1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a single machine parameter server strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import parameter_server_strategy
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export('distribute.experimental.CentralStorageStrategy', v1=[])
28class CentralStorageStrategy(distribute_lib.Strategy):
29  """A one-machine strategy that puts all variables on a single device.
30
31  Variables are assigned to local CPU or the only GPU. If there is more
32  than one GPU, compute operations (other than variable update operations)
33  will be replicated across all GPUs.
34
35  For Example:
36  ```
37  strategy = tf.distribute.experimental.CentralStorageStrategy()
38  # Create a dataset
39  ds = tf.data.Dataset.range(5).batch(2)
40  # Distribute that dataset
41  dist_dataset = strategy.experimental_distribute_dataset(ds)
42
43  with strategy.scope():
44    @tf.function
45    def train_step(val):
46      return val + 1
47
48    # Iterate over the distributed dataset
49    for x in dist_dataset:
50      # process dataset elements
51      strategy.run(train_step, args=(x,))
52  ```
53  """
54
55  def __init__(self, compute_devices=None, parameter_device=None):
56    extended = parameter_server_strategy.ParameterServerStrategyExtended(
57        self,
58        compute_devices=compute_devices,
59        parameter_device=parameter_device)
60    """Initializes the strategy with optional device strings.
61
62    Args:
63    compute_devices: an optional list of strings for device to replicate models
64      on. If this is not provided, all local GPUs will be used; if there is no
65      GPU, local CPU will be used.
66    parameter_device: an optional device string for which device to put
67      variables on. The default one is CPU or GPU if there is only one.
68    """
69    super(CentralStorageStrategy, self).__init__(extended)
70    distribute_lib.distribution_strategy_gauge.get_cell('V2').set(
71        'CentralStorageStrategy')
72
73  @classmethod
74  def _from_num_gpus(cls, num_gpus):
75    return cls(device_util.local_devices_from_num_gpus(num_gpus))
76
77  def experimental_distribute_dataset(self, dataset, options=None):  # pylint: disable=useless-super-delegation
78    """Distributes a tf.data.Dataset instance provided via dataset.
79
80    The returned dataset is a wrapped strategy dataset which creates a
81    multidevice iterator under the hood. It prefetches the input data to the
82    specified devices on the worker. The returned distributed dataset can be
83    iterated over similar to how regular datasets can.
84
85    NOTE: Currently, the user cannot add any more transformations to a
86    distributed dataset.
87
88    For Example:
89    ```
90    strategy = tf.distribute.CentralStorageStrategy()  # with 1 CPU and 1 GPU
91    dataset = tf.data.Dataset.range(10).batch(2)
92    dist_dataset = strategy.experimental_distribute_dataset(dataset)
93    for x in dist_dataset:
94      print(x)  # Prints PerReplica values [0, 1], [2, 3],...
95
96    ```
97    Args:
98      dataset: `tf.data.Dataset` to be prefetched to device.
99      options: `tf.distribute.InputOptions` used to control options on how this
100        dataset is distributed.
101
102    Returns:
103      A "distributed `Dataset`" that the caller can iterate over.
104    """
105    if (options and options.experimental_replication_moden ==
106        distribute_lib.InputReplicationMode.PER_REPLICA):
107      raise NotImplementedError(
108          'InputReplicationMode.PER_REPLICA '
109          'is only supported in '
110          '`experimental_distribute_datasets_from_function`.'
111      )
112    return super(CentralStorageStrategy, self).experimental_distribute_dataset(
113        dataset, options)
114
115  def experimental_local_results(self, value):  # pylint: disable=useless-super-delegation
116    """Returns the list of all local per-replica values contained in `value`.
117
118    In `CentralStorageStrategy` there is a single worker so the value returned
119    will be all the values on that worker.
120
121    Args:
122      value: A value returned by `run()`, `extended.call_for_each_replica()`,
123      or a variable created in `scope`.
124
125    Returns:
126      A tuple of values contained in `value`. If `value` represents a single
127      value, this returns `(value,).`
128    """
129    return super(CentralStorageStrategy, self).experimental_local_results(value)
130
131  def run(self, fn, args=(), kwargs=None, options=None):  # pylint: disable=useless-super-delegation
132    """Run `fn` on each replica, with the given arguments.
133
134    In `CentralStorageStrategy`, `fn` is  called on each of the compute
135    replicas, with the provided "per replica" arguments specific to that device.
136
137    Args:
138      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
139      args: (Optional) Positional arguments to `fn`.
140      kwargs: (Optional) Keyword arguments to `fn`.
141      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
142        the options to run `fn`.
143
144    Returns:
145      Return value from running `fn`.
146    """
147    return super(CentralStorageStrategy, self).run(fn, args, kwargs, options)
148
149  def reduce(self, reduce_op, value, axis):  # pylint: disable=useless-super-delegation
150    """Reduce `value` across replicas.
151
152    Given a per-replica value returned by `run`, say a
153    per-example loss, the batch will be divided across all the replicas. This
154    function allows you to aggregate across replicas and optionally also across
155    batch elements.  For example, if you have a global batch size of 8 and 2
156    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
157    `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just
158    aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful
159    when each replica is computing a scalar or some other value that doesn't
160    have a "batch" dimension (like a gradient). More often you will want to
161    aggregate across the global batch, which you can get by specifying the batch
162    dimension as the `axis`, typically `axis=0`. In this case it would return a
163    scalar `0+1+2+3+4+5+6+7`.
164
165    If there is a last partial batch, you will need to specify an axis so
166    that the resulting shape is consistent across replicas. So if the last
167    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
168    would get a shape mismatch unless you specify `axis=0`. If you specify
169    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
170    denominator of 6. Contrast this with computing `reduce_mean` to get a
171    scalar value on each replica and this function to average those means,
172    which will weigh some values `1/8` and others `1/4`.
173
174    For Example:
175    ```
176    strategy = tf.distribute.experimental.CentralStorageStrategy(
177        compute_devices=['CPU:0', 'GPU:0'], parameter_device='CPU:0')
178    ds = tf.data.Dataset.range(10)
179    # Distribute that dataset
180    dist_dataset = strategy.experimental_distribute_dataset(ds)
181
182    with strategy.scope():
183      @tf.function
184      def train_step(val):
185        # pass through
186        return val
187
188      # Iterate over the distributed dataset
189      for x in dist_dataset:
190        result = strategy.run(train_step, args=(x,))
191
192    result = strategy.reduce(tf.distribute.ReduceOp.SUM, result,
193                             axis=None).numpy()
194    # result: array([ 4,  6,  8, 10])
195
196    result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, axis=0).numpy()
197    # result: 28
198    ```
199
200    Args:
201      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
202        be combined.
203      value: A "per replica" value, e.g. returned by `run` to
204        be combined into a single tensor.
205      axis: Specifies the dimension to reduce along within each
206        replica's tensor. Should typically be set to the batch dimension, or
207        `None` to only reduce across replicas (e.g. if the tensor has no batch
208        dimension).
209
210    Returns:
211      A `Tensor`.
212    """
213    return super(CentralStorageStrategy, self).reduce(reduce_op, value, axis)
214
215
216@tf_export(v1=['distribute.experimental.CentralStorageStrategy'])  # pylint: disable=missing-docstring
217class CentralStorageStrategyV1(distribute_lib.StrategyV1):
218
219  __doc__ = CentralStorageStrategy.__doc__
220
221  def __init__(self, compute_devices=None, parameter_device=None):
222    super(CentralStorageStrategyV1, self).__init__(
223        parameter_server_strategy.ParameterServerStrategyExtended(
224            self,
225            compute_devices=compute_devices,
226            parameter_device=parameter_device))
227    distribute_lib.distribution_strategy_gauge.get_cell('V1').set(
228        'CentralStorageStrategy')
229
230  __init__.__doc__ = CentralStorageStrategy.__init__.__doc__
231