1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes implementing a multi-worker ps DistributionStrategy."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import copy
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import input_lib
28from tensorflow.python.distribute import mirrored_strategy
29from tensorflow.python.distribute import multi_worker_util
30from tensorflow.python.distribute import numpy_dataset
31from tensorflow.python.distribute import values
32from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
33from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
34from tensorflow.python.eager import context
35from tensorflow.python.framework import device as tf_device
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.training import device_setter
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
45_LOCAL_CPU = "/device:CPU:0"
46_LOCAL_GPU_0 = "/device:GPU:0"
49# TODO(yuefengz): maybe cache variables on local CPU.
51class ParameterServerStrategy(distribute_lib.DistributionStrategy):
52  """A parameter server DistributionStrategy.
54  This strategy class works for both local training and between-graph replicated
55  training for multiple workers. It uses `TFConfigClusterResolver` to detect
56  configurations for multi-worker training. In multi-worker training mode, i.e.
57  `TFConfigClusterResolver` has detected 'TF_CONFIG' environment variable and
58  'TF_CONFIG' has a cluster spec, variables and updates to those variables are
59  assigned to parameter servers and other operations are assigned to workers.
60  In local training mode, variables are assigned to local CPU or the only GPU.
61  When each worker has more than one GPU, operations will be replicated on these
62  GPUs. In both cases, operations are replicated but variables are not and these
63  workers share a common view for which paramater server a variable is assigned
64  to.
66  This class assumes between-graph replication will be used and works on a graph
67  for a particular worker. Note that each graph and worker is independent.
68  This means that while each worker will synchronously compute a single gradient
69  update across all GPUs, updates between workers proceed asynchronously.
70  Operations that occur only on the first replica (such as incrementing the
71  global step), will occur on the first replica *of every worker*.
73  It is expected to call `call_for_each_replica(fn, ...)` for any
74  operations which potentially can be replicated across replicas (i.e. multiple
75  GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
76  caution needs to be taken:
78  1) It is generally not recommended to open a device scope under the strategy's
79  scope. A device scope (i.e. calling `tf.device`) will be merged with or
80  override the device for operations but will not change the device for
81  variables.
83  2) It is also not recommended to open a colocation scope (i.e. calling
84  `tf.colocate_with`) under the strategy's scope. For colocating variables, use
85  `strategy.extended.colocate_vars_with` instead. Colocation of ops will
86  possibly create conflicts of device assignment.
87  """
89  def __init__(self):
90    """Initializes this strategy with default TFConfigClusterResolver."""
91    super(ParameterServerStrategy, self).__init__(
92        ParameterServerStrategyExtended(self))
95class ParameterServerStrategyExtended(
96    distribute_lib.DistributionStrategyExtended):
97  """Implementation of ParameterServerStrategy."""
99  def __init__(self,
100               container_strategy,
101               cluster_resolver=TFConfigClusterResolver()):
102    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
103    self._initialize_strategy(cluster_resolver)
105    # We typically don't need to do all-reduce in this strategy.
106    self._cross_device_ops = (
107        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
109  def _initialize_strategy(self, cluster_resolver):
110    if cluster_resolver.cluster_spec().as_dict():
111      self._initialize_multi_worker(cluster_resolver)
112    else:
113      self._initialize_local(cluster_resolver)
115  def _initialize_multi_worker(self, cluster_resolver):
116    """Initialize devices for multiple workers.
118    It creates variable devices and compute devices. Variables and operations
119    will be assigned to them respectively. We have one compute device per
120    replica. The variable device is a device function or device string. The
121    default variable device assigns variables to parameter servers in a
122    round-robin fashion.
124    Args:
125      cluster_resolver: a descendant of `ClusterResolver` object.
127    Raises:
128      ValueError: if the cluster doesn't have ps jobs.
129    """
130    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
131    # some cases.
132    if isinstance(cluster_resolver, TFConfigClusterResolver):
133      num_gpus = context.num_gpus()
134    else:
135      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
137    # Save the num_gpus_per_worker for configure method.
138    self._num_gpus_per_worker = num_gpus
140    cluster_spec = cluster_resolver.cluster_spec()
141    task_type = cluster_resolver.task_type
142    task_id = cluster_resolver.task_id
143    if not task_type or task_id is None:
144      raise ValueError("When `cluster_spec` is given, you must also specify "
145                       "`task_type` and `task_id`")
146    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
147    assert cluster_spec.as_dict()
149    worker_device = "/job:%s/task:%d" % (task_type, task_id)
150    self._input_host_device = numpy_dataset.SingleDevice(worker_device)
152    # Define compute devices which is a list of device strings and one for each
153    # replica. When there are GPUs, replicate operations on these GPUs.
154    # Otherwise, place operations on CPU.
155    if num_gpus > 0:
156      compute_devices = tuple(
157          "%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus))
158    else:
159      compute_devices = (worker_device,)
161    self._device_map = values.ReplicaDeviceMap(compute_devices)
162    self._input_workers = input_lib.InputWorkers(
163        self._device_map, [(worker_device, compute_devices)])
165    # In distributed mode, place variables on ps jobs in a round-robin fashion.
166    # Note that devices returned from `replica_device_setter` are not
167    # canonical and therefore we don't canonicalize all variable devices to
168    # make them consistent.
169    # TODO(yuefengz): support passing a strategy object to control variable
170    # assignment.
171    # TODO(yuefengz): merge the logic of replica_device_setter into this
172    # class.
173    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
174    if num_ps_replicas == 0:
175      raise ValueError("The cluster spec needs to have `ps` jobs.")
176    self._variable_device = device_setter.replica_device_setter(
177        ps_tasks=num_ps_replicas,
178        worker_device=worker_device,
179        merge_devices=True,
180        cluster=cluster_spec)
182    # The `_parameter_devices` is needed for the `parameter_devices` property
183    # and is a list of all variable devices. Here parameter devices are all
184    # tasks of the "ps" job.
185    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
186                                        range(num_ps_replicas)))
188    # Add a default device so that ops without specified devices will not end up
189    # on other workers.
190    self._default_device = worker_device
192    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
193                                                task_id)
194    self._cluster_spec = cluster_spec
195    self._task_type = task_type
196    self._task_id = task_id
198    logging.info(
199        "Multi-worker ParameterServerStrategy with "
200        "cluster_spec = %r, task_type = %r, task_id = %r, "
201        "num_ps_replicas = %r, is_chief = %r, device_map = %r, "
202        "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
203        num_ps_replicas, self._is_chief, self._device_map,
204        self._variable_device)
206  def _initialize_local(self, cluster_resolver):
207    """Initialize internal devices for local training."""
208    worker_device = device_util.canonicalize("/device:CPU:0")
209    self._input_host_device = numpy_dataset.SingleDevice(worker_device)
211    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
212    # some cases.
213    if isinstance(cluster_resolver, TFConfigClusterResolver):
214      num_gpus = context.num_gpus()
215    else:
216      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
218    # Save the num_gpus_per_worker for configure method.
219    self._num_gpus_per_worker = num_gpus
221    # Define compute devices which is a list of device strings and one for each
222    # replica. When there are GPUs, replicate operations on these GPUs.
223    # Otherwise, place operations on CPU.
224    if num_gpus > 0:
225      compute_devices = tuple(map("/device:GPU:{}".format, range(num_gpus)))
226    else:
227      compute_devices = (_LOCAL_CPU,)
229    self._device_map = values.ReplicaDeviceMap(compute_devices)
230    self._input_workers = input_lib.InputWorkers(
231        self._device_map, [(worker_device, compute_devices)])
233    # If there is only one GPU, put everything on that GPU. Otherwise, place
234    # variables on CPU.
235    if num_gpus == 1:
236      assert len(compute_devices) == 1
237      self._variable_device = _LOCAL_GPU_0
238      self._parameter_devices = (_LOCAL_GPU_0,)
239    else:
240      self._variable_device = _LOCAL_CPU
241      self._parameter_devices = (_LOCAL_CPU,)
243    self._is_chief = True
244    self._cluster_spec = None
245    self._task_type = None
246    self._task_id = None
248    logging.info(
249        "ParameterServerStrategy with compute_devices = %r, "
250        "variable_device = %r", compute_devices, self._variable_device)
252  def _validate_colocate_with_variable(self, colocate_with_variable):
253    values.validate_colocate(colocate_with_variable, self)
255  def _make_dataset_iterator(self, dataset):
256    return input_lib.DatasetIterator(dataset, self._input_workers,
257                                     self._num_replicas_in_sync)
259  def _make_input_fn_iterator(
260      self,
261      input_fn,
262      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
263    """Distributes the dataset to each local GPU."""
264    if self._cluster_spec:
265      input_pipeline_id = multi_worker_util.id_in_cluster(
266          self._cluster_spec, self._task_type, self._task_id)
267      num_input_pipelines = multi_worker_util.worker_count(
268          self._cluster_spec, self._task_type)
269    else:
270      input_pipeline_id = 0
271      num_input_pipelines = 1
272    input_context = distribute_lib.InputContext(
273        num_input_pipelines=num_input_pipelines,
274        input_pipeline_id=input_pipeline_id,
275        num_replicas_in_sync=self._num_replicas_in_sync)
276    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
277                                           [input_context])
279  def _experimental_make_numpy_dataset(self, numpy_input, session):
280    return numpy_dataset.one_host_numpy_dataset(
281        numpy_input, self._input_host_device, session)
283  def _broadcast_to(self, tensor, destinations):
284    # This is both a fast path for Python constants, and a way to delay
285    # converting Python values to a tensor until we know what type it
286    # should be converted to. Otherwise we have trouble with:
287    #   global_step.assign_add(1)
288    # since the `1` gets broadcast as an int32 but global_step is int64.
289    if isinstance(tensor, (float, int)):
290      return tensor
291    if not cross_device_ops_lib.check_destinations(destinations):
292      # TODO(josh11b): Use current logical device instead of 0 here.
293      destinations = values.LogicalDeviceSpec(
294          device_map=self._device_map, logical_device=0)
295    return self._cross_device_ops.broadcast(tensor, destinations)
297  def _allow_variable_partition(self):
298    return not context.executing_eagerly()
300  # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
301  # this creator, such as "MutableHashTable".
302  def _create_variable(self, next_creator, *args, **kwargs):
303    if self._num_replicas_in_sync > 1:
304      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
305      if aggregation not in (
306          vs.VariableAggregation.NONE,
307          vs.VariableAggregation.SUM,
308          vs.VariableAggregation.MEAN,
309          vs.VariableAggregation.ONLY_FIRST_REPLICA
310      ):
311        raise ValueError("Invalid variable aggregation mode: " + aggregation +
312                         " for variable: " + kwargs["name"])
314      def var_creator(*args, **kwargs):
315        """Create an AggregatingVariable and fix up collections."""
316        # Record what collections this variable should be added to.
317        collections = kwargs.pop("collections", None)
318        if collections is None:
319          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
320        kwargs["collections"] = []
322        # Create and wrap the variable.
323        v = next_creator(*args, **kwargs)
324        wrapped = values.AggregatingVariable(
325            self._container_strategy(), v, aggregation)
327        # Add the wrapped variable to the requested collections.
328        # The handling of eager mode and the global step matches
329        # ResourceVariable._init_from_args().
330        if not context.executing_eagerly():
331          g = ops.get_default_graph()
332          # If "trainable" is True, next_creator() will add the contained
333          # variable to the TRAINABLE_VARIABLES collection, so we manually
334          # remove it and replace with the wrapper. We can't set "trainable"
335          # to False for next_creator() since that causes functions like
336          # implicit_gradients to skip those variables.
337          if kwargs.get("trainable", True):
338            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
339            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
340            if v in l:
341              l.remove(v)
342          g.add_to_collections(collections, wrapped)
343        elif ops.GraphKeys.GLOBAL_STEP in collections:
344          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
346        return wrapped
347    else:
348      var_creator = next_creator
350    if "colocate_with" in kwargs:
351      colocate_with = kwargs["colocate_with"]
352      if isinstance(colocate_with, numpy_dataset.SingleDevice):
353        with ops.device(colocate_with.device):
354          return var_creator(*args, **kwargs)
355      with ops.device(None):
356        with ops.colocate_with(colocate_with):
357          return var_creator(*args, **kwargs)
359    with ops.colocate_with(None, ignore_existing=True):
360      with ops.device(self._variable_device):
361        return var_creator(*args, **kwargs)
363  def _call_for_each_replica(self, fn, args, kwargs):
364    # pylint: disable=protected-access
365    return mirrored_strategy._call_for_each_replica(
366        self._container_strategy(), self._device_map, fn, args, kwargs)
368  def _verify_destinations_not_different_worker(self, destinations):
369    if not self._cluster_spec:
370      return
371    if destinations is None:
372      return
373    for d in cross_device_ops_lib.get_devices_from(destinations):
374      d_spec = tf_device.DeviceSpec.from_string(d)
375      if d_spec.job == self._task_type and d_spec.task != self._task_id:
376        raise ValueError(
377            "Cannot reduce to another worker: %r, current worker is %r" %
378            (d, self._input_workers.worker_devices[0]))
380  def _reduce_to(self, reduce_op, value, destinations):
381    self._verify_destinations_not_different_worker(destinations)
382    if not isinstance(value, values.DistributedValues):
383      # pylint: disable=protected-access
384      return cross_device_ops_lib.reduce_non_distributed_value(
385          reduce_op, self._device_map, value, destinations)
386    return self._cross_device_ops.reduce(
387        reduce_op, value, destinations=destinations)
389  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
390    for _, destinations in value_destination_pairs:
391      self._verify_destinations_not_different_worker(destinations)
392    return self._cross_device_ops.batch_reduce(reduce_op,
393                                               value_destination_pairs)
395  def _select_single_value(self, structured):
396    """Select any single values in `structured`."""
398    def _select_fn(x):  # pylint: disable=g-missing-docstring
399      if isinstance(x, values.Mirrored):
400        if len(x.devices) == 1:
401          return x.primary
402        else:
403          raise ValueError(
404              "You cannot update variable with a Mirrored object with multiple "
405              "components %r when using ParameterServerStrategy. You must "
406              "specify a single value or a Mirrored with a single value." % x)
407      elif isinstance(x, values.PerReplica):
408        raise ValueError(
409            "You cannot update variable with a PerReplica object %r when using "
410            "ParameterServerStrategy. You must specify a single value or a "
411            "Mirrored with a single value" % x)
412      else:
413        return x
415    return nest.map_structure(_select_fn, structured)
417  def _update(self, var, fn, args, kwargs, group):
418    if isinstance(var, values.AggregatingVariable):
419      var = var.get()
420    if not isinstance(var, resource_variable_ops.ResourceVariable):
421      raise ValueError(
422          "You can not update `var` %r. It must be a Variable." % var)
423    with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
424      result = fn(var, *self._select_single_value(args),
425                  **self._select_single_value(kwargs))
426      if group:
427        return result
428      else:
429        return nest.map_structure(self._local_results, result)
431  # TODO(yuefengz): does it need to call _select_single_value?
432  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
433    with ops.device(
434        colocate_with.device), distribute_lib.UpdateContext(colocate_with):
435      result = fn(*args, **kwargs)
436      if group:
437        return result
438      else:
439        return nest.map_structure(self._local_results, result)
441  def _local_results(self, val):
442    if isinstance(val, values.DistributedValues):
443      return val.values
444    return (val,)
446  def value_container(self, val):
447    if (hasattr(val, "_aggregating_container") and
448        not isinstance(val, values.AggregatingVariable)):
449      wrapper = val._aggregating_container()  # pylint: disable=protected-access
450      if wrapper is not None:
451        return wrapper
452    return val
454  def read_var(self, var):
455    # No need to distinguish between normal variables and replica-local
456    # variables.
457    return array_ops.identity(var)
459  def _configure(self,
460                 session_config=None,
461                 cluster_spec=None,
462                 task_type=None,
463                 task_id=None):
464    """Configures the strategy class.
466    The strategy object will be re-initialized if `cluster_spec` is given but
467    was not passed in the constructor.
469    Args:
470      session_config: not used currently.
471      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
472        cluster configurations.
473      task_type: the current task type.
474      task_id: the current task id.
476    Raises:
477      ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
478        not.
479    """
480    if cluster_spec:
481      # Use the num_gpus_per_worker recorded in constructor since _configure
482      # doesn't take num_gpus.
483      cluster_resolver = SimpleClusterResolver(
484          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
485          task_type=task_type,
486          task_id=task_id,
487          num_accelerators={"GPU": self._num_gpus_per_worker})
488      self._initialize_multi_worker(cluster_resolver)
490    if session_config:
491      session_config.CopyFrom(self._update_config_proto(session_config))
493  def _update_config_proto(self, config_proto):
494    updated_config = copy.deepcopy(config_proto)
495    if not self._cluster_spec:
496      updated_config.isolate_session_state = True
497      return updated_config
499    updated_config.isolate_session_state = False
501    assert self._task_type
502    assert self._task_id is not None
504    # The device filters prevent communication between workers.
505    if self._task_type not in ["chief", "worker"]:
506      return updated_config
507    del updated_config.device_filters[:]
508    updated_config.device_filters.extend(
509        ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
510    return updated_config
512  @property
513  def _num_replicas_in_sync(self):
514    return self._device_map.num_replicas_in_graph
516  @property
517  def worker_devices(self):
518    return self._device_map.all_devices
520  @property
521  def worker_devices_by_replica(self):
522    return self._device_map.devices_by_replica
524  @property
525  def parameter_devices(self):
526    return self._parameter_devices
528  def non_slot_devices(self, var_list):
529    return min(var_list, key=lambda x: x.name)
531  @property
532  def experimental_between_graph(self):
533    # TODO(yuefengz): Should this return False in the local case?
534    return True
536  @property
537  def experimental_should_init(self):
538    return self._is_chief
540  @property
541  def should_checkpoint(self):
542    return self._is_chief
544  @property
545  def should_save_summary(self):
546    return self._is_chief
548  # TODO(priyag): Delete this once all strategies use global batch size.
549  @property
550  def _global_batch_size(self):
551    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
553    `make_input_fn_iterator` assumes per-replica batching.
555    Returns:
556      Boolean.
557    """
558    return True