1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""EXPERIMENTAL utilities for parameter server training with eager execution. 16 17Note: this should eventually be merged with the distribution strategy for 18ParameterServer. 19""" 20 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import contextlib 27import time 28 29from tensorflow.python.eager import context 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.training.tracking import base as trackable 34 35 36def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): 37 """Creates a variable handle with information to do shape inference.""" 38 container = ops.get_default_graph()._container # pylint: disable=protected-access 39 if container is None: 40 container = "" 41 handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 42 shared_name=shared_name, 43 name=name, 44 container=container) 45 if graph_mode: 46 return handle 47 48 with context.graph_mode(), ops.Graph().as_default() as graph: 49 h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 50 shared_name=shared_name, 51 name=name, 52 container=container) 53 54 # Tensor._handle_data contains information for the shape-inference code to 55 # know the shape and dtype of the variable pointed to by a handle. Since 56 # shape inference doesn't run in eager mode we copy this data here for when 57 # the handle is captured by an eager mode function. 58 # pylint: disable=protected-access 59 handle._handle_data = resource_variable_ops.get_resource_handle_data(h) 60 # pylint: enable=protected-access 61 # Clean up op->graph->op reference cycles. 62 ops.dismantle_graph(graph) 63 return handle 64 65 66class SharedVariable(resource_variable_ops.ResourceVariable): 67 """Experimental Variable designed for parameter server training. 68 69 A SharedVariable has a name and two instances of SharedVariable with the 70 same name will have the same value, even if they are in different Sessions, 71 as long as they are placed on the same device. 72 73 The storage associated with SharedVariables is also not deleted when they go 74 out of scope. 75 """ 76 77 def __init__(self, # pylint: disable=super-init-not-called 78 initial_value=None, 79 trainable=True, 80 name=None, 81 dtype=None, 82 constraint=None, 83 initialize=True, 84 **unused_kwargs): 85 """Creates a variable. 86 87 Args: 88 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 89 which is the initial value for the Variable. The initial value must have 90 a shape specified unless `validate_shape` is set to False. Can also be a 91 callable with no argument that returns the initial value when called. 92 (Note that initializer functions from init_ops.py must first be bound 93 to a shape before being used here.) 94 trainable: If `True`, automatically watches this variable on GradientTape 95 whenever it's used. 96 name: Optional name for the variable. Defaults to `'Variable'` and gets 97 uniquified automatically. 98 dtype: If set, initial_value will be converted to the given type. 99 If None, either the datatype will be kept (if initial_value is 100 a Tensor) or float32 will be used (if it is a Python object convertible 101 to a Tensor). 102 constraint: An optional projection function to be applied to the variable 103 after being updated by an `Optimizer` (e.g. used to implement norm 104 constraints or value constraints for layer weights). The function must 105 take as input the unprojected Tensor representing the value of the 106 variable and return the Tensor for the projected value 107 (which must have the same shape). Constraints are not safe to 108 use when doing asynchronous distributed training. 109 initialize: if True, runs initialization in eager execution; leaves the 110 variable uninitialized otherwise. 111 112 Raises: 113 ValueError: If the initial value is not specified, or does not have a 114 shape and `validate_shape` is `True`. 115 """ 116 if initial_value is None: 117 raise ValueError("initial_value must be specified.") 118 init_from_fn = callable(initial_value) 119 120 if isinstance(initial_value, ops.Tensor) and hasattr( 121 initial_value, "graph") and initial_value.graph.building_function: 122 raise ValueError("Tensor-typed variable initializers must either be " 123 "wrapped in an init_scope or callable " 124 "(e.g., `tf.Variable(lambda : " 125 "tf.truncated_normal([10, 40]))`) when building " 126 "functions. Please file a feature request if this " 127 "restriction inconveniences you.") 128 129 if constraint is not None and not callable(constraint): 130 raise ValueError("The `constraint` argument must be a callable.") 131 132 if isinstance(initial_value, trackable.CheckpointInitialValue): 133 self._maybe_initialize_trackable() 134 self._update_uid = initial_value.checkpoint_position.restore_uid 135 initial_value = initial_value.wrapped_value 136 137 self._trainable = trainable 138 self._save_slice_info = None 139 # Store the graph key so optimizers know how to only retrieve variables from 140 # this graph. 141 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 142 with ops.init_scope(): 143 self._in_graph_mode = not context.executing_eagerly() 144 with ops.name_scope(name, "Variable", [] 145 if init_from_fn else [initial_value]) as name: 146 # pylint: disable=protected-access 147 handle_name = ops._name_from_scope_name(name) 148 shared_name = handle_name 149 if init_from_fn: 150 # Use attr_scope and device(None) to simulate the behavior of 151 # colocate_with when the variable we want to colocate with doesn't 152 # yet exist. 153 if self._in_graph_mode: 154 with ops.name_scope("Initializer"), ops.device(None): 155 initial_value = ops.convert_to_tensor( 156 initial_value(), name="initial_value", dtype=dtype) 157 self._handle = _eager_safe_variable_handle( 158 shape=initial_value.get_shape(), 159 dtype=initial_value.dtype.base_dtype, 160 shared_name=shared_name, 161 name=name, 162 graph_mode=self._in_graph_mode) 163 self._shape = initial_value.get_shape() 164 else: 165 initial_value = initial_value() 166 with ops.name_scope("Initializer"): 167 initial_value = ops.convert_to_tensor( 168 initial_value, name="initial_value", dtype=dtype) 169 self._handle = _eager_safe_variable_handle( 170 shape=initial_value.get_shape(), 171 dtype=initial_value.dtype.base_dtype, 172 shared_name=shared_name, 173 name=name, 174 graph_mode=False) 175 self._shape = initial_value.get_shape() 176 # pylint: enable=protected-access 177 178 # Or get the initial value from a Tensor or Python object. 179 else: 180 with ops.name_scope("Initializer"): 181 initial_value = ops.convert_to_tensor( 182 initial_value, name="initial_value", dtype=dtype) 183 # pylint: disable=protected-access 184 if (self._in_graph_mode and initial_value is not None and 185 initial_value.op._get_control_flow_context() is not None): 186 raise ValueError( 187 "Initializer for variable %s is from inside a control-flow " 188 "construct, such as a loop or conditional. When creating a " 189 "variable inside a loop or conditional, use a lambda as the " 190 "initializer." % name) 191 # pylint: enable=protected-access 192 self._handle = _eager_safe_variable_handle( 193 shape=initial_value.get_shape(), 194 dtype=initial_value.dtype.base_dtype, 195 shared_name=shared_name, 196 name=name, 197 graph_mode=self._in_graph_mode) 198 self._shape = initial_value.get_shape() 199 200 self._unique_id = shared_name 201 self._initial_value = initial_value if self._in_graph_mode else None 202 self._handle_name = handle_name + ":0" 203 self._dtype = initial_value.dtype.base_dtype 204 self._constraint = constraint 205 206 if self._in_graph_mode: 207 with ops.name_scope("IsInitialized"): 208 self._is_initialized_op = ( 209 resource_variable_ops.var_is_initialized_op(self._handle)) 210 if initial_value is not None: 211 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 212 self._initializer_op = ( 213 resource_variable_ops.assign_variable_op( 214 self._handle, 215 self._try_guard_against_uninitialized_dependencies( 216 initial_value), 217 name=n)) 218 with ops.name_scope("Read"), ops.colocate_with(self._handle): 219 # Manually assign reads to the handle's device to avoid log 220 # messages. 221 with ops.device(self._handle.device): 222 value = self._read_variable_op() 223 self._graph_element = value 224 self._cached_value = None 225 else: 226 if initialize: 227 resource_variable_ops.assign_variable_op(self._handle, 228 initial_value) 229 self._is_initialized_op = None 230 self._initializer_op = None 231 self._graph_element = None 232 self._cached_value = None 233 234 self._handle_deleter = None 235 self._cached_shape_as_list = None 236 237 238@contextlib.contextmanager 239def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks): 240 """Strategy to use parameter servers in eager. 241 242 Creates SharedVariable objects for variables created in this scope. These 243 SharedVariable objects will be placed round-robin on the parameter servers 244 specified by the ps_job_name and num_ps_tasks arguments. 245 246 To use parameter servers you need only to wrap your model initialization in 247 this scope: 248 249 ``` 250 with tf.contrib.eager.parameter_server_scope( 251 is_chief, ps_job_name, num_ps_tasks): 252 my_model = tf.keras.Sequential([...]) # Or 253 input = tf.keras.Input(...) 254 .... 255 my_model = tf.keras.Model(input, output) 256 my_model.compile(...) 257 # or other usages of the model. 258 ``` 259 260 Args: 261 is_chief: Boolean. Whether this worker is responsible for initializing 262 variables. 263 ps_job_name: The name of the ps job in this cluster. 264 num_ps_tasks: The number of ps tasks to use. 265 266 Yields: 267 a context manager. 268 """ 269 # Note: capturing in a list to allow assignment. 270 ps_index = [0] 271 272 def variable_creator_scope(unused_next_creator, **kwargs): 273 kwargs["initialize"] = is_chief 274 with ops.device( 275 "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)): 276 ps_index[0] += 1 277 v = SharedVariable(**kwargs) 278 if not is_chief: 279 while not resource_variable_ops.var_is_initialized_op(v.handle): 280 time.sleep(10) 281 return v 282 283 with variable_scope.variable_creator_scope(variable_creator_scope): 284 yield 285