1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
16
17Additionally it provides `generate()`, `combine()` and `times()` with
18`tf.distribute.Strategy` customizations as a default.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import copy
27import re
28import sys
29import types
30import unittest
31
32import six
33
34from tensorflow.python.client import session
35from tensorflow.python.distribute import collective_all_reduce_strategy
36from tensorflow.python.distribute import distribute_lib
37from tensorflow.python.distribute import multi_process_runner
38from tensorflow.python.distribute import multi_worker_test_base
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.framework import combinations as framework_combinations
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import test_combinations as combinations_lib
44from tensorflow.python.framework import test_util
45from tensorflow.python.platform import flags
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.util import tf_decorator
48from tensorflow.python.util import tf_inspect
49from tensorflow.python.util.tf_export import tf_export
50
51
52# TODO(rchao): Rename `distribution` parameter to `strategy` or
53# `distribute_strategy` in all tests.
54class DistributionParameter(combinations_lib.ParameterModifier):
55  """Transforms arguments of type `NamedDistribution`.
56
57  Convert all arguments of type `NamedDistribution` to the value of their
58  `strategy` property.
59  """
60
61  def modified_arguments(self, kwargs, requested_parameters):
62    # Get the parameter that indicates if we need to set the `_use_policy` flag
63    # on the strategy object. This is a temporary flag for testing the variable
64    # policy rollout.
65    use_var_policy = kwargs.get("use_var_policy", None)
66    distribution_arguments = {}
67    for k, v in kwargs.items():
68      if isinstance(v, NamedDistribution):
69        strategy = v.strategy
70        if use_var_policy:
71          strategy.extended._use_var_policy = use_var_policy
72        distribution_arguments[k] = strategy
73    return distribution_arguments
74
75
76class ClusterParameters(combinations_lib.ParameterModifier):
77  """Adds cluster parameters if a `NamedDistribution` has it.
78
79  It needs to be before DistributionParameter.
80  """
81
82  def modified_arguments(self, kwargs, requested_parameters):
83    strategy = None
84    for _, v in kwargs.items():
85      if isinstance(v, NamedDistribution):
86        if strategy is not None and _num_total_workers(v.has_chief,
87                                                       v.num_workers) > 1:
88          raise ValueError("Only support one NamedDistribution for multi worker"
89                           "tests.")
90        strategy = v
91
92    if strategy:
93      has_chief = strategy.has_chief
94      num_workers = strategy.num_workers
95      runner = strategy.runner
96      if "has_chief" in kwargs and kwargs["has_chief"] != has_chief:
97        raise ValueError(
98            "both has_chief and strategy specified but are not compatible")
99      if "num_workers" in kwargs and kwargs["num_workers"] != num_workers:
100        raise ValueError(
101            "both num_workers and strategy specified but are not compatible")
102    else:
103      has_chief = kwargs.get("has_chief", False)
104      num_workers = kwargs.get("num_workers", 1)
105      runner = kwargs.get("runner", None)
106
107    # Always set cluster parameters if they're requested. So that generate()
108    # works when there's no startegy in the combinations.
109    update = {}
110    if "has_chief" in requested_parameters:
111      update["has_chief"] = has_chief
112    if "num_workers" in requested_parameters:
113      update["num_workers"] = num_workers
114    if "runner" in requested_parameters:
115      update["runner"] = runner
116    return update
117
118
119class DistributionCombination(combinations_lib.TestCombination):
120  """Sets up distribution strategy for tests."""
121
122  def should_execute_combination(self, kwargs):
123    distributions = [
124        v for v in kwargs.values() if isinstance(v, NamedDistribution)
125    ]
126    if test_util.is_xla_enabled() and any(d.no_xla for d in distributions):
127      return (
128          False,
129          "n/a: skipping strategy combination with no_xla=True in XLA tests")
130    return (True, None)
131
132  def parameter_modifiers(self):
133    return [
134        DistributionParameter(),
135        combinations_lib.OptionalParameter("use_var_policy"),
136    ]
137
138
139class ClusterCombination(combinations_lib.TestCombination):
140  """Sets up multi worker tests."""
141
142  def parameter_modifiers(self):
143    return [ClusterParameters()]
144
145
146class GPUCombination(combinations_lib.TestCombination):
147  """Enable tests to request GPU hardware and skip non-GPU combinations.
148
149  This class expects test_combinations to be generated with `NamedDistribution`
150  wrapping instances of `tf.distribute.Strategy`.
151
152  Optionally, the `required_gpus` argument is supported.  GPU hardware is
153  required, if its value is `True` or > 0.
154
155  Attributes:
156    GPU_TEST: The environment is considered to have GPU hardware available if
157              the name of the program contains "test_gpu" or "test_xla_gpu".
158  """
159
160  GPU_TEST = re.search(r"(test_gpu|test_xla_gpu)$", sys.argv[0])
161
162  def should_execute_combination(self, kwargs):
163    distributions = [
164        v for v in kwargs.values() if isinstance(v, NamedDistribution)
165    ]
166    required_gpus = kwargs.get("required_gpus", None)
167
168    if distributions and required_gpus:
169      raise ValueError("Do not use `required_gpus` and arguments of type "
170                       "NamedDistribution together.")
171
172    number_of_required_gpus = max([required_gpus or 0] +
173                                  [d.required_gpus or 0 for d in distributions])
174
175    if not number_of_required_gpus and GPUCombination.GPU_TEST:
176      return (False, "Test that doesn't require GPUs.")
177    elif (number_of_required_gpus > 0
178          and context.num_gpus() < number_of_required_gpus):
179      return (False, ("Only {} of {} required GPUs are available.".format(
180          context.num_gpus(), number_of_required_gpus)))
181    else:
182      return (True, None)
183
184  def parameter_modifiers(self):
185    return [combinations_lib.OptionalParameter("required_gpus")]
186
187
188class TPUCombination(combinations_lib.TestCombination):
189  """Allow to request TPU hardware and skip non-TPU combinations.
190
191  This class expects test_combinations to be generated with `NamedDistribution`
192  wrapping instances of `tf.distribute.Strategy`.
193
194  Optionally, the `required_tpus` parameter is supported.  TPU hardware is
195  required, if its argument is `True` or > 0.
196
197  Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is
198  required by `required_tpus`, it specifically must be a Cloud TPU (specified
199  with `--tpu`) if `use_cloud_tpu` is `True`.
200
201  Attributes:
202    TPU_TEST: The environment is considered to have TPU hardware available if
203              the name of the program contains "test_tpu".
204  """
205
206  TPU_TEST = "test_tpu" in sys.argv[0]
207
208  def should_execute_combination(self, kwargs):
209    distributions = [
210        v for v in kwargs.values() if isinstance(v, NamedDistribution)
211    ]
212    # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor
213    # of 'required_tpus'.
214    if "required_tpus" in kwargs and "required_tpu" in kwargs:
215      raise ValueError("Do not use `required_tpu`.  Both `required_tpus` and "
216                       "`required_tpu` were specified.")
217    required_tpus = kwargs.get("required_tpus", None) or kwargs.get(
218        "required_tpu", None)
219
220    if distributions and required_tpus:
221      raise ValueError("Do not use `required_tpus` and arguments of type "
222                       "NamedDistribution together.")
223
224    # TODO(isaprykin): Add support for a particular number of TPUs.  Right now
225    # it's binary.
226    number_of_required_tpus = max([required_tpus or 0] +
227                                  [d.required_tpu or 0 for d in distributions])
228    use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] +
229                        [d.use_cloud_tpu for d in distributions])
230    tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or ""
231
232    if not number_of_required_tpus and TPUCombination.TPU_TEST:
233      return (False, "Test that doesn't require TPUs.")
234    if number_of_required_tpus and not TPUCombination.TPU_TEST:
235      return (False, "Test requires a TPU, but it's not available.")
236    if use_cloud_tpu and not tpu:
237      return (False, "Test requires a Cloud TPU, but none specified.")
238    if not use_cloud_tpu and tpu:
239      return (False, "Test requires local TPU, but Cloud TPU specified.")
240    return (True, None)
241
242  def parameter_modifiers(self):
243    return [
244        combinations_lib.OptionalParameter("required_tpus"),
245        combinations_lib.OptionalParameter("required_tpu"),
246        combinations_lib.OptionalParameter("use_cloud_tpu"),
247    ]
248
249
250class NamedDistribution(object):
251  """Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
252
253  def __init__(self,
254               name,
255               distribution_fn,
256               required_gpus=None,
257               required_tpu=False,
258               use_cloud_tpu=False,
259               has_chief=False,
260               num_workers=1,
261               pool_runner_fn=None,
262               no_xla=False):
263    """Initialize NamedDistribution.
264
265    Args:
266      name: Name that will be a part of the name of the test case.
267      distribution_fn: A callable that creates a `tf.distribute.Strategy`.
268      required_gpus: The number of GPUs that the strategy requires.
269      required_tpu: Whether the strategy requires TPU.
270      use_cloud_tpu: Whether the strategy requires cloud TPU.
271      has_chief: Whether the strategy requires a chief worker.
272      num_workers: The number of workers that the strategy requires.
273      pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner
274        to run the test.
275      no_xla: Whether to skip in XLA tests.
276    """
277    object.__init__(self)
278    self._name = name
279    self._distribution_fn = distribution_fn
280    self.required_gpus = required_gpus
281    self.required_tpu = required_tpu
282    self.use_cloud_tpu = use_cloud_tpu
283    self.has_chief = has_chief
284    self.num_workers = num_workers
285    self._pool_runner_fn = pool_runner_fn
286    self.no_xla = no_xla
287
288  @property
289  def runner(self):
290    if self._pool_runner_fn is not None:
291      return self._pool_runner_fn()
292    return None
293
294  @property
295  def strategy(self):
296    return self._distribution_fn()
297
298  def __repr__(self):
299    return self._name
300
301
302# This is to allow adding combinations that runs a function both as a
303# tf.function and eagerly.
304#
305# @combinations.generate(
306#   combinations.combine(
307#     tf_function = [combinations.tf_function, combinations.no_tf_function]
308#   )
309# )
310# def testXXX(tf_function):
311#   @tf_function
312#   def foo():
313#     tf.add(1., 1.)
314#
315#   foo()
316tf_function = combinations_lib.NamedObject("TfFunction", def_function.function)
317no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f)
318
319
320def concat(*combined):
321  """Concats combinations."""
322  result = []
323  for one in combined:
324    result += one
325  return result
326
327
328@tf_export("__internal__.distribute.combinations.generate", v1=[])
329def generate(combinations, test_combinations=()):
330  # pylint: disable=g-doc-args,g-doc-return-or-yield
331  """Distributed adapter of `tf.__internal__.test.combinations.generate`.
332
333  All tests with distributed strategy should use this one instead of
334  `tf.__internal__.test.combinations.generate`. This function has support of
335  strategy combinations, GPU/TPU and multi worker support.
336
337  See `tf.__internal__.test.combinations.generate` for usage.
338  """
339  # pylint: enable=g-doc-args,g-doc-return-or-yield
340  default_combinations = (
341      framework_combinations.EagerGraphCombination(),
342      framework_combinations.TFVersionCombination(),
343      ClusterCombination(),
344      DistributionCombination(),
345      GPUCombination(),
346      TPUCombination(),
347  )
348  # We apply our own decoration to handle multi worker tests before applying
349  # framework.test_combinations.generate. The order is important since we need
350  # framework.test_combinations.generate to apply all parameter modifiers first.
351  combination_decorator = combinations_lib.generate(
352      combinations, test_combinations=default_combinations + test_combinations)
353
354  def decorator(test_method_or_class):
355    if isinstance(test_method_or_class, type):
356      # If it's a test class.
357      class_object = test_method_or_class
358      # Decorate each test method with _multi_worker_test.
359      for name, test_method in six.iteritems(class_object.__dict__.copy()):
360        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
361            isinstance(test_method, types.FunctionType)):
362          setattr(class_object, name, _multi_worker_test(test_method))
363      return combination_decorator(class_object)
364    else:
365      return combination_decorator(_multi_worker_test(test_method_or_class))
366
367  return decorator
368
369
370combine = combinations_lib.combine
371times = combinations_lib.times
372NamedObject = combinations_lib.NamedObject
373
374
375# Identifies whether we're in the main process or worker processes.
376# `_multi_worker_test` decoration behaves differently in the main processs and
377# the worker processes. See the documentation of _multi_worker_test for detail.
378_running_in_worker = False
379
380
381def in_main_process():
382  """Whether it's in the main test process.
383
384  This is normally used to prepare the test environment which should only happen
385  in the main process.
386
387  Returns:
388    A boolean.
389  """
390  return not _running_in_worker
391
392
393class TestEnvironment(object):
394
395  def __init__(self):
396    self.tf_data_service_dispatcher = None
397
398  def __setattr__(self, name, value):
399    if not in_main_process():
400      raise ValueError(
401          "combinations.env() should only be modified in the main process. "
402          "Condition your code on combinations.in_main_process().")
403    super().__setattr__(name, value)
404
405
406_env = TestEnvironment()
407
408
409def env():
410  """Returns the object holds the test environment information.
411
412  Tests should modifies this in the main process if needed, and it will be
413  passed to the worker processes each time a test case is ran.
414
415  Returns:
416    a TestEnvironment object.
417  """
418  return _env
419
420
421_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
422
423
424def _test_runner(test_id, test_env):
425  """Executes the test with the given test_id.
426
427  This is a simple wrapper around TestRunner to be used with
428  multi_process_runner. Similar to test.main(), but it executes only one test
429  specified by test_id and returns whether the test succeeds. If the test fails,
430  the function prints failures and errors to stdout.
431
432  Args:
433    test_id: TestCase.id()
434    test_env: a TestEnvironment object.
435
436  Returns:
437    A boolean indicates whether the test succeeds.
438  """
439  global _running_in_worker, _env
440  # No need to restore the value of _running_in_worker since it should always be
441  # True in worker processes.
442  _running_in_worker = True
443  _env = test_env
444  test = unittest.defaultTestLoader.loadTestsFromName(test_id)
445  runner = unittest.TextTestRunner()
446  result = runner.run(test)
447  # Treat expected failures as failures, so that the main process can get
448  # them and fail as expected. Also treat errors as failures to simplify the
449  # handling.
450  failures = result.failures + result.expectedFailures + result.errors
451  if failures:
452    ret = _TestResult(status="failure", message=failures[0][1])
453  elif result.skipped:
454    ret = _TestResult(status="skipped", message=result.skipped[0][1])
455  else:
456    # Treat unexpectedSuccesses as OK so that the test case in the main process
457    # succeed as well.
458    ret = _TestResult(status="ok", message=None)
459  # Print tracebacks to stdout and multi_process_runner will collect
460  # them and stream back to the main process.
461  if ret.message:
462    print(ret.message)
463  return ret
464
465
466def _multi_worker_test(test_method):
467  """Decorate test_method so that it runs in each worker.
468
469  We use `multi_process_runner` to simulate multiple workers. Since we run the
470  this function in the main process and all worker processes, this decoration
471  behaves differently in the main process and worker procssses. In the main
472  process, it spawns subprocesses and runs the test on each of them; in a worker
473  process, it executes test in the same way as a normal test, e.g.
474  setUp()/tearDown() are called before/after the test.
475
476  Args:
477    test_method: a function which must be a test method.
478
479  Returns:
480    Decorated `test_method`. Note that the decorated function has additional
481    arguments.
482  """
483
484  def decorator(self, has_chief, num_workers, runner, **kwargs):
485    if _num_total_workers(has_chief, num_workers) == 1 or _running_in_worker:
486      # We're in worker process or the test is for single worker. Either case we
487      # execute the test method directly instead of spawning subprocesses.
488
489      # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a
490      # session that connects to the local server. This is necessary for multi
491      # worker graph mode tests to work. Those tests cannot use their graphs or
492      # sessions, including the one returned by self.cached_session(). Since
493      # existing tests may already be doing so, we only install the session for
494      # multi worker tests.
495      with _multi_worker_session(kwargs):
496        test_method(self, **kwargs)
497      return
498
499    # We're in the main process. We spawn subprocesses and run the *test* on
500    # each of them. Note that we're not directly executing test_method passed to
501    # _multi_worker_test, because we need setUp()/tearDown() to be called and
502    # all the decorations on the test method. The conceptual call stack is:
503    #   [main process]test.main()
504    #     [main process]test_runner.run(test)
505    #       [main process]wrapper by combinations.generate()
506    #         [main process]_multi_worker_test.decorator()
507    #           # A sub process goes through the same code path as the main
508    #           # process.
509    #           [sub process]_test_runner()
510    #             [sub process]test_runner.run(test)
511    #               [sub process]wrapper by combinations.generate()
512    #                 [sub process]_multi_worker_test.decorator()
513    #                   # _running_in_worker is True
514    #                   [sub process]test_method()
515    test_id = self.id()
516    if runner:
517      results = runner.run(_test_runner, args=(test_id, _env))
518    else:
519      cluster_spec = multi_worker_test_base.create_cluster_spec(
520          has_chief=has_chief,
521          num_workers=num_workers,
522          num_ps=0,
523          has_eval=False)
524      results = multi_process_runner.run(
525          _test_runner, cluster_spec, args=(test_id, _env)).return_value
526
527    skip_reason = None
528    for result in results:
529      if result.status == "failure":
530        # We can't tell which worker the return value come from, so we fail on
531        # the  first error.
532        self.fail(result.message)
533        break
534      elif result.status == "skipped":
535        # Record the skip reason, but do not actually skip the test in case some
536        # processes fail instead.
537        skip_reason = result.message
538    if skip_reason is not None:
539      self.skipTest(skip_reason)
540
541  argspec = tf_inspect.getfullargspec(test_method)
542  decorator_args = (argspec.args or []) + ["has_chief", "num_workers", "runner"]
543  decorator_argspec = argspec._replace(args=decorator_args)
544  return tf_decorator.make_decorator(
545      test_method, decorator, decorator_argspec=decorator_argspec)
546
547
548def _num_total_workers(has_chief, num_workers):
549  """Returns the number of workers including the chief."""
550  if has_chief:
551    return num_workers + 1
552  return num_workers
553
554
555def _multi_worker_session(kwargs):
556  """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy.
557
558  Args:
559    kwargs: a dict. Keyword arguments passed to the test.
560
561  Returns:
562    A context manager. If MultiWorkerMirroredStrategy is the  one and only one
563    strategy in kwargs and it's in graph mode, it's the seesion that is
564    configured for that strategy.  Otherwise, it's a no-op context manager.
565  """
566  strategy = None
567  for _, v in kwargs.items():
568    if isinstance(v, distribute_lib.StrategyBase):
569      if strategy is not None:
570        logging.warning(
571            "The test uses multiple strategies. Skipping "
572            "entering a session that is configured for the strategy.")
573        return ops.NullContextmanager()
574      strategy = v
575  if context.executing_eagerly() or not isinstance(
576      strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy):
577    return ops.NullContextmanager()
578  sess_config = copy.deepcopy(context.context().config)
579  sess_config = strategy.update_config_proto(sess_config)
580  target = strategy.cluster_resolver.master()
581  return session.Session(config=sess_config, target=target).as_default()
582