1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Fault tolerance test for parameter server training in TF2."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import gc
23import os
24import threading
25import time
26
27from tensorflow.python.compat import v2_compat
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.distribute import multi_process_runner
30from tensorflow.python.distribute import multi_worker_test_base
31from tensorflow.python.distribute import parameter_server_strategy_v2
32from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
33from tensorflow.python.distribute.coordinator import cluster_coordinator
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.eager import test
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import check_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import random_ops
44from tensorflow.python.ops import variables
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.training import coordinator as thread_coordinator
47from tensorflow.python.training import server_lib
48
49_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
50_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
51_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
52_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop"
53
54
55class Model(object):
56
57  def __init__(self, coordinator):
58    self.cluster_coord = coordinator
59    self.strategy = self.cluster_coord.strategy
60    with self.cluster_coord.strategy.scope():
61      self.build()
62
63  def build(self):
64    self.w = variables.Variable(
65        initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
66    self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
67    # Allow external control to make the model run its train_fn in an infinite
68    # loop. This allows us to reliably test worker preemption in the middle of
69    # function execution.
70    self.do_infinite_step = variables.Variable(False)
71
72    def dataset_fn():
73      data = random_ops.random_uniform((10, 10))
74      dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
75      return dataset
76
77    self.iterator = iter(
78        self.cluster_coord.create_per_worker_dataset(dataset_fn))
79
80  def _train_fn_internal(self, iterator):
81    x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
82    x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
83    self.w.assign_add(x)
84
85  @def_function.function
86  def train_fn(self, iterator):
87    self._train_fn_internal(iterator)
88    while self.do_infinite_step:
89      self._train_fn_internal(iterator)
90    self.iterations.assign_add(1)
91
92  def schedule_training_functions(self, num_steps):
93    with self.strategy.scope():
94      for _ in range(num_steps):
95        self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
96
97  def join_training_functions(self):
98    self.do_infinite_step.assign(False)
99    self.cluster_coord.join()
100
101
102class BaseFaultToleranceTest(object):  # pylint: disable=missing-docstring
103
104  def setUp(self, num_workers, num_ps):
105    super(BaseFaultToleranceTest, self).setUp()
106
107    # Set the environment variable to prevent hanging upon job failure and
108    # restart. Note that it defaults to 'use_caller' at Google, but defaults
109    # to False in OSS.
110    os.environ["GRPC_FAIL_FAST"] = "use_caller"
111
112    self._cluster = multi_worker_test_base.create_multi_process_cluster(
113        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
114    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
115    self._cluster_def["chief"] = [
116        "localhost:%d" % multi_worker_test_base.pick_unused_port()
117    ]
118    cluster_resolver = SimpleClusterResolver(
119        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")
120
121    # The strategy's constructor would connect to the cluster.
122    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
123        cluster_resolver)
124    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)
125
126    self.thread_coord = thread_coordinator.Coordinator(
127        clean_stop_exception_types=[])
128    self.num_workers = num_workers
129    self.num_ps = num_ps
130
131  def tearDown(self):
132    super(BaseFaultToleranceTest, self).tearDown()
133    self._cluster.stop()
134    self._cluster = None
135
136  def _restart(self, downtime_secs, job):
137    """Kills `job` (index: 0) and restarts it after `downtime_secs`.
138
139    Args:
140      downtime_secs: secs before restarting the job.
141      job: a string specifying the job to restart.
142    """
143    self._cluster.kill_task(job, 0)
144    time.sleep(downtime_secs)
145    self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job))
146    self._cluster.start_task(job, 0)
147    while not context.check_alive("/job:%s/replica:0/task:0" % job):
148      time.sleep(1)
149
150  def _restart_in_thread(self, downtime_secs, restart_job):
151
152    def _restart_fn():
153      with self.thread_coord.stop_on_exception():
154        self._restart(downtime_secs, restart_job)
155
156    restart_thread = threading.Thread(target=_restart_fn)
157    restart_thread.start()
158    return restart_thread
159
160  def _ensure_threads_closed(self):
161    """Ensures worker and preemption threads are closed."""
162
163    def _get_running_threads():
164      """Returns a set of all running thread names."""
165      running_threads = set()
166      for thread in threading.enumerate():
167        if thread.name is not None:
168          running_threads.add(thread.name)
169      return running_threads
170
171    def _has_thread(prefix, running_threads):
172      """Returns whether any 'running_threads' is prefixed with 'prefix'."""
173      for thread in running_threads:
174        if thread.startswith(prefix):
175          return True
176      return False
177
178    # Worker and preemption threads should exist before releasing
179    # ClusterCoordinator.
180    running_threads = _get_running_threads()
181    self.assertTrue(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
182    self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
183
184    # Wait for threads to close.
185    self.cluster_coord = None
186    self.strategy = None
187    gc.collect()
188    time.sleep(1)
189
190    # Verify thread names.
191    running_threads = _get_running_threads()
192    self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
193    self.assertFalse(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
194
195  def _create_model_and_run_indefinitely(self):
196    model = Model(self.cluster_coord)
197    model.do_infinite_step.assign(True)
198    model.schedule_training_functions(10)
199    # Model does infinite training step, so at this moment, we expect to have
200    # `self.num_workers` infinite closures inflight, and `10-self.num_workers`
201    # closures in the queue.
202    while (self.cluster_coord._cluster._closure_queue._inflight_closure_count <
203           self.num_workers):
204      time.sleep(0.1)
205    return model
206
207  def testClusterCoordinatorDestroyed(self):
208    self._ensure_threads_closed()
209
210  def testWorkerPreemptionBetweenFunctions(self):
211    model = Model(self.cluster_coord)
212    model.schedule_training_functions(2)
213    model.join_training_functions()
214    self.assertEqual(model.iterations.numpy(), 2)
215
216    self._restart(downtime_secs=2, job="worker")
217
218    model.schedule_training_functions(2)
219    model.join_training_functions()
220    self.assertEqual(model.iterations.numpy(), 4)
221
222  def testWorkerPreemptionMidstFunction(self):
223    model = Model(self.cluster_coord)
224    model.do_infinite_step.assign(True)
225
226    model.schedule_training_functions(4)
227    # Model does infinite training step, so at this moment, we expect to have
228    # `self.num_workers` infinite closures inflight, and `4-self.num_workers`
229    # closures in the queue.
230    while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
231           < self.num_workers):
232      time.sleep(0.1)
233    self.assertFalse(self.cluster_coord.done())
234    self._restart(downtime_secs=2, job="worker")
235    model.join_training_functions()
236    self.assertGreaterEqual(model.iterations.numpy(), 4)
237
238  def testOneWorkerPreemptionWithCancellation(self):
239
240    @def_function.function
241    def normal_function():
242      x = random_ops.random_uniform((2, 10))
243      y = random_ops.random_uniform((10, 2))
244      return math_ops.reduce_mean(math_ops.matmul(x, y))
245
246    @def_function.function
247    def error_function():
248      x = random_ops.random_uniform((2, 10))
249      y = random_ops.random_uniform((10, 2))
250      check_ops.assert_non_positive_v2(
251          math_ops.reduce_sum(math_ops.matmul(x, y)))
252      return x
253
254    @def_function.function
255    def long_function():
256      x = random_ops.random_uniform((1000, 1000))
257      for _ in math_ops.range(10000):
258        a = random_ops.random_uniform((1000, 1000))
259        b = random_ops.random_uniform((1000, 1000))
260        x += math_ops.matmul(a, b)
261      return x
262
263    for _ in range(3):
264      self.cluster_coord.schedule(normal_function)
265    long_function_result = self.cluster_coord.schedule(long_function)
266    self.cluster_coord.schedule(error_function)
267
268    time.sleep(1)  # Let it run a couple steps.
269    self._restart(1, "worker")
270
271    with self.assertRaises(errors.InvalidArgumentError):
272      self.cluster_coord.join()
273
274    with self.assertRaises(errors.CancelledError):
275      long_function_result.fetch()
276
277    for _ in range(3):
278      self.cluster_coord.schedule(normal_function)
279    self.cluster_coord.join()
280
281  def testHandleDatasetCreationFailure(self):
282    model = Model(self.cluster_coord)
283
284    restart_thread = self._restart_in_thread(5, "worker")
285
286    model.schedule_training_functions(3)
287    model.join_training_functions()
288
289    self.thread_coord.join([restart_thread])
290    self.assertGreaterEqual(model.iterations.numpy(), 3)
291
292  def testWorkerPreemptionErrorType(self):
293
294    @def_function.function
295    def worker_train_fn():
296      x = random_ops.random_uniform((2, 10))
297      y = random_ops.random_uniform((10, 2))
298      return math_ops.reduce_mean(math_ops.matmul(x, y))
299
300    def run_fn():
301      with self.thread_coord.stop_on_exception():
302        with ops.device("/job:worker/replica:0/task:0"):
303          for _ in range(3):
304            for _ in range(3):
305              worker_train_fn()
306            time.sleep(5)
307
308    run_thread = threading.Thread(target=run_fn)
309    run_thread.start()
310    time.sleep(1)  # Let it run a couple steps.
311    self._restart(2, "worker")
312
313    try:
314      self.thread_coord.join([run_thread])
315    except errors.UnavailableError as e:
316      logging.info("Got exception %r, error message is %s", e, e)
317
318      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
319      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
320
321      self.assertTrue("failed to connect to all addresses" in str(e) or
322                      "Unable to find a context_id" in str(e) or
323                      "Socket closed" in str(e) or
324                      "Connection reset by peer" in str(e) or
325                      "Transport closed" in str(e))
326
327  def testWorkerPreemptionErrorTypeWithPythonFunction(self):
328
329    def worker_train_fn():
330      x = random_ops.random_uniform((2, 10))
331      y = random_ops.random_uniform((10, 2))
332      return math_ops.reduce_mean(math_ops.matmul(x, y))
333
334    def run_fn():
335      with self.thread_coord.stop_on_exception():
336        with ops.device("/job:worker/replica:0/task:0"):
337          for _ in range(3):
338            for _ in range(3):
339              worker_train_fn()
340            time.sleep(5)
341
342    run_thread = threading.Thread(target=run_fn)
343    run_thread.start()
344    time.sleep(1)  # Let it run a couple steps.
345    self._restart(2, "worker")
346
347    try:
348      self.thread_coord.join([run_thread])
349    except errors.UnavailableError as e:
350      logging.info("Got exception %r, error message is %s", e, e)
351
352      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
353      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
354
355      self.assertTrue("failed to connect to all addresses" in str(e) or
356                      "Unable to find a context_id" in str(e) or
357                      "Socket closed" in str(e) or
358                      "Connection reset by peer" in str(e) or
359                      "Transport closed" in str(e))
360
361  def testPSPreemptionErrorType(self):
362
363    with ops.device("/job:ps/replica:0/task:0"):
364      v = variables.Variable(
365          initial_value=random_ops.random_uniform((2, 10)),
366          dtype=dtypes.float32)
367
368    @def_function.function
369    def worker_train_fn():
370      y = random_ops.random_uniform((10, 2))
371      return math_ops.reduce_mean(math_ops.matmul(v, y))
372
373    def run_fn():
374      with self.thread_coord.stop_on_exception():
375        with ops.device("/job:worker/replica:0/task:0"):
376          for _ in range(3):
377            for _ in range(3):
378              worker_train_fn()
379            time.sleep(5)
380
381    run_thread = threading.Thread(target=run_fn)
382    run_thread.start()
383    time.sleep(1)  # Let it run a couple steps.
384
385    # Use a short restart delay to cover the case that RPC channel is reused
386    self._restart(1, "ps")
387
388    try:
389      self.thread_coord.join([run_thread])
390    except (errors.UnavailableError, errors.AbortedError) as e:
391      logging.info("Got exception %r, error message is %s", e, e)
392      self.assertIn(_RPC_ERROR_FROM_PS, str(e))  # pylint: disable=g-assert-in-except
393
394      if isinstance(e, errors.UnavailableError):
395        self.assertTrue("failed to connect to all addresses" in str(e) or
396                        "Unable to find a context_id" in str(e) or
397                        "Socket closed" in str(e) or
398                        "Connection reset by peer" in str(e) or
399                        "Transport closed" in str(e))
400
401      if isinstance(e, errors.AbortedError):
402        self.assertIn("RecvTensor expects a different device incarnation",
403                      str(e))
404      self._ensure_threads_closed()
405
406  def testTwoWorkersPreempted(self):
407    if self.num_workers < 2:
408      self.skipTest("Worker number is less than 2.")
409    model = self._create_model_and_run_indefinitely()
410
411    self.assertFalse(self.cluster_coord.done())
412    self._cluster.kill_task("worker", 0)
413    self._cluster.kill_task("worker", 1)
414    time.sleep(2)
415    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
416    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
417    self._cluster.start_task("worker", 0)
418    self._cluster.start_task("worker", 1)
419    time.sleep(2)
420    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
421    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
422
423    model.join_training_functions()
424    self.assertGreaterEqual(model.iterations.numpy(), 10)
425
426  def testWorkerContinuousFailure(self):
427    model = self._create_model_and_run_indefinitely()
428
429    self.assertFalse(self.cluster_coord.done())
430    self._cluster.kill_task("worker", 0)
431    time.sleep(2)
432    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
433    self._cluster.start_task("worker", 0)
434    time.sleep(2)
435    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
436    self._cluster.kill_task("worker", 0)
437    time.sleep(2)
438    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
439    self._cluster.start_task("worker", 0)
440    time.sleep(2)
441    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
442
443    model.join_training_functions()
444    self.assertGreaterEqual(model.iterations.numpy(), 10)
445
446  def testNumpyFetchedAfterWorkerFailure(self):
447
448    with self.strategy.scope():
449      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
450
451    @def_function.function
452    def worker_fn():
453      return v + 1, v - 1
454
455    remote_value = self.cluster_coord.schedule(worker_fn)
456    # Attempt to fetch before killing worker task should succeed.
457    self.assertEqual((1, -1), remote_value.fetch())
458    self._cluster.kill_task("worker", 0)
459    # So should attempt to fetch after killing worker task.
460    self.assertEqual((1, -1), remote_value.fetch())
461
462  def testClusterStateNotDisrupted(self):
463    # This test has side effects and can disrupt other tests, even if the
464    # resource created by it will not be used in following tests.
465    # TODO(b/155209534): enable this test.
466    # self.testPSPreemptionErrorType()
467
468    self.thread_coord = thread_coordinator.Coordinator(
469        clean_stop_exception_types=[])
470    self.testWorkerPreemptionMidstFunction()
471
472    self.thread_coord = thread_coordinator.Coordinator(
473        clean_stop_exception_types=[])
474    self.testWorkerPreemptionErrorType()
475
476    # In previous tests, workers may fail after training is done. But the
477    # following tests start with creating resources where failure is not
478    # handled.
479    # TODO(b/153888707): enable the following two tests.
480    # self.testTwoWorkersPreempted()
481    # self.testWorkerContinuousFailure()
482
483  def testJoinRaisesUnavailableErrorAtPsFailure(self):
484    self._create_model_and_run_indefinitely()
485    self._cluster.kill_task("ps", 0)
486    while self.cluster_coord._cluster._closure_queue._error is None:
487      time.sleep(1)
488    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
489                            errors.FailedPreconditionError)):
490      self.cluster_coord.join()
491
492  def testScheduleRaisesUnavailableErrorAtPsFailure(self):
493    self._create_model_and_run_indefinitely()
494    self._cluster.kill_task("ps", 0)
495    while self.cluster_coord._cluster._closure_queue._error is None:
496      time.sleep(1)
497    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
498                            errors.FailedPreconditionError)):
499      self.cluster_coord.schedule(def_function.function(lambda: None))
500
501  def testWorkerExecutionAfterPsFailureRaisesExpectedError(self):
502    model = self._create_model_and_run_indefinitely()
503    for i in range(self.num_ps):
504      self._cluster.kill_task("ps", i)
505    while self.cluster_coord._cluster._closure_queue._error is None:
506      time.sleep(1)
507
508    @def_function.function
509    def trivial_function():
510      return model.iterations + 1
511
512    for i in range(self.num_workers):
513      try:
514        with ops.device("/job:worker/replica:0/task:{}".format(i)):
515          trivial_function()
516      except Exception as e:  # pylint: disable=broad-except
517        if cluster_coordinator._is_ps_failure(e):
518          if i < self.num_workers - 1:
519            continue
520          return
521      raise AssertionError("Executing a function after PS fails, should "
522                           "result in a PS failure.")
523
524
525class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
526  """Multi worker fault tolerance tests.
527
528  This covers the ordinary cases where multiple workers and PS are used.
529  """
530
531  def setUp(self):
532    super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
533
534
535class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
536  """Single worker fault tolerance tests.
537
538  This covers the cases that ensure training can continue in a single-worker
539  cluster, even if the only worker can become unavailable at some point and
540  recovered (if there are multiple workers, it is possible that the training
541  succeeds with the workers that did not fail). Realistically single worker
542  is very rarely used, but the tests are important to ensure the correct
543  behaviors.
544  """
545
546  def setUp(self):
547    super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
548
549
550if __name__ == "__main__":
551  v2_compat.enable_v2_behavior()
552  multi_process_runner.test_main()
553