1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes implementing a multi-worker ps DistributionStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import input_lib
23from tensorflow.python.distribute import parameter_server_strategy
24from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
25from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
26
27# pylint: disable=protected-access,invalid-name,line-too-long
28CoreParameterServerStrategy = parameter_server_strategy.ParameterServerStrategy
29CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyExtended
30
31# pylint: enable=protected-access,invalid-name,line-too-long
32
33
34class ParameterServerStrategy(distribute_lib.DistributionStrategy):
35  """A parameter server DistributionStrategy.
36
37  *** contrib version ***
38
39  This strategy class works for both local training and between-graph replicated
40  training for multiple workers. If `cluster_spec` is specified, either passed
41  in to __init__() method or parsed from the
42  ["TF_CONFIG" environment
43  variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig),
44  variables and updates to those variables are assigned to parameter servers and
45  other operations are assigned to workers. If `cluster_spec` is not set, it
46  becomes local training where variables are assigned to local CPU or the only
47  GPU. When each worker has more than one GPU, operations will be replicated on
48  these GPUs. In both cases, operations are replicated but variables are not and
49  these workers share a common view for which paramater server a variable is
50  assigned to.
51
52  This class assumes between-graph replication will be used and works on a graph
53  for a particular worker. Note that each graph and worker is independent.
54  This means that while each worker will synchronously compute a single gradient
55  update across all GPUs, updates between workers proceed asynchronously.
56  Operations that occur only on the first replica (such as incrementing the
57  global step), will occur on the first replica *of every worker*.
58
59  It is expected to call `call_for_each_replica(fn, ...)` for any
60  operations which potentially can be replicated across replicas (i.e. multiple
61  GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
62  caution needs to be taken:
63
64  1) Always use `tf.get_variable` instead of `tf.Variable` which is not able
65  to refer to the same variable on different replicas.
66
67  2) It is generally not recommended to open a device scope under the strategy's
68  scope. A device scope (i.e. calling `tf.device`) will be merged with or
69  override the device for operations but will not change the device for
70  variables.
71
72  3) It is also not recommended to open a colocation scope (i.e. calling
73  `tf.colocate_with`) under the strategy's scope. For colocating variables, use
74  `strategy.extended.colocate_vars_with` instead. Colocation of ops will
75  possibly create conflicts of device assignment.
76  """
77
78  def __init__(self, num_gpus_per_worker=0):
79    """Initializes this strategy.
80
81    Args:
82      num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
83        is 0 meaning CPU only.
84
85    Raises:
86      ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
87        not.
88    """
89    super(ParameterServerStrategy, self).__init__(
90        ParameterServerExtended(self, num_gpus_per_worker))
91
92  # Override to change the documentation to reflect the different handling of
93  # global vs. local batch size between core and contrib.
94  def make_dataset_iterator(self, dataset):  # pylint: disable=useless-super-delegation
95    """Makes an iterator for input provided via `dataset`.
96
97    NOTE: The batch size of the `dataset` argument is treated differently for
98    this contrib version of `ParameterServerStrategy`.
99
100    Data from the given dataset will be distributed evenly across all the
101    compute replicas. We will assume that the input dataset is batched by the
102    per-replica batch size.
103
104    The user could also use `make_input_fn_iterator` if they want to
105    customize which input is fed to which replica/worker etc.
106
107    Args:
108      dataset: `tf.data.Dataset` that will be distributed evenly across all
109        replicas.
110
111    Returns:
112      An `tf.distribute.InputIterator` which returns inputs for each step of the
113      computation.  User should call `initialize` on the returned iterator.
114    """
115    return super(ParameterServerStrategy, self).make_dataset_iterator(dataset)
116
117  # Override to change the documentation to reflect the different handling of
118  # global vs. local batch size between core and contrib.
119  def experimental_make_numpy_iterator(  # pylint: disable=useless-super-delegation
120      self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None):
121    """Makes an iterator for input provided via a nest of numpy arrays.
122
123    NOTE: The `batch_size` argument here has different behavior for this
124    contrib version of `ParameterServerStrategy`.
125
126    Args:
127      numpy_input: A nest of NumPy input arrays that will be distributed evenly
128        across all replicas.
129      batch_size: The number of entries from the array we should consume in one
130        step of the computation, across all replicas. This is the per-replica
131        batch size. The global batch size will be this times
132        `num_replicas_in_sync`.
133      num_epochs: The number of times to iterate through the examples. A value
134        of `None` means repeat forever.
135      shuffle: Size of buffer to use for shuffling the input examples.
136        Use `None` to disable shuffling.
137      session: (TensorFlow v1.x graph execution only) A session used for
138        initialization.
139
140    Returns:
141      An `tf.distribute.InputIterator` which returns inputs for each step of the
142      computation.  User should call `initialize` on the returned iterator.
143    """
144    return super(ParameterServerStrategy,
145                 self).experimental_make_numpy_iterator(
146                     numpy_input, batch_size, num_epochs, shuffle, session)
147
148
149class ParameterServerExtended(CoreParameterServerExtended):
150  """Implementation of ParameterServerStrategy."""
151
152  def __init__(self, container_strategy, num_gpus_per_worker):
153    # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
154    # the constructor's interface to allow customized cluster resolver. Use
155    # SimpleClusterResolver to override num_accelerators.
156    tfconfig = TFConfigClusterResolver()
157    cluster_resolver = SimpleClusterResolver(
158        cluster_spec=tfconfig.cluster_spec(),
159        task_type=tfconfig.task_type,
160        task_id=tfconfig.task_id,
161        num_accelerators={'GPU': num_gpus_per_worker})
162    super(ParameterServerExtended, self).__init__(
163        container_strategy, cluster_resolver=cluster_resolver)
164
165  def _make_dataset_iterator(self, dataset):
166    return input_lib.DatasetIterator(dataset, self._input_workers)
167
168  # TODO(priyag): Delete this once all strategies use global batch size.
169  @property
170  def _global_batch_size(self):
171    """The contrib version of PS strategy uses per-replica batch size."""
172    return False
173