1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""EXPERIMENTAL utilities for parameter server training with eager execution.
16
17Note: this should eventually be merged with the distribution strategy for
18ParameterServer.
19"""
20
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import contextlib
27import time
28
29from tensorflow.python.eager import context
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.training.tracking import base as trackable
34
35
36def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
37  """Creates a variable handle with information to do shape inference."""
38  container = ops.get_default_graph()._container  # pylint: disable=protected-access
39  if container is None:
40    container = ""
41  handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
42                                               shared_name=shared_name,
43                                               name=name,
44                                               container=container)
45  if graph_mode:
46    return handle
47
48  with context.graph_mode(), ops.Graph().as_default() as graph:
49    h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
50                                            shared_name=shared_name,
51                                            name=name,
52                                            container=container)
53
54    # Tensor._handle_data contains information for the shape-inference code to
55    # know the shape and dtype of the variable pointed to by a handle. Since
56    # shape inference doesn't run in eager mode we copy this data here for when
57    # the handle is captured by an eager mode function.
58    # pylint: disable=protected-access
59    handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
60    # pylint: enable=protected-access
61  # Clean up op->graph->op reference cycles.
62  ops.dismantle_graph(graph)
63  return handle
64
65
66class SharedVariable(resource_variable_ops.ResourceVariable):
67  """Experimental Variable designed for parameter server training.
68
69  A SharedVariable has a name and two instances of SharedVariable with the
70  same name will have the same value, even if they are in different Sessions,
71  as long as they are placed on the same device.
72
73  The storage associated with SharedVariables is also not deleted when they go
74  out of scope.
75  """
76
77  def __init__(self,  # pylint: disable=super-init-not-called
78               initial_value=None,
79               trainable=True,
80               name=None,
81               dtype=None,
82               constraint=None,
83               initialize=True,
84               **unused_kwargs):
85    """Creates a variable.
86
87    Args:
88      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
89        which is the initial value for the Variable. The initial value must have
90        a shape specified unless `validate_shape` is set to False. Can also be a
91        callable with no argument that returns the initial value when called.
92        (Note that initializer functions from init_ops.py must first be bound
93         to a shape before being used here.)
94      trainable: If `True`, automatically watches this variable on GradientTape
95        whenever it's used.
96      name: Optional name for the variable. Defaults to `'Variable'` and gets
97        uniquified automatically.
98      dtype: If set, initial_value will be converted to the given type.
99        If None, either the datatype will be kept (if initial_value is
100        a Tensor) or float32 will be used (if it is a Python object convertible
101        to a Tensor).
102      constraint: An optional projection function to be applied to the variable
103        after being updated by an `Optimizer` (e.g. used to implement norm
104        constraints or value constraints for layer weights). The function must
105        take as input the unprojected Tensor representing the value of the
106        variable and return the Tensor for the projected value
107        (which must have the same shape). Constraints are not safe to
108        use when doing asynchronous distributed training.
109      initialize: if True, runs initialization in eager execution; leaves the
110        variable uninitialized otherwise.
111
112    Raises:
113      ValueError: If the initial value is not specified, or does not have a
114        shape and `validate_shape` is `True`.
115    """
116    if initial_value is None:
117      raise ValueError("initial_value must be specified.")
118    init_from_fn = callable(initial_value)
119
120    if isinstance(initial_value, ops.Tensor) and hasattr(
121        initial_value, "graph") and initial_value.graph.building_function:
122      raise ValueError("Tensor-typed variable initializers must either be "
123                       "wrapped in an init_scope or callable "
124                       "(e.g., `tf.Variable(lambda : "
125                       "tf.truncated_normal([10, 40]))`) when building "
126                       "functions. Please file a feature request if this "
127                       "restriction inconveniences you.")
128
129    if constraint is not None and not callable(constraint):
130      raise ValueError("The `constraint` argument must be a callable.")
131
132    if isinstance(initial_value, trackable.CheckpointInitialValue):
133      self._maybe_initialize_trackable()
134      self._update_uid = initial_value.checkpoint_position.restore_uid
135      initial_value = initial_value.wrapped_value
136
137    self._trainable = trainable
138    self._save_slice_info = None
139    # Store the graph key so optimizers know how to only retrieve variables from
140    # this graph.
141    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
142    with ops.init_scope():
143      self._in_graph_mode = not context.executing_eagerly()
144      with ops.name_scope(name, "Variable", []
145                          if init_from_fn else [initial_value]) as name:
146        # pylint: disable=protected-access
147        handle_name = ops._name_from_scope_name(name)
148        shared_name = handle_name
149        if init_from_fn:
150          # Use attr_scope and device(None) to simulate the behavior of
151          # colocate_with when the variable we want to colocate with doesn't
152          # yet exist.
153          if self._in_graph_mode:
154            with ops.name_scope("Initializer"), ops.device(None):
155              initial_value = ops.convert_to_tensor(
156                  initial_value(), name="initial_value", dtype=dtype)
157            self._handle = _eager_safe_variable_handle(
158                shape=initial_value.get_shape(),
159                dtype=initial_value.dtype.base_dtype,
160                shared_name=shared_name,
161                name=name,
162                graph_mode=self._in_graph_mode)
163            self._shape = initial_value.get_shape()
164          else:
165            initial_value = initial_value()
166            with ops.name_scope("Initializer"):
167              initial_value = ops.convert_to_tensor(
168                  initial_value, name="initial_value", dtype=dtype)
169            self._handle = _eager_safe_variable_handle(
170                shape=initial_value.get_shape(),
171                dtype=initial_value.dtype.base_dtype,
172                shared_name=shared_name,
173                name=name,
174                graph_mode=False)
175            self._shape = initial_value.get_shape()
176        # pylint: enable=protected-access
177
178        # Or get the initial value from a Tensor or Python object.
179        else:
180          with ops.name_scope("Initializer"):
181            initial_value = ops.convert_to_tensor(
182                initial_value, name="initial_value", dtype=dtype)
183          # pylint: disable=protected-access
184          if (self._in_graph_mode and initial_value is not None and
185              initial_value.op._get_control_flow_context() is not None):
186            raise ValueError(
187                "Initializer for variable %s is from inside a control-flow "
188                "construct, such as a loop or conditional. When creating a "
189                "variable inside a loop or conditional, use a lambda as the "
190                "initializer." % name)
191          # pylint: enable=protected-access
192          self._handle = _eager_safe_variable_handle(
193              shape=initial_value.get_shape(),
194              dtype=initial_value.dtype.base_dtype,
195              shared_name=shared_name,
196              name=name,
197              graph_mode=self._in_graph_mode)
198          self._shape = initial_value.get_shape()
199
200        self._unique_id = shared_name
201        self._initial_value = initial_value if self._in_graph_mode else None
202        self._handle_name = handle_name + ":0"
203        self._dtype = initial_value.dtype.base_dtype
204        self._constraint = constraint
205
206        if self._in_graph_mode:
207          with ops.name_scope("IsInitialized"):
208            self._is_initialized_op = (
209                resource_variable_ops.var_is_initialized_op(self._handle))
210          if initial_value is not None:
211            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
212              self._initializer_op = (
213                  resource_variable_ops.assign_variable_op(
214                      self._handle,
215                      self._try_guard_against_uninitialized_dependencies(
216                          initial_value),
217                      name=n))
218          with ops.name_scope("Read"), ops.colocate_with(self._handle):
219            # Manually assign reads to the handle's device to avoid log
220            # messages.
221            with ops.device(self._handle.device):
222              value = self._read_variable_op()
223            self._graph_element = value
224            self._cached_value = None
225        else:
226          if initialize:
227            resource_variable_ops.assign_variable_op(self._handle,
228                                                     initial_value)
229          self._is_initialized_op = None
230          self._initializer_op = None
231          self._graph_element = None
232          self._cached_value = None
233
234    self._handle_deleter = None
235    self._cached_shape_as_list = None
236
237
238@contextlib.contextmanager
239def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
240  """Strategy to use parameter servers in eager.
241
242  Creates SharedVariable objects for variables created in this scope. These
243  SharedVariable objects will be placed round-robin on the parameter servers
244  specified by the ps_job_name and num_ps_tasks arguments.
245
246  To use parameter servers you need only to wrap your model initialization in
247  this scope:
248
249  ```
250  with tf.contrib.eager.parameter_server_scope(
251      is_chief, ps_job_name, num_ps_tasks):
252    my_model = tf.keras.Sequential([...])  # Or
253    input = tf.keras.Input(...)
254    ....
255    my_model = tf.keras.Model(input, output)
256  my_model.compile(...)
257  # or other usages of the model.
258  ```
259
260  Args:
261    is_chief: Boolean. Whether this worker is responsible for initializing
262      variables.
263    ps_job_name: The name of the ps job in this cluster.
264    num_ps_tasks: The number of ps tasks to use.
265
266  Yields:
267    a context manager.
268  """
269  # Note: capturing in a list to allow assignment.
270  ps_index = [0]
271
272  def variable_creator_scope(unused_next_creator, **kwargs):
273    kwargs["initialize"] = is_chief
274    with ops.device(
275        "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)):
276      ps_index[0] += 1
277      v = SharedVariable(**kwargs)
278      if not is_chief:
279        while not resource_variable_ops.var_is_initialized_op(v.handle):
280          time.sleep(10)
281      return v
282
283  with variable_scope.variable_creator_scope(variable_creator_scope):
284    yield
285