1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy combinations for combinations.combine()."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import tf2
22from tensorflow.python.distribute import central_storage_strategy
23from tensorflow.python.distribute import cluster_resolver
24from tensorflow.python.distribute import collective_all_reduce_strategy
25from tensorflow.python.distribute import combinations
26from tensorflow.python.distribute import distribution_strategy_context
27from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
28from tensorflow.python.distribute import multi_process_runner
29from tensorflow.python.distribute import multi_worker_test_base
30from tensorflow.python.distribute import one_device_strategy as one_device_lib
31from tensorflow.python.distribute import test_util
32from tensorflow.python.distribute import tpu_strategy as tpu_lib
33from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
34from tensorflow.python.eager import context
35from tensorflow.python.eager import remote
36from tensorflow.python.platform import flags
37from tensorflow.python.tpu import device_assignment as device_assignment_lib
38from tensorflow.python.tpu import tpu_strategy_util
39from tensorflow.python.util.tf_export import tf_export
40
41_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations."
42
43_did_connect_to_cluster = False
44_topology = None
45CollectiveAllReduceExtended = (
46    collective_all_reduce_strategy.CollectiveAllReduceExtended)
47
48
49def _version_chooser(tf1_cls, tf2_cls):
50
51  def creator(*args, **kwargs):
52    if tf2.enabled():
53      return tf2_cls(*args, **kwargs)
54    return tf1_cls(*args, **kwargs)
55
56  return creator
57
58
59MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1,
60                                    mirrored_lib.MirroredStrategy)
61CentralStorageStrategy = _version_chooser(
62    central_storage_strategy.CentralStorageStrategyV1,
63    central_storage_strategy.CentralStorageStrategy)
64OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1,
65                                     one_device_lib.OneDeviceStrategy)
66# Only V2 CollectiveAllReduceStrategy combinations are supported.
67CollectiveAllReduceStrategy = (
68    collective_all_reduce_strategy.CollectiveAllReduceStrategy)
69
70
71# pylint: disable=missing-docstring
72def _get_tpu_strategy_creator(steps_per_run,
73                              use_single_core=False,
74                              enable_packed_variable=False,
75                              **kwargs):
76
77  def _create_tpu_strategy():
78    FLAGS = flags.FLAGS  # pylint: disable=invalid-name
79    global _did_connect_to_cluster
80    global _topology
81
82    try:
83      # Attempt to locally discover the TPU. This will fail for Cloud TPU, in
84      # which case we fall back to the values passed as flags.
85      resolver = tpu_cluster_resolver.TPUClusterResolver()
86      did_automatically_resolve = True
87    except ValueError:
88      did_automatically_resolve = False
89
90      # These flags will be defined by tpu_test_wrapper.py.
91      resolver = tpu_cluster_resolver.TPUClusterResolver(
92          tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
93          zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
94          project=hasattr(FLAGS, "project") and FLAGS.project or None,
95      )
96
97    # Only connect once per process, rather than per test method.
98    if not _did_connect_to_cluster:
99      if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
100        remote.connect_to_cluster(resolver)
101        _did_connect_to_cluster = True
102      _topology = tpu_strategy_util.initialize_tpu_system(resolver)
103
104    device_assignment = None
105    if use_single_core:
106      device_assignment = device_assignment_lib.DeviceAssignment(
107          _topology,
108          core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
109
110    # Steps per run is only supported in TF 1.x
111    if tf2.enabled():
112      strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
113    else:
114      strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
115                                       device_assignment, **kwargs)
116    strategy._enable_packed_variable_in_eager_mode = enable_packed_variable  # pylint: disable=protected-access
117    return strategy
118
119  return _create_tpu_strategy
120
121
122def _mirrored_strategy_with_collective_key_base(devices):
123  mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
124  mirrored_lib.MirroredStrategy._collective_key_base += 100000
125  return MirroredStrategy(devices)
126
127
128def _get_multi_worker_mirrored_creator(required_gpus):
129
130  def _create_multi_worker_mirrored():
131    tf_config = cluster_resolver.TFConfigClusterResolver()
132    master = tf_config.master()
133    if tf_config.rpc_layer:
134      # Strip off the rpc_layer suffix.
135      master = master[len("%s://" % tf_config.rpc_layer):]
136    resolver = cluster_resolver.SimpleClusterResolver(
137        cluster_spec=tf_config.cluster_spec(),
138        task_type=tf_config.task_type,
139        task_id=tf_config.task_id,
140        master=master,
141        environment=tf_config.environment,
142        num_accelerators={"GPU": required_gpus},
143        rpc_layer=tf_config.rpc_layer or "grpc",
144    )
145    # Disable health check. We don't have a reliable to shutdown the strategy
146    # (and thus the health check) at the end of a test. Turning on health check
147    # causes some flakiness since we re-create part of the server when creating
148    # a strategy, and our tests are capable of handling failures.
149    CollectiveAllReduceExtended._enable_check_health = False  # pylint: disable=protected-access
150    # Always create the strategy in eager mode so that it starts the server and
151    # configures the eager context. The eager context can no longer be
152    # configured after initialization.
153    with context.eager_mode():
154      strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver)
155    # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
156    # collectives may hang if any worker launches collectives before the chief
157    # creates the strategy.
158    try:
159      multi_process_runner.get_barrier().wait()
160    except ValueError:
161      # If the creator is called in the main process,
162      # multi_process_runner.get_barrier() raises ValueError, which is safe to
163      # ignore.
164      pass
165    return strategy
166
167  return _create_multi_worker_mirrored
168
169
170def _deferred_pool_runner(has_chief, num_workers, initializer=None):
171  """Returns a callable that returns the pool runner.
172
173  It creates the pool runner only upon first invocation. This avoids creating it
174  when this file is imported.
175
176  Args:
177    has_chief: whether there should be a chief.
178    num_workers: the number of workers excluding the chief.
179    initializer: initializer of each process.
180
181  Returns:
182    A callable that returns the runner.
183  """
184
185  container = []
186
187  def get_or_create():
188    if not container:
189      cluster_spec = multi_worker_test_base.create_cluster_spec(
190          has_chief=has_chief,
191          num_workers=num_workers,
192          num_ps=0,
193          has_eval=False)
194      runner = multi_process_runner.MultiProcessPoolRunner(
195          cluster_spec, initializer=initializer)
196      container.append(runner)
197    return container[0]
198
199  return get_or_create
200
201
202# We need to create the strategy in the initializer to start the server before
203# any test runs.
204_two_worker_pool = _deferred_pool_runner(
205    has_chief=True,
206    num_workers=1,
207    initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
208_four_worker_pool = _deferred_pool_runner(
209    has_chief=True,
210    num_workers=3,
211    initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
212
213
214# pylint: disable=g-long-lambda
215default_strategy = combinations.NamedDistribution(
216    "Default",
217    distribution_strategy_context._get_default_strategy,  # pylint: disable=protected-access
218    required_gpus=None)
219one_device_strategy = combinations.NamedDistribution(
220    "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None)
221one_device_strategy_gpu = combinations.NamedDistribution(
222    "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1)
223one_device_strategy_on_worker_1 = combinations.NamedDistribution(
224    "OneDeviceOnWorker1CPU",
225    lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"),
226    required_gpus=None)
227one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
228    "OneDeviceOnWorker1GPU",
229    lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"),
230    required_gpus=1)
231tpu_strategy = combinations.NamedDistribution(
232    "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
233tpu_strategy_packed_var = combinations.NamedDistribution(
234    "TPUPackedVar",
235    _get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
236    required_tpu=True)
237tpu_strategy_one_step = combinations.NamedDistribution(
238    "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
239tpu_strategy_one_core = combinations.NamedDistribution(
240    "TPUOneCore",
241    _get_tpu_strategy_creator(steps_per_run=2, use_single_core=True),
242    required_tpu=True)
243tpu_strategy_one_step_one_core = combinations.NamedDistribution(
244    "TPUOneStepOneCore",
245    _get_tpu_strategy_creator(steps_per_run=1, use_single_core=True),
246    required_tpu=True)
247cloud_tpu_strategy = combinations.NamedDistribution(
248    "CloudTPU",
249    _get_tpu_strategy_creator(steps_per_run=2),
250    required_tpu=True,
251    use_cloud_tpu=True)
252mirrored_strategy_with_one_cpu = combinations.NamedDistribution(
253    "Mirrored1CPU",
254    lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"]))
255mirrored_strategy_with_one_gpu = combinations.NamedDistribution(
256    "Mirrored1GPU",
257    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]),
258    required_gpus=1)
259mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
260    "MirroredCPUAndGPU",
261    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]),
262    required_gpus=1)
263mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
264    "Mirrored2GPUs",
265    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]),
266    required_gpus=2)
267# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods.
268mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution(
269    "Mirrored2CPU",
270    lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"]))
271mirrored_strategy_with_cpu_1_and_2.__doc__ = (
272    """Mirrored strategy with 2 virtual CPUs.
273
274    Should set up logical devices before use
275    """)
276central_storage_strategy_with_two_gpus = combinations.NamedDistribution(
277    "CentralStorage2GPUs",
278    lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]),
279    required_gpus=2)
280central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
281    "CentralStorageCPUAndGPU",
282    lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]),
283    required_gpus=1)
284# chief + 1 worker, with CPU.
285multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution(
286    "MultiWorkerMirrored2x1CPU",
287    _get_multi_worker_mirrored_creator(required_gpus=0),
288    has_chief=True,
289    num_workers=1,
290    pool_runner_fn=_two_worker_pool,
291    no_xla=True,
292)
293# chief + 1 worker, with 1 GPU each.
294multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution(
295    "MultiWorkerMirrored2x1GPU",
296    _get_multi_worker_mirrored_creator(required_gpus=1),
297    has_chief=True,
298    num_workers=1,
299    required_gpus=1,
300    pool_runner_fn=_two_worker_pool,
301    no_xla=True,
302)
303# chief + 1 worker, with 2 GPU each.
304multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution(
305    "MultiWorkerMirrored2x2GPU",
306    _get_multi_worker_mirrored_creator(required_gpus=2),
307    has_chief=True,
308    num_workers=1,
309    required_gpus=2,
310    pool_runner_fn=_two_worker_pool,
311    no_xla=True,
312)
313# chief + 3 workers, with CPU.
314multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution(
315    "MultiWorkerMirrored4x1CPU",
316    _get_multi_worker_mirrored_creator(required_gpus=0),
317    has_chief=True,
318    num_workers=3,
319    pool_runner_fn=_four_worker_pool,
320    no_xla=True,
321)
322
323
324graph_and_eager_modes = ["graph", "eager"]
325
326
327# TODO(crccw): remove after tf-nightly picks up the new API.
328def set_virtual_cpus_to_at_least(num_virtual_cpus):
329  test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus)
330
331
332strategies_minus_tpu = [
333    default_strategy,
334    one_device_strategy,
335    one_device_strategy_gpu,
336    mirrored_strategy_with_gpu_and_cpu,
337    mirrored_strategy_with_two_gpus,
338    central_storage_strategy_with_gpu_and_cpu,
339]
340
341strategies_minus_default_and_tpu = [
342    one_device_strategy,
343    one_device_strategy_gpu,
344    mirrored_strategy_with_gpu_and_cpu,
345    mirrored_strategy_with_two_gpus,
346]
347
348tpu_strategies = [
349    tpu_strategy,  # steps_per_run=2
350    tpu_strategy_one_step,
351    tpu_strategy_packed_var,
352    cloud_tpu_strategy,
353]
354
355all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
356
357all_strategies = strategies_minus_tpu + tpu_strategies
358
359two_replica_strategies = [
360    mirrored_strategy_with_gpu_and_cpu,
361    mirrored_strategy_with_two_gpus,
362    multi_worker_mirrored_2x1_cpu,
363    multi_worker_mirrored_2x1_gpu,
364    tpu_strategy,  # steps_per_run=2
365    tpu_strategy_one_step,
366    central_storage_strategy_with_gpu_and_cpu,
367]
368
369four_replica_strategies = [
370    multi_worker_mirrored_2x2_gpu,
371    multi_worker_mirrored_4x1_cpu,
372]
373
374# TODO(b/159831907): replace with two_replica_strategies after the tests using
375# it work with MWMS.
376multidevice_strategies = [
377    mirrored_strategy_with_gpu_and_cpu,
378    mirrored_strategy_with_two_gpus,
379    tpu_strategy,  # steps_per_run=2
380    tpu_strategy_one_step
381]
382
383multiworker_strategies = [
384    multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu,
385    multi_worker_mirrored_2x2_gpu
386]
387
388
389def strategy_minus_tpu_combinations():
390  return combinations.combine(
391      distribution=strategies_minus_tpu, mode=["graph", "eager"])
392
393
394def tpu_strategy_combinations():
395  return combinations.combine(distribution=tpu_strategies, mode=["graph"])
396
397
398def all_strategy_combinations():
399  return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
400
401
402def all_strategy_minus_default_and_tpu_combinations():
403  return combinations.combine(
404      distribution=[
405          one_device_strategy, one_device_strategy_gpu,
406          mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
407      ],
408      mode=["graph", "eager"])
409
410
411def all_strategy_combinations_minus_default():
412  return (all_strategy_minus_default_and_tpu_combinations() +
413          tpu_strategy_combinations())
414
415
416tf_export(
417    _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu",
418    v1=[]).export_constant(__name__,
419                           "central_storage_strategy_with_gpu_and_cpu")
420tf_export(
421    _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus",
422    v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus")
423tf_export(
424    _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy",
425    v1=[]).export_constant(__name__, "cloud_tpu_strategy")
426tf_export(
427    _TF_INTERNAL_API_PREFIX + "default_strategy",
428    v1=[]).export_constant(__name__, "default_strategy")
429tf_export(
430    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2",
431    v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2")
432tf_export(
433    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu",
434    v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu")
435tf_export(
436    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu",
437    v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu")
438tf_export(
439    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu",
440    v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu")
441tf_export(
442    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus",
443    v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus")
444tf_export(
445    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu",
446    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu")
447tf_export(
448    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu",
449    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu")
450tf_export(
451    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu",
452    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu")
453tf_export(
454    _TF_INTERNAL_API_PREFIX + "one_device_strategy",
455    v1=[]).export_constant(__name__, "one_device_strategy")
456tf_export(
457    _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu",
458    v1=[]).export_constant(__name__, "one_device_strategy_gpu")
459tf_export(
460    _TF_INTERNAL_API_PREFIX + "tpu_strategy",
461    v1=[]).export_constant(__name__, "tpu_strategy")
462tf_export(
463    _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core",
464    v1=[]).export_constant(__name__, "tpu_strategy_one_core")
465tf_export(
466    _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var",
467    v1=[]).export_constant(__name__, "tpu_strategy_packed_var")
468