1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes implementing a multi-worker ps DistributionStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import input_lib 28from tensorflow.python.distribute import mirrored_strategy 29from tensorflow.python.distribute import multi_worker_util 30from tensorflow.python.distribute import numpy_dataset 31from tensorflow.python.distribute import values 32from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 33from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 34from tensorflow.python.eager import context 35from tensorflow.python.framework import device as tf_device 36from tensorflow.python.framework import ops 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import resource_variable_ops 39from tensorflow.python.ops import variable_scope as vs 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.training import device_setter 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44 45_LOCAL_CPU = "/device:CPU:0" 46_LOCAL_GPU_0 = "/device:GPU:0" 47 48 49# TODO(yuefengz): maybe cache variables on local CPU. 50@tf_export("distribute.experimental.ParameterServerStrategy") 51class ParameterServerStrategy(distribute_lib.DistributionStrategy): 52 """A parameter server DistributionStrategy. 53 54 This strategy class works for both local training and between-graph replicated 55 training for multiple workers. It uses `TFConfigClusterResolver` to detect 56 configurations for multi-worker training. In multi-worker training mode, i.e. 57 `TFConfigClusterResolver` has detected 'TF_CONFIG' environment variable and 58 'TF_CONFIG' has a cluster spec, variables and updates to those variables are 59 assigned to parameter servers and other operations are assigned to workers. 60 In local training mode, variables are assigned to local CPU or the only GPU. 61 When each worker has more than one GPU, operations will be replicated on these 62 GPUs. In both cases, operations are replicated but variables are not and these 63 workers share a common view for which paramater server a variable is assigned 64 to. 65 66 This class assumes between-graph replication will be used and works on a graph 67 for a particular worker. Note that each graph and worker is independent. 68 This means that while each worker will synchronously compute a single gradient 69 update across all GPUs, updates between workers proceed asynchronously. 70 Operations that occur only on the first replica (such as incrementing the 71 global step), will occur on the first replica *of every worker*. 72 73 It is expected to call `call_for_each_replica(fn, ...)` for any 74 operations which potentially can be replicated across replicas (i.e. multiple 75 GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra 76 caution needs to be taken: 77 78 1) It is generally not recommended to open a device scope under the strategy's 79 scope. A device scope (i.e. calling `tf.device`) will be merged with or 80 override the device for operations but will not change the device for 81 variables. 82 83 2) It is also not recommended to open a colocation scope (i.e. calling 84 `tf.colocate_with`) under the strategy's scope. For colocating variables, use 85 `strategy.extended.colocate_vars_with` instead. Colocation of ops will 86 possibly create conflicts of device assignment. 87 """ 88 89 def __init__(self): 90 """Initializes this strategy with default TFConfigClusterResolver.""" 91 super(ParameterServerStrategy, self).__init__( 92 ParameterServerStrategyExtended(self)) 93 94 95class ParameterServerStrategyExtended( 96 distribute_lib.DistributionStrategyExtended): 97 """Implementation of ParameterServerStrategy.""" 98 99 def __init__(self, 100 container_strategy, 101 cluster_resolver=TFConfigClusterResolver()): 102 super(ParameterServerStrategyExtended, self).__init__(container_strategy) 103 self._initialize_strategy(cluster_resolver) 104 105 # We typically don't need to do all-reduce in this strategy. 106 self._cross_device_ops = ( 107 cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) 108 109 def _initialize_strategy(self, cluster_resolver): 110 if cluster_resolver.cluster_spec().as_dict(): 111 self._initialize_multi_worker(cluster_resolver) 112 else: 113 self._initialize_local(cluster_resolver) 114 115 def _initialize_multi_worker(self, cluster_resolver): 116 """Initialize devices for multiple workers. 117 118 It creates variable devices and compute devices. Variables and operations 119 will be assigned to them respectively. We have one compute device per 120 replica. The variable device is a device function or device string. The 121 default variable device assigns variables to parameter servers in a 122 round-robin fashion. 123 124 Args: 125 cluster_resolver: a descendant of `ClusterResolver` object. 126 127 Raises: 128 ValueError: if the cluster doesn't have ps jobs. 129 """ 130 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 131 # some cases. 132 if isinstance(cluster_resolver, TFConfigClusterResolver): 133 num_gpus = context.num_gpus() 134 else: 135 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 136 137 # Save the num_gpus_per_worker for configure method. 138 self._num_gpus_per_worker = num_gpus 139 140 cluster_spec = cluster_resolver.cluster_spec() 141 task_type = cluster_resolver.task_type 142 task_id = cluster_resolver.task_id 143 if not task_type or task_id is None: 144 raise ValueError("When `cluster_spec` is given, you must also specify " 145 "`task_type` and `task_id`") 146 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 147 assert cluster_spec.as_dict() 148 149 worker_device = "/job:%s/task:%d" % (task_type, task_id) 150 self._input_host_device = numpy_dataset.SingleDevice(worker_device) 151 152 # Define compute devices which is a list of device strings and one for each 153 # replica. When there are GPUs, replicate operations on these GPUs. 154 # Otherwise, place operations on CPU. 155 if num_gpus > 0: 156 compute_devices = tuple( 157 "%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus)) 158 else: 159 compute_devices = (worker_device,) 160 161 self._device_map = values.ReplicaDeviceMap(compute_devices) 162 self._input_workers = input_lib.InputWorkers( 163 self._device_map, [(worker_device, compute_devices)]) 164 165 # In distributed mode, place variables on ps jobs in a round-robin fashion. 166 # Note that devices returned from `replica_device_setter` are not 167 # canonical and therefore we don't canonicalize all variable devices to 168 # make them consistent. 169 # TODO(yuefengz): support passing a strategy object to control variable 170 # assignment. 171 # TODO(yuefengz): merge the logic of replica_device_setter into this 172 # class. 173 num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) 174 if num_ps_replicas == 0: 175 raise ValueError("The cluster spec needs to have `ps` jobs.") 176 self._variable_device = device_setter.replica_device_setter( 177 ps_tasks=num_ps_replicas, 178 worker_device=worker_device, 179 merge_devices=True, 180 cluster=cluster_spec) 181 182 # The `_parameter_devices` is needed for the `parameter_devices` property 183 # and is a list of all variable devices. Here parameter devices are all 184 # tasks of the "ps" job. 185 self._parameter_devices = tuple(map("/job:ps/task:{}".format, 186 range(num_ps_replicas))) 187 188 # Add a default device so that ops without specified devices will not end up 189 # on other workers. 190 self._default_device = worker_device 191 192 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 193 task_id) 194 self._cluster_spec = cluster_spec 195 self._task_type = task_type 196 self._task_id = task_id 197 198 logging.info( 199 "Multi-worker ParameterServerStrategy with " 200 "cluster_spec = %r, task_type = %r, task_id = %r, " 201 "num_ps_replicas = %r, is_chief = %r, device_map = %r, " 202 "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, 203 num_ps_replicas, self._is_chief, self._device_map, 204 self._variable_device) 205 206 def _initialize_local(self, cluster_resolver): 207 """Initialize internal devices for local training.""" 208 worker_device = device_util.canonicalize("/device:CPU:0") 209 self._input_host_device = numpy_dataset.SingleDevice(worker_device) 210 211 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 212 # some cases. 213 if isinstance(cluster_resolver, TFConfigClusterResolver): 214 num_gpus = context.num_gpus() 215 else: 216 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 217 218 # Save the num_gpus_per_worker for configure method. 219 self._num_gpus_per_worker = num_gpus 220 221 # Define compute devices which is a list of device strings and one for each 222 # replica. When there are GPUs, replicate operations on these GPUs. 223 # Otherwise, place operations on CPU. 224 if num_gpus > 0: 225 compute_devices = tuple(map("/device:GPU:{}".format, range(num_gpus))) 226 else: 227 compute_devices = (_LOCAL_CPU,) 228 229 self._device_map = values.ReplicaDeviceMap(compute_devices) 230 self._input_workers = input_lib.InputWorkers( 231 self._device_map, [(worker_device, compute_devices)]) 232 233 # If there is only one GPU, put everything on that GPU. Otherwise, place 234 # variables on CPU. 235 if num_gpus == 1: 236 assert len(compute_devices) == 1 237 self._variable_device = _LOCAL_GPU_0 238 self._parameter_devices = (_LOCAL_GPU_0,) 239 else: 240 self._variable_device = _LOCAL_CPU 241 self._parameter_devices = (_LOCAL_CPU,) 242 243 self._is_chief = True 244 self._cluster_spec = None 245 self._task_type = None 246 self._task_id = None 247 248 logging.info( 249 "ParameterServerStrategy with compute_devices = %r, " 250 "variable_device = %r", compute_devices, self._variable_device) 251 252 def _validate_colocate_with_variable(self, colocate_with_variable): 253 values.validate_colocate(colocate_with_variable, self) 254 255 def _make_dataset_iterator(self, dataset): 256 return input_lib.DatasetIterator(dataset, self._input_workers, 257 self._num_replicas_in_sync) 258 259 def _make_input_fn_iterator( 260 self, 261 input_fn, 262 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 263 """Distributes the dataset to each local GPU.""" 264 if self._cluster_spec: 265 input_pipeline_id = multi_worker_util.id_in_cluster( 266 self._cluster_spec, self._task_type, self._task_id) 267 num_input_pipelines = multi_worker_util.worker_count( 268 self._cluster_spec, self._task_type) 269 else: 270 input_pipeline_id = 0 271 num_input_pipelines = 1 272 input_context = distribute_lib.InputContext( 273 num_input_pipelines=num_input_pipelines, 274 input_pipeline_id=input_pipeline_id, 275 num_replicas_in_sync=self._num_replicas_in_sync) 276 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 277 [input_context]) 278 279 def _experimental_make_numpy_dataset(self, numpy_input, session): 280 return numpy_dataset.one_host_numpy_dataset( 281 numpy_input, self._input_host_device, session) 282 283 def _broadcast_to(self, tensor, destinations): 284 # This is both a fast path for Python constants, and a way to delay 285 # converting Python values to a tensor until we know what type it 286 # should be converted to. Otherwise we have trouble with: 287 # global_step.assign_add(1) 288 # since the `1` gets broadcast as an int32 but global_step is int64. 289 if isinstance(tensor, (float, int)): 290 return tensor 291 if not cross_device_ops_lib.check_destinations(destinations): 292 # TODO(josh11b): Use current logical device instead of 0 here. 293 destinations = values.LogicalDeviceSpec( 294 device_map=self._device_map, logical_device=0) 295 return self._cross_device_ops.broadcast(tensor, destinations) 296 297 def _allow_variable_partition(self): 298 return not context.executing_eagerly() 299 300 # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through 301 # this creator, such as "MutableHashTable". 302 def _create_variable(self, next_creator, *args, **kwargs): 303 if self._num_replicas_in_sync > 1: 304 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 305 if aggregation not in ( 306 vs.VariableAggregation.NONE, 307 vs.VariableAggregation.SUM, 308 vs.VariableAggregation.MEAN, 309 vs.VariableAggregation.ONLY_FIRST_REPLICA 310 ): 311 raise ValueError("Invalid variable aggregation mode: " + aggregation + 312 " for variable: " + kwargs["name"]) 313 314 def var_creator(*args, **kwargs): 315 """Create an AggregatingVariable and fix up collections.""" 316 # Record what collections this variable should be added to. 317 collections = kwargs.pop("collections", None) 318 if collections is None: 319 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 320 kwargs["collections"] = [] 321 322 # Create and wrap the variable. 323 v = next_creator(*args, **kwargs) 324 wrapped = values.AggregatingVariable( 325 self._container_strategy(), v, aggregation) 326 327 # Add the wrapped variable to the requested collections. 328 # The handling of eager mode and the global step matches 329 # ResourceVariable._init_from_args(). 330 if not context.executing_eagerly(): 331 g = ops.get_default_graph() 332 # If "trainable" is True, next_creator() will add the contained 333 # variable to the TRAINABLE_VARIABLES collection, so we manually 334 # remove it and replace with the wrapper. We can't set "trainable" 335 # to False for next_creator() since that causes functions like 336 # implicit_gradients to skip those variables. 337 if kwargs.get("trainable", True): 338 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 339 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 340 if v in l: 341 l.remove(v) 342 g.add_to_collections(collections, wrapped) 343 elif ops.GraphKeys.GLOBAL_STEP in collections: 344 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) 345 346 return wrapped 347 else: 348 var_creator = next_creator 349 350 if "colocate_with" in kwargs: 351 colocate_with = kwargs["colocate_with"] 352 if isinstance(colocate_with, numpy_dataset.SingleDevice): 353 with ops.device(colocate_with.device): 354 return var_creator(*args, **kwargs) 355 with ops.device(None): 356 with ops.colocate_with(colocate_with): 357 return var_creator(*args, **kwargs) 358 359 with ops.colocate_with(None, ignore_existing=True): 360 with ops.device(self._variable_device): 361 return var_creator(*args, **kwargs) 362 363 def _call_for_each_replica(self, fn, args, kwargs): 364 # pylint: disable=protected-access 365 return mirrored_strategy._call_for_each_replica( 366 self._container_strategy(), self._device_map, fn, args, kwargs) 367 368 def _verify_destinations_not_different_worker(self, destinations): 369 if not self._cluster_spec: 370 return 371 if destinations is None: 372 return 373 for d in cross_device_ops_lib.get_devices_from(destinations): 374 d_spec = tf_device.DeviceSpec.from_string(d) 375 if d_spec.job == self._task_type and d_spec.task != self._task_id: 376 raise ValueError( 377 "Cannot reduce to another worker: %r, current worker is %r" % 378 (d, self._input_workers.worker_devices[0])) 379 380 def _reduce_to(self, reduce_op, value, destinations): 381 self._verify_destinations_not_different_worker(destinations) 382 if not isinstance(value, values.DistributedValues): 383 # pylint: disable=protected-access 384 return cross_device_ops_lib.reduce_non_distributed_value( 385 reduce_op, self._device_map, value, destinations) 386 return self._cross_device_ops.reduce( 387 reduce_op, value, destinations=destinations) 388 389 def _batch_reduce_to(self, reduce_op, value_destination_pairs): 390 for _, destinations in value_destination_pairs: 391 self._verify_destinations_not_different_worker(destinations) 392 return self._cross_device_ops.batch_reduce(reduce_op, 393 value_destination_pairs) 394 395 def _select_single_value(self, structured): 396 """Select any single values in `structured`.""" 397 398 def _select_fn(x): # pylint: disable=g-missing-docstring 399 if isinstance(x, values.Mirrored): 400 if len(x.devices) == 1: 401 return x.primary 402 else: 403 raise ValueError( 404 "You cannot update variable with a Mirrored object with multiple " 405 "components %r when using ParameterServerStrategy. You must " 406 "specify a single value or a Mirrored with a single value." % x) 407 elif isinstance(x, values.PerReplica): 408 raise ValueError( 409 "You cannot update variable with a PerReplica object %r when using " 410 "ParameterServerStrategy. You must specify a single value or a " 411 "Mirrored with a single value" % x) 412 else: 413 return x 414 415 return nest.map_structure(_select_fn, structured) 416 417 def _update(self, var, fn, args, kwargs, group): 418 if isinstance(var, values.AggregatingVariable): 419 var = var.get() 420 if not isinstance(var, resource_variable_ops.ResourceVariable): 421 raise ValueError( 422 "You can not update `var` %r. It must be a Variable." % var) 423 with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): 424 result = fn(var, *self._select_single_value(args), 425 **self._select_single_value(kwargs)) 426 if group: 427 return result 428 else: 429 return nest.map_structure(self._local_results, result) 430 431 # TODO(yuefengz): does it need to call _select_single_value? 432 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 433 with ops.device( 434 colocate_with.device), distribute_lib.UpdateContext(colocate_with): 435 result = fn(*args, **kwargs) 436 if group: 437 return result 438 else: 439 return nest.map_structure(self._local_results, result) 440 441 def _local_results(self, val): 442 if isinstance(val, values.DistributedValues): 443 return val.values 444 return (val,) 445 446 def value_container(self, val): 447 if (hasattr(val, "_aggregating_container") and 448 not isinstance(val, values.AggregatingVariable)): 449 wrapper = val._aggregating_container() # pylint: disable=protected-access 450 if wrapper is not None: 451 return wrapper 452 return val 453 454 def read_var(self, var): 455 # No need to distinguish between normal variables and replica-local 456 # variables. 457 return array_ops.identity(var) 458 459 def _configure(self, 460 session_config=None, 461 cluster_spec=None, 462 task_type=None, 463 task_id=None): 464 """Configures the strategy class. 465 466 The strategy object will be re-initialized if `cluster_spec` is given but 467 was not passed in the constructor. 468 469 Args: 470 session_config: not used currently. 471 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 472 cluster configurations. 473 task_type: the current task type. 474 task_id: the current task id. 475 476 Raises: 477 ValueError: if `cluster_spec` is given but `task_type` or `task_id` is 478 not. 479 """ 480 if cluster_spec: 481 # Use the num_gpus_per_worker recorded in constructor since _configure 482 # doesn't take num_gpus. 483 cluster_resolver = SimpleClusterResolver( 484 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 485 task_type=task_type, 486 task_id=task_id, 487 num_accelerators={"GPU": self._num_gpus_per_worker}) 488 self._initialize_multi_worker(cluster_resolver) 489 490 if session_config: 491 session_config.CopyFrom(self._update_config_proto(session_config)) 492 493 def _update_config_proto(self, config_proto): 494 updated_config = copy.deepcopy(config_proto) 495 if not self._cluster_spec: 496 updated_config.isolate_session_state = True 497 return updated_config 498 499 updated_config.isolate_session_state = False 500 501 assert self._task_type 502 assert self._task_id is not None 503 504 # The device filters prevent communication between workers. 505 if self._task_type not in ["chief", "worker"]: 506 return updated_config 507 del updated_config.device_filters[:] 508 updated_config.device_filters.extend( 509 ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) 510 return updated_config 511 512 @property 513 def _num_replicas_in_sync(self): 514 return self._device_map.num_replicas_in_graph 515 516 @property 517 def worker_devices(self): 518 return self._device_map.all_devices 519 520 @property 521 def worker_devices_by_replica(self): 522 return self._device_map.devices_by_replica 523 524 @property 525 def parameter_devices(self): 526 return self._parameter_devices 527 528 def non_slot_devices(self, var_list): 529 return min(var_list, key=lambda x: x.name) 530 531 @property 532 def experimental_between_graph(self): 533 # TODO(yuefengz): Should this return False in the local case? 534 return True 535 536 @property 537 def experimental_should_init(self): 538 return self._is_chief 539 540 @property 541 def should_checkpoint(self): 542 return self._is_chief 543 544 @property 545 def should_save_summary(self): 546 return self._is_chief 547 548 # TODO(priyag): Delete this once all strategies use global batch size. 549 @property 550 def _global_batch_size(self): 551 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 552 553 `make_input_fn_iterator` assumes per-replica batching. 554 555 Returns: 556 Boolean. 557 """ 558 return True 559