1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing a single machine parameter server strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import device_util 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import parameter_server_strategy 24from tensorflow.python.util.tf_export import tf_export 25 26 27@tf_export('distribute.experimental.CentralStorageStrategy', v1=[]) 28class CentralStorageStrategy(distribute_lib.Strategy): 29 """A one-machine strategy that puts all variables on a single device. 30 31 Variables are assigned to local CPU or the only GPU. If there is more 32 than one GPU, compute operations (other than variable update operations) 33 will be replicated across all GPUs. 34 35 For Example: 36 ``` 37 strategy = tf.distribute.experimental.CentralStorageStrategy() 38 # Create a dataset 39 ds = tf.data.Dataset.range(5).batch(2) 40 # Distribute that dataset 41 dist_dataset = strategy.experimental_distribute_dataset(ds) 42 43 with strategy.scope(): 44 @tf.function 45 def train_step(val): 46 return val + 1 47 48 # Iterate over the distributed dataset 49 for x in dist_dataset: 50 # process dataset elements 51 strategy.run(train_step, args=(x,)) 52 ``` 53 """ 54 55 def __init__(self, compute_devices=None, parameter_device=None): 56 extended = parameter_server_strategy.ParameterServerStrategyExtended( 57 self, 58 compute_devices=compute_devices, 59 parameter_device=parameter_device) 60 """Initializes the strategy with optional device strings. 61 62 Args: 63 compute_devices: an optional list of strings for device to replicate models 64 on. If this is not provided, all local GPUs will be used; if there is no 65 GPU, local CPU will be used. 66 parameter_device: an optional device string for which device to put 67 variables on. The default one is CPU or GPU if there is only one. 68 """ 69 super(CentralStorageStrategy, self).__init__(extended) 70 distribute_lib.distribution_strategy_gauge.get_cell('V2').set( 71 'CentralStorageStrategy') 72 73 @classmethod 74 def _from_num_gpus(cls, num_gpus): 75 return cls(device_util.local_devices_from_num_gpus(num_gpus)) 76 77 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation 78 """Distributes a tf.data.Dataset instance provided via dataset. 79 80 The returned dataset is a wrapped strategy dataset which creates a 81 multidevice iterator under the hood. It prefetches the input data to the 82 specified devices on the worker. The returned distributed dataset can be 83 iterated over similar to how regular datasets can. 84 85 NOTE: Currently, the user cannot add any more transformations to a 86 distributed dataset. 87 88 For Example: 89 ``` 90 strategy = tf.distribute.CentralStorageStrategy() # with 1 CPU and 1 GPU 91 dataset = tf.data.Dataset.range(10).batch(2) 92 dist_dataset = strategy.experimental_distribute_dataset(dataset) 93 for x in dist_dataset: 94 print(x) # Prints PerReplica values [0, 1], [2, 3],... 95 96 ``` 97 Args: 98 dataset: `tf.data.Dataset` to be prefetched to device. 99 options: `tf.distribute.InputOptions` used to control options on how this 100 dataset is distributed. 101 102 Returns: 103 A "distributed `Dataset`" that the caller can iterate over. 104 """ 105 if (options and options.experimental_replication_moden == 106 distribute_lib.InputReplicationMode.PER_REPLICA): 107 raise NotImplementedError( 108 'InputReplicationMode.PER_REPLICA ' 109 'is only supported in ' 110 '`experimental_distribute_datasets_from_function`.' 111 ) 112 return super(CentralStorageStrategy, self).experimental_distribute_dataset( 113 dataset, options) 114 115 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation 116 """Returns the list of all local per-replica values contained in `value`. 117 118 In `CentralStorageStrategy` there is a single worker so the value returned 119 will be all the values on that worker. 120 121 Args: 122 value: A value returned by `run()`, `extended.call_for_each_replica()`, 123 or a variable created in `scope`. 124 125 Returns: 126 A tuple of values contained in `value`. If `value` represents a single 127 value, this returns `(value,).` 128 """ 129 return super(CentralStorageStrategy, self).experimental_local_results(value) 130 131 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation 132 """Run `fn` on each replica, with the given arguments. 133 134 In `CentralStorageStrategy`, `fn` is called on each of the compute 135 replicas, with the provided "per replica" arguments specific to that device. 136 137 Args: 138 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 139 args: (Optional) Positional arguments to `fn`. 140 kwargs: (Optional) Keyword arguments to `fn`. 141 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 142 the options to run `fn`. 143 144 Returns: 145 Return value from running `fn`. 146 """ 147 return super(CentralStorageStrategy, self).run(fn, args, kwargs, options) 148 149 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation 150 """Reduce `value` across replicas. 151 152 Given a per-replica value returned by `run`, say a 153 per-example loss, the batch will be divided across all the replicas. This 154 function allows you to aggregate across replicas and optionally also across 155 batch elements. For example, if you have a global batch size of 8 and 2 156 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 157 `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just 158 aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful 159 when each replica is computing a scalar or some other value that doesn't 160 have a "batch" dimension (like a gradient). More often you will want to 161 aggregate across the global batch, which you can get by specifying the batch 162 dimension as the `axis`, typically `axis=0`. In this case it would return a 163 scalar `0+1+2+3+4+5+6+7`. 164 165 If there is a last partial batch, you will need to specify an axis so 166 that the resulting shape is consistent across replicas. So if the last 167 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 168 would get a shape mismatch unless you specify `axis=0`. If you specify 169 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 170 denominator of 6. Contrast this with computing `reduce_mean` to get a 171 scalar value on each replica and this function to average those means, 172 which will weigh some values `1/8` and others `1/4`. 173 174 For Example: 175 ``` 176 strategy = tf.distribute.experimental.CentralStorageStrategy( 177 compute_devices=['CPU:0', 'GPU:0'], parameter_device='CPU:0') 178 ds = tf.data.Dataset.range(10) 179 # Distribute that dataset 180 dist_dataset = strategy.experimental_distribute_dataset(ds) 181 182 with strategy.scope(): 183 @tf.function 184 def train_step(val): 185 # pass through 186 return val 187 188 # Iterate over the distributed dataset 189 for x in dist_dataset: 190 result = strategy.run(train_step, args=(x,)) 191 192 result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, 193 axis=None).numpy() 194 # result: array([ 4, 6, 8, 10]) 195 196 result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, axis=0).numpy() 197 # result: 28 198 ``` 199 200 Args: 201 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 202 be combined. 203 value: A "per replica" value, e.g. returned by `run` to 204 be combined into a single tensor. 205 axis: Specifies the dimension to reduce along within each 206 replica's tensor. Should typically be set to the batch dimension, or 207 `None` to only reduce across replicas (e.g. if the tensor has no batch 208 dimension). 209 210 Returns: 211 A `Tensor`. 212 """ 213 return super(CentralStorageStrategy, self).reduce(reduce_op, value, axis) 214 215 216@tf_export(v1=['distribute.experimental.CentralStorageStrategy']) # pylint: disable=missing-docstring 217class CentralStorageStrategyV1(distribute_lib.StrategyV1): 218 219 __doc__ = CentralStorageStrategy.__doc__ 220 221 def __init__(self, compute_devices=None, parameter_device=None): 222 super(CentralStorageStrategyV1, self).__init__( 223 parameter_server_strategy.ParameterServerStrategyExtended( 224 self, 225 compute_devices=compute_devices, 226 parameter_device=parameter_device)) 227 distribute_lib.distribution_strategy_gauge.get_cell('V1').set( 228 'CentralStorageStrategy') 229 230 __init__.__doc__ = CentralStorageStrategy.__init__.__doc__ 231