1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Distributed variable implementation for TPUs. 16 17N.B. This is an experimental feature that should only be used for Keras support. 18 19It is unsupported and will be removed in favor of Distribution Strategy soon. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import contextlib 27 28import numpy as np 29 30from tensorflow.python.client import session as session_lib 31from tensorflow.python.framework import dtypes as dtypes_module 32from tensorflow.python.framework import ops 33from tensorflow.python.keras import backend 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gen_resource_variable_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope 39 40 41@contextlib.contextmanager 42def _handle_graph(handle): 43 with handle.graph.as_default(): 44 yield 45 46 47def _enclosing_tpu_context(): 48 # pylint: disable=protected-access 49 context = ops.get_default_graph()._get_control_flow_context() 50 # pylint: enable=protected-access 51 while context is not None and not isinstance( 52 context, control_flow_ops.XLAControlFlowContext): 53 context = context.outer_context 54 return context 55 56 57class ReplicatedVariable(object): 58 """A replicated variable for use on TPUs. 59 60 When accessed inside a tpu.replicate() context, this variable acts as if it 61 is a single variable whose handle is a replicated input to the computation. 62 63 Outside a tpu.replicate() context currently this object has pretty murky 64 semantics, especially with respect to things such as 65 * initialization 66 * colocation. 67 """ 68 69 def __init__(self, name, variables): 70 self._name = name 71 self._primary_var = variables[0] 72 self._common_name = self._primary_var.name.split(":")[0] 73 self._vars = variables 74 self._cached_value = None 75 self._dtype = variables[0].dtype 76 77 @property 78 def handle(self): 79 tpu_context = _enclosing_tpu_context() 80 if tpu_context is None: 81 return self._primary_var.handle 82 83 return tpu_context.get_replicated_var_handle(self._name, self._vars) 84 85 @contextlib.contextmanager 86 def _assign_dependencies(self): 87 """Makes assignments depend on the cached value, if any. 88 89 This prevents undefined behavior with reads not ordered wrt writes. 90 91 Yields: 92 None. 93 """ 94 if self._cached_value is not None: 95 with ops.control_dependencies([self._cached_value]): 96 yield 97 else: 98 yield 99 100 @property 101 def initializer(self): 102 return control_flow_ops.group([v.initializer for v in self._vars]) 103 104 @property 105 def graph(self): 106 return self._primary_var.graph 107 108 @property 109 def _shared_name(self): 110 return self._common_name 111 112 @property 113 def _unique_id(self): 114 return self._primary_var._unique_id # pylint: disable=protected-access 115 116 @property 117 def name(self): 118 return self._name 119 120 @property 121 def dtype(self): 122 return self._primary_var.dtype 123 124 @property 125 def shape(self): 126 return self._primary_var.shape 127 128 def get_shape(self): 129 return self._primary_var.get_shape() 130 131 def to_proto(self, export_scope=None): 132 return self._primary_var.to_proto(export_scope=export_scope) 133 134 @property 135 def constraint(self): 136 return None 137 138 @property 139 def op(self): 140 return self.get().op 141 142 @property 143 def is_tensor_like(self): 144 return True 145 146 def _read_variable_op(self): 147 if _enclosing_tpu_context() is None: 148 return self._primary_var.read_value() 149 v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype) 150 return v 151 152 def read_value(self): 153 return self._read_variable_op() 154 155 def is_initialized(self, name=None): 156 return self._vars[0].is_initialized(name=name) 157 158 def __getitem__(self, *args): 159 return self.read_value().__getitem__(*args) 160 161 def assign(self, value, use_locking=None, name=None, read_value=False): 162 """Assign `value` to all replicas. 163 164 Outside of the tpu.rewrite context, assign explicitly to all replicas. 165 Inside of the tpu.rewrite context, assigns to the local replica. 166 167 Arguments: 168 value: Tensor to assign 169 use_locking: ignored 170 name: ignored 171 read_value: return the value from the assignment 172 Returns: 173 Assignment operation, or new value of the variable if `read_value` is True 174 """ 175 del use_locking 176 if _enclosing_tpu_context() is None: 177 assign_ops = [] 178 with self._assign_dependencies(): 179 for var in self._vars: 180 assign_ops.append(var.assign(value, use_locking=None, name=name)) 181 182 if read_value: 183 with ops.control_dependencies(assign_ops): 184 return self.read_value() 185 else: 186 return control_flow_ops.group(assign_ops) 187 188 with _handle_graph(self.handle), self._assign_dependencies(): 189 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 190 assign_op = gen_resource_variable_ops.assign_variable_op( 191 self.handle, value_tensor, name=name) 192 if read_value: 193 return self._read_variable_op() 194 return assign_op 195 196 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 197 del use_locking 198 with _handle_graph(self.handle), self._assign_dependencies(): 199 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 200 self.handle, 201 ops.convert_to_tensor(delta, dtype=self.dtype), 202 name=name) 203 if read_value: 204 return self._read_variable_op() 205 return assign_add_op 206 207 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 208 del use_locking 209 with _handle_graph(self.handle), self._assign_dependencies(): 210 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 211 self.handle, 212 ops.convert_to_tensor(delta, dtype=self.dtype), 213 name=name) 214 if read_value: 215 return self._read_variable_op() 216 return assign_sub_op 217 218 def get(self): 219 return self._primary_var 220 221 @property 222 def _in_graph_mode(self): 223 return self._primary_var._in_graph_mode # pylint: disable=protected-access 224 225 def _should_act_as_resource_variable(self): 226 """Pass resource_variable_ops.is_resource_variable check.""" 227 pass 228 229 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 230 """Converts a variable to a tensor.""" 231 # pylint: disable=protected-access 232 if _enclosing_tpu_context() is None: 233 return self._primary_var._dense_var_to_tensor(dtype, name, as_ref) 234 # pylint: enable=protected-access 235 if dtype is not None and dtype != self.dtype: 236 return math_ops.cast(self._read_variable_op(), dtype) 237 if as_ref: 238 return self.handle 239 else: 240 return self.read_value() 241 242 243# Register a conversion function which reads the value of the variable, 244# allowing instances of the class to be used as tensors. 245def _tensor_conversion(var, dtype=None, name=None, as_ref=False): 246 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 247 248 249def replicated_fetch_function(var): 250 # pylint: disable=protected-access 251 return ([var._dense_var_to_tensor()], lambda v: v[0]) 252 # pylint: enable=protected-access 253 254 255ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion) 256ops.register_dense_tensor_like_type(ReplicatedVariable) 257session_lib.register_session_run_conversion_functions( 258 ReplicatedVariable, replicated_fetch_function) 259 260 261def replicated_scope(num_replicas): 262 """Variable scope for constructing replicated variables.""" 263 264 def _replicated_variable_getter(getter, name, *args, **kwargs): 265 """Getter that constructs replicated variables.""" 266 collections = kwargs.pop("collections", None) 267 if collections is None: 268 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 269 kwargs["collections"] = [] 270 271 variables = [] 272 index = {} 273 for i in range(num_replicas): 274 replica_name = "{}/{}".format(name, i) 275 with ops.device("device:TPU:{}".format(i)): 276 v = getter(*args, name=replica_name, **kwargs) 277 variables.append(v) 278 index[i] = v 279 result = ReplicatedVariable(name, variables) 280 281 g = ops.get_default_graph() 282 # If "trainable" is True, next_creator() will add the member variables 283 # to the TRAINABLE_VARIABLES collection, so we manually remove 284 # them and replace with the MirroredVariable. We can't set 285 # "trainable" to False for next_creator() since that causes functions 286 # like implicit_gradients to skip those variables. 287 if kwargs.get("trainable", True): 288 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 289 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 290 for v in index.values(): 291 if v in l: 292 l.remove(v) 293 g.add_to_collections(collections, result) 294 295 return result 296 297 return variable_scope.variable_scope( 298 "", custom_getter=_replicated_variable_getter) 299 300 301@contextlib.contextmanager 302def replicated_variable_for_optimizer(num_replicas): 303 """Context manager for optimizer weights. Overrides K.variable.""" 304 if num_replicas == 1: 305 yield 306 return 307 308 try: 309 old_v = backend.variable 310 311 def opt_variable(value, dtype=None, name=None, constraint=None): 312 """Instantiates a variable and returns it.""" 313 if dtype is None: 314 dtype = backend.floatx() 315 316 variables = [] 317 for i in range(num_replicas): 318 # Keras holds the variables in optimizer class instance , so the name 319 # does not matter here. ResourceVariable constructor will find a unique 320 # name (including name=None) for each replica. 321 with ops.device("device:TPU:{}".format(i)): 322 v = resource_variable_ops.ResourceVariable( 323 value, 324 dtype=dtypes_module.as_dtype(dtype), 325 name=name, 326 constraint=constraint) 327 variables.append(v) 328 name = "replicate_{}_{}".format("variable" if name is None else name, 329 ops.uid()) 330 v = ReplicatedVariable(name, variables) 331 332 # pylint: disable=protected-access 333 334 if isinstance(value, np.ndarray): 335 v._keras_shape = value.shape 336 elif hasattr(value, "shape"): 337 v._keras_shape = backend.int_shape(value) 338 v._uses_learning_phase = False 339 backend.track_variable(v) 340 return v 341 342 backend.variable = opt_variable 343 yield 344 345 finally: 346 backend.variable = old_v 347