1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A tf.distribute.Strategy for running on a single device.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import device_util 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import distribute_utils 24from tensorflow.python.distribute import input_lib 25from tensorflow.python.distribute import numpy_dataset 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.util import nest 31from tensorflow.python.util.tf_export import tf_export 32 33 34# TODO(josh11b): Do we wrap values in types to generate errors if you are 35# doing something that won't work with other DistributionStrategy 36# implementations? 37 38 39@tf_export("distribute.OneDeviceStrategy", v1=[]) 40class OneDeviceStrategy(distribute_lib.Strategy): 41 """A distribution strategy for running on a single device. 42 43 Using this strategy will place any variables created in its scope on the 44 specified device. Input distributed through this strategy will be 45 prefetched to the specified device. Moreover, any functions called via 46 `strategy.run` will also be placed on the specified device 47 as well. 48 49 Typical usage of this strategy could be testing your code with the 50 tf.distribute.Strategy API before switching to other strategies which 51 actually distribute to multiple devices/machines. 52 53 For example: 54 ``` 55 strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") 56 57 with strategy.scope(): 58 v = tf.Variable(1.0) 59 print(v.device) # /job:localhost/replica:0/task:0/device:GPU:0 60 61 def step_fn(x): 62 return x * 2 63 64 result = 0 65 for i in range(10): 66 result += strategy.run(step_fn, args=(i,)) 67 print(result) # 90 68 ``` 69 """ 70 71 def __init__(self, device): 72 """Creates a `OneDeviceStrategy`. 73 74 Args: 75 device: Device string identifier for the device on which the variables 76 should be placed. See class docs for more details on how the device is 77 used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0" 78 """ 79 super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) 80 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 81 "OneDeviceStrategy") 82 83 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation 84 """Distributes a tf.data.Dataset instance provided via dataset. 85 86 In this case, there is only one device, so this is only a thin wrapper 87 around the input dataset. It will, however, prefetch the input data to the 88 specified device. The returned distributed dataset can be iterated over 89 similar to how regular datasets can. 90 91 NOTE: Currently, the user cannot add any more transformations to a 92 distributed dataset. 93 94 Example: 95 ``` 96 strategy = tf.distribute.OneDeviceStrategy() 97 dataset = tf.data.Dataset.range(10).batch(2) 98 dist_dataset = strategy.experimental_distribute_dataset(dataset) 99 for x in dist_dataset: 100 print(x) # [0, 1], [2, 3],... 101 ``` 102 Args: 103 dataset: `tf.data.Dataset` to be prefetched to device. 104 options: `tf.distribute.InputOptions` used to control options on how this 105 dataset is distributed. 106 Returns: 107 A "distributed `Dataset`" that the caller can iterate over. 108 """ 109 return super(OneDeviceStrategy, self).experimental_distribute_dataset( 110 dataset, options) 111 112 def distribute_datasets_from_function( 113 self, 114 dataset_fn, # pylint: disable=useless-super-delegation 115 options=None): 116 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 117 118 `dataset_fn` will be called once for each worker in the strategy. In this 119 case, we only have one worker and one device so `dataset_fn` is called 120 once. 121 122 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 123 information about batching and input replication can be accessed: 124 125 ``` 126 def dataset_fn(input_context): 127 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 128 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 129 return d.shard( 130 input_context.num_input_pipelines, input_context.input_pipeline_id) 131 132 inputs = strategy.distribute_datasets_from_function(dataset_fn) 133 134 for batch in inputs: 135 replica_results = strategy.run(replica_fn, args=(batch,)) 136 ``` 137 138 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 139 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 140 the global batch size. This may be computed using 141 `input_context.get_per_replica_batch_size`. 142 143 Args: 144 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 145 returning a `tf.data.Dataset`. 146 options: `tf.distribute.InputOptions` used to control options on how this 147 dataset is distributed. 148 149 Returns: 150 A "distributed `Dataset`", which the caller can iterate over like regular 151 datasets. 152 """ 153 return super(OneDeviceStrategy, 154 self).distribute_datasets_from_function(dataset_fn, options) 155 156 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation 157 """Returns the list of all local per-replica values contained in `value`. 158 159 In `OneDeviceStrategy`, the `value` is always expected to be a single 160 value, so the result is just the value in a tuple. 161 162 Args: 163 value: A value returned by `experimental_run()`, `run()`, 164 `extended.call_for_each_replica()`, or a variable created in `scope`. 165 166 Returns: 167 A tuple of values contained in `value`. If `value` represents a single 168 value, this returns `(value,).` 169 """ 170 return super(OneDeviceStrategy, self).experimental_local_results(value) 171 172 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation 173 """Run `fn` on each replica, with the given arguments. 174 175 In `OneDeviceStrategy`, `fn` is simply called within a device scope for the 176 given device, with the provided arguments. 177 178 Args: 179 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 180 args: (Optional) Positional arguments to `fn`. 181 kwargs: (Optional) Keyword arguments to `fn`. 182 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 183 the options to run `fn`. 184 185 Returns: 186 Return value from running `fn`. 187 """ 188 return super(OneDeviceStrategy, self).run(fn, args, kwargs, options) 189 190 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation 191 """Reduce `value` across replicas. 192 193 In `OneDeviceStrategy`, there is only one replica, so if axis=None, value 194 is simply returned. If axis is specified as something other than None, 195 such as axis=0, value is reduced along that axis and returned. 196 197 Example: 198 ``` 199 t = tf.range(10) 200 201 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy() 202 # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 203 204 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy() 205 # result: 45 206 ``` 207 208 Args: 209 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 210 be combined. 211 value: A "per replica" value, e.g. returned by `run` to 212 be combined into a single tensor. 213 axis: Specifies the dimension to reduce along within each 214 replica's tensor. Should typically be set to the batch dimension, or 215 `None` to only reduce across replicas (e.g. if the tensor has no batch 216 dimension). 217 218 Returns: 219 A `Tensor`. 220 """ 221 return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis) 222 223 def scope(self): # pylint: disable=useless-super-delegation 224 """Returns a context manager selecting this Strategy as current. 225 226 Inside a `with strategy.scope():` code block, this thread 227 will use a variable creator set by `strategy`, and will 228 enter its "cross-replica context". 229 230 In `OneDeviceStrategy`, all variables created inside `strategy.scope()` 231 will be on `device` specified at strategy construction time. 232 See example in the docs for this class. 233 234 Returns: 235 A context manager to use for creating variables with this strategy. 236 """ 237 return super(OneDeviceStrategy, self).scope() 238 239 240@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring 241class OneDeviceStrategyV1(distribute_lib.StrategyV1): 242 243 __doc__ = OneDeviceStrategy.__doc__.replace( 244 "For example:\n ```", 245 "For example:\n ```\n tf.enable_eager_execution()") 246 247 def __init__(self, device): 248 super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device)) 249 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 250 "OneDeviceStrategy") 251 __init__.__doc__ = OneDeviceStrategy.__init__.__doc__ 252 253 254# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs. 255class OneDeviceExtended(distribute_lib.StrategyExtendedV1): 256 """Implementation of OneDeviceStrategy.""" 257 258 def __init__(self, container_strategy, device): 259 super(OneDeviceExtended, self).__init__(container_strategy) 260 self._device = device_util.resolve(device) 261 self._input_device = device_util.get_host_for_device(self._device) 262 263 def _input_workers_with_options(self, options=None): 264 if not options or options.experimental_prefetch_to_device: 265 return input_lib.InputWorkers([(self._input_device, (self._device,))]) 266 else: 267 return input_lib.InputWorkers([(self._input_device, 268 (self._input_device,))]) 269 270 @property 271 def _input_workers(self): 272 return self._input_workers_with_options() 273 274 def _create_variable(self, next_creator, **kwargs): 275 colocate_with = kwargs.pop("colocate_with", None) 276 if colocate_with is None: 277 with ops.device(self._device): 278 return next_creator(**kwargs) 279 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 280 with ops.device(colocate_with.device): 281 return next_creator(**kwargs) 282 else: 283 with ops.colocate_with(colocate_with): 284 return next_creator(**kwargs) 285 286 def _validate_colocate_with_variable(self, colocate_with_variable): 287 distribute_utils.validate_colocate(colocate_with_variable, self) 288 289 def _make_dataset_iterator(self, dataset): 290 """Make iterator from dataset without splitting the batch.""" 291 # Note that split_batch_by argument is not passed because it is always 1 in 292 # this strategy, and adding it adds unnecessary overhead to the dataset. 293 return input_lib.DatasetIterator(dataset, self._input_workers, 294 self._container_strategy()) 295 296 def _make_input_fn_iterator( 297 self, 298 input_fn, 299 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 300 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 301 [distribute_lib.InputContext()], 302 self._container_strategy()) 303 304 def _experimental_make_numpy_dataset(self, numpy_input, session): 305 return numpy_dataset.one_host_numpy_dataset( 306 numpy_input, numpy_dataset.SingleDevice(self._input_device), session) 307 308 def _broadcast_to(self, tensor, destinations): 309 del destinations 310 return tensor 311 312 def _experimental_distribute_dataset(self, dataset, options): 313 # Note that split_batch_by argument is not passed because it is always 1 in 314 # this strategy, and adding it adds unnecessary overhead to the dataset. 315 if (options and options.experimental_replication_mode == 316 distribute_lib.InputReplicationMode.PER_REPLICA): 317 raise NotImplementedError( 318 "InputReplicationMode.PER_REPLICA " 319 "is only supported in " 320 "`experimental_distribute_datasets_from_function`." 321 ) 322 return input_lib.get_distributed_dataset( 323 dataset, 324 self._input_workers_with_options(options), 325 self._container_strategy()) 326 327 def _distribute_datasets_from_function(self, dataset_fn, options): 328 if (options and options.experimental_replication_mode == 329 distribute_lib.InputReplicationMode.PER_REPLICA): 330 raise NotImplementedError( 331 "InputReplicationMode.PER_REPLICA " 332 "is only supported in " 333 "`experimental_distribute_datasets_from_function` " 334 "of tf.distribute.MirroredStrategy") 335 return input_lib.get_distributed_datasets_from_function( 336 dataset_fn, 337 self._input_workers_with_options(options), 338 [distribute_lib.InputContext()], 339 self._container_strategy()) 340 341 def _experimental_distribute_values_from_function(self, value_fn): 342 # TODO(b/137795644): This should return a PerReplica value but other 343 # methods like run in OneDeviceStrategy need to be modified 344 # to do the same. 345 return value_fn(distribute_lib.ValueContext()) 346 347 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 348 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 349 initial_loop_values=None): 350 if initial_loop_values is None: 351 initial_loop_values = {} 352 initial_loop_values = nest.flatten(initial_loop_values) 353 354 ctx = input_lib.MultiStepContext() 355 def body(i, *args): 356 """A wrapper around `fn` to create the while loop body.""" 357 del args 358 fn_result = fn(ctx, iterator.get_next()) 359 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 360 with ops.control_dependencies([fn_result]): 361 return [i + 1] + flat_last_step_outputs 362 363 # We capture the control_flow_context at this point, before we run `fn` 364 # inside a while_loop. This is useful in cases where we might need to exit 365 # these contexts and get back to the outer context to do some things, for 366 # e.g. create an op which should be evaluated only once at the end of the 367 # loop on the host. One such usage is in creating metrics' value op. 368 self._outer_control_flow_context = ( 369 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 370 371 # TODO(priyag): Use max_iterations instead of an explicit counter. 372 cond = lambda i, *args: i < iterations 373 i = constant_op.constant(0) 374 loop_result = control_flow_ops.while_loop( 375 cond, body, [i] + initial_loop_values, name="", 376 parallel_iterations=1, back_prop=False, swap_memory=False, 377 return_same_structure=True) 378 del self._outer_control_flow_context 379 380 ctx.run_op = control_flow_ops.group(loop_result) 381 382 # Convert the last_step_outputs from a list to the original dict structure 383 # of last_step_outputs. 384 last_step_tensor_outputs = loop_result[1:] 385 last_step_tensor_outputs_dict = nest.pack_sequence_as( 386 ctx.last_step_outputs, last_step_tensor_outputs) 387 388 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 389 return ctx 390 391 def _call_for_each_replica(self, fn, args, kwargs): 392 strategy = self._container_strategy() 393 with ops.device(self._device), _OneDeviceReplicaContext(strategy): 394 return fn(*args, **kwargs) 395 396 def _reduce_to(self, reduce_op, value, destinations, options): 397 del reduce_op, destinations, options 398 return value 399 400 def _gather_to_implementation(self, value, destinations, axis, options): 401 del destinations, axis, options 402 return value 403 404 def _update(self, var, fn, args, kwargs, group): 405 # The implementations of _update() and _update_non_slot() are identical 406 # except _update() passes `var` as the first argument to `fn()`. 407 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 408 409 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 410 del colocate_with 411 with ops.device(self._device), distribute_lib.UpdateContext(self._device): 412 result = fn(*args, **kwargs) 413 if group: 414 return result 415 else: 416 return nest.map_structure(self._local_results, result) 417 418 def read_var(self, replica_local_var): 419 """Read the aggregate value of a replica-local variable.""" 420 return array_ops.identity(replica_local_var) 421 422 def _local_results(self, value): 423 return (value,) 424 425 def value_container(self, value): 426 return value 427 428 def _in_multi_worker_mode(self): 429 """Whether this strategy indicates working in multi-worker settings.""" 430 return False 431 432 @property 433 def _num_replicas_in_sync(self): 434 return 1 435 436 @property 437 def worker_devices(self): 438 return (self._device,) 439 440 @property 441 def parameter_devices(self): 442 return (self._device,) 443 444 def non_slot_devices(self, var_list): 445 del var_list 446 return (self._device,) 447 448 @property 449 def experimental_should_init(self): 450 return True 451 452 @property 453 def experimental_between_graph(self): 454 return False 455 456 @property 457 def should_checkpoint(self): 458 return True 459 460 @property 461 def should_save_summary(self): 462 return True 463 464 # TODO(priyag): Delete this once all strategies use global batch size. 465 @property 466 def _global_batch_size(self): 467 """Global and per-replica batching are equivalent for OneDeviceStrategy.""" 468 return True 469 470 @property 471 def _support_per_replica_values(self): 472 return False 473 474 def _get_local_replica_id(self, replica_id_in_sync_group): 475 return replica_id_in_sync_group 476 477 478class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): 479 """ReplicaContext for OneDeviceStrategy.""" 480 481 def __init__(self, strategy): 482 distribute_lib.ReplicaContext.__init__( 483 self, strategy, replica_id_in_sync_group=0) 484 485 @property 486 def devices(self): 487 return self._strategy.extended.worker_devices 488