1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Fault tolerance test for parameter server training in TF2.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import gc 23import os 24import threading 25import time 26 27from tensorflow.python.compat import v2_compat 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.distribute import multi_process_runner 30from tensorflow.python.distribute import multi_worker_test_base 31from tensorflow.python.distribute import parameter_server_strategy_v2 32from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 33from tensorflow.python.distribute.coordinator import cluster_coordinator 34from tensorflow.python.eager import context 35from tensorflow.python.eager import def_function 36from tensorflow.python.eager import test 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import ops 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import check_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import random_ops 44from tensorflow.python.ops import variables 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.training import coordinator as thread_coordinator 47from tensorflow.python.training import server_lib 48 49_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker" 50_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 51_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler" 52_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop" 53 54 55class Model(object): 56 57 def __init__(self, coordinator): 58 self.cluster_coord = coordinator 59 self.strategy = self.cluster_coord.strategy 60 with self.cluster_coord.strategy.scope(): 61 self.build() 62 63 def build(self): 64 self.w = variables.Variable( 65 initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32) 66 self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) 67 # Allow external control to make the model run its train_fn in an infinite 68 # loop. This allows us to reliably test worker preemption in the middle of 69 # function execution. 70 self.do_infinite_step = variables.Variable(False) 71 72 def dataset_fn(): 73 data = random_ops.random_uniform((10, 10)) 74 dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() 75 return dataset 76 77 self.iterator = iter( 78 self.cluster_coord.create_per_worker_dataset(dataset_fn)) 79 80 def _train_fn_internal(self, iterator): 81 x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) 82 x = math_ops.matmul(random_ops.random_uniform((10, 10)), x) 83 self.w.assign_add(x) 84 85 @def_function.function 86 def train_fn(self, iterator): 87 self._train_fn_internal(iterator) 88 while self.do_infinite_step: 89 self._train_fn_internal(iterator) 90 self.iterations.assign_add(1) 91 92 def schedule_training_functions(self, num_steps): 93 with self.strategy.scope(): 94 for _ in range(num_steps): 95 self.cluster_coord.schedule(self.train_fn, args=(self.iterator,)) 96 97 def join_training_functions(self): 98 self.do_infinite_step.assign(False) 99 self.cluster_coord.join() 100 101 102class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring 103 104 def setUp(self, num_workers, num_ps): 105 super(BaseFaultToleranceTest, self).setUp() 106 107 # Set the environment variable to prevent hanging upon job failure and 108 # restart. Note that it defaults to 'use_caller' at Google, but defaults 109 # to False in OSS. 110 os.environ["GRPC_FAIL_FAST"] = "use_caller" 111 112 self._cluster = multi_worker_test_base.create_multi_process_cluster( 113 num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") 114 self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() 115 self._cluster_def["chief"] = [ 116 "localhost:%d" % multi_worker_test_base.pick_unused_port() 117 ] 118 cluster_resolver = SimpleClusterResolver( 119 server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") 120 121 # The strategy's constructor would connect to the cluster. 122 self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 123 cluster_resolver) 124 self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) 125 126 self.thread_coord = thread_coordinator.Coordinator( 127 clean_stop_exception_types=[]) 128 self.num_workers = num_workers 129 self.num_ps = num_ps 130 131 def tearDown(self): 132 super(BaseFaultToleranceTest, self).tearDown() 133 self._cluster.stop() 134 self._cluster = None 135 136 def _restart(self, downtime_secs, job): 137 """Kills `job` (index: 0) and restarts it after `downtime_secs`. 138 139 Args: 140 downtime_secs: secs before restarting the job. 141 job: a string specifying the job to restart. 142 """ 143 self._cluster.kill_task(job, 0) 144 time.sleep(downtime_secs) 145 self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job)) 146 self._cluster.start_task(job, 0) 147 while not context.check_alive("/job:%s/replica:0/task:0" % job): 148 time.sleep(1) 149 150 def _restart_in_thread(self, downtime_secs, restart_job): 151 152 def _restart_fn(): 153 with self.thread_coord.stop_on_exception(): 154 self._restart(downtime_secs, restart_job) 155 156 restart_thread = threading.Thread(target=_restart_fn) 157 restart_thread.start() 158 return restart_thread 159 160 def _ensure_threads_closed(self): 161 """Ensures worker and preemption threads are closed.""" 162 163 def _get_running_threads(): 164 """Returns a set of all running thread names.""" 165 running_threads = set() 166 for thread in threading.enumerate(): 167 if thread.name is not None: 168 running_threads.add(thread.name) 169 return running_threads 170 171 def _has_thread(prefix, running_threads): 172 """Returns whether any 'running_threads' is prefixed with 'prefix'.""" 173 for thread in running_threads: 174 if thread.startswith(prefix): 175 return True 176 return False 177 178 # Worker and preemption threads should exist before releasing 179 # ClusterCoordinator. 180 running_threads = _get_running_threads() 181 self.assertTrue(_has_thread(_WORKER_THREAD_PREFIX, running_threads)) 182 self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 183 184 # Wait for threads to close. 185 self.cluster_coord = None 186 self.strategy = None 187 gc.collect() 188 time.sleep(1) 189 190 # Verify thread names. 191 running_threads = _get_running_threads() 192 self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 193 self.assertFalse(_has_thread(_WORKER_THREAD_PREFIX, running_threads)) 194 195 def _create_model_and_run_indefinitely(self): 196 model = Model(self.cluster_coord) 197 model.do_infinite_step.assign(True) 198 model.schedule_training_functions(10) 199 # Model does infinite training step, so at this moment, we expect to have 200 # `self.num_workers` infinite closures inflight, and `10-self.num_workers` 201 # closures in the queue. 202 while (self.cluster_coord._cluster._closure_queue._inflight_closure_count < 203 self.num_workers): 204 time.sleep(0.1) 205 return model 206 207 def testClusterCoordinatorDestroyed(self): 208 self._ensure_threads_closed() 209 210 def testWorkerPreemptionBetweenFunctions(self): 211 model = Model(self.cluster_coord) 212 model.schedule_training_functions(2) 213 model.join_training_functions() 214 self.assertEqual(model.iterations.numpy(), 2) 215 216 self._restart(downtime_secs=2, job="worker") 217 218 model.schedule_training_functions(2) 219 model.join_training_functions() 220 self.assertEqual(model.iterations.numpy(), 4) 221 222 def testWorkerPreemptionMidstFunction(self): 223 model = Model(self.cluster_coord) 224 model.do_infinite_step.assign(True) 225 226 model.schedule_training_functions(4) 227 # Model does infinite training step, so at this moment, we expect to have 228 # `self.num_workers` infinite closures inflight, and `4-self.num_workers` 229 # closures in the queue. 230 while (self.cluster_coord._cluster._closure_queue._inflight_closure_count 231 < self.num_workers): 232 time.sleep(0.1) 233 self.assertFalse(self.cluster_coord.done()) 234 self._restart(downtime_secs=2, job="worker") 235 model.join_training_functions() 236 self.assertGreaterEqual(model.iterations.numpy(), 4) 237 238 def testOneWorkerPreemptionWithCancellation(self): 239 240 @def_function.function 241 def normal_function(): 242 x = random_ops.random_uniform((2, 10)) 243 y = random_ops.random_uniform((10, 2)) 244 return math_ops.reduce_mean(math_ops.matmul(x, y)) 245 246 @def_function.function 247 def error_function(): 248 x = random_ops.random_uniform((2, 10)) 249 y = random_ops.random_uniform((10, 2)) 250 check_ops.assert_non_positive_v2( 251 math_ops.reduce_sum(math_ops.matmul(x, y))) 252 return x 253 254 @def_function.function 255 def long_function(): 256 x = random_ops.random_uniform((1000, 1000)) 257 for _ in math_ops.range(10000): 258 a = random_ops.random_uniform((1000, 1000)) 259 b = random_ops.random_uniform((1000, 1000)) 260 x += math_ops.matmul(a, b) 261 return x 262 263 for _ in range(3): 264 self.cluster_coord.schedule(normal_function) 265 long_function_result = self.cluster_coord.schedule(long_function) 266 self.cluster_coord.schedule(error_function) 267 268 time.sleep(1) # Let it run a couple steps. 269 self._restart(1, "worker") 270 271 with self.assertRaises(errors.InvalidArgumentError): 272 self.cluster_coord.join() 273 274 with self.assertRaises(errors.CancelledError): 275 long_function_result.fetch() 276 277 for _ in range(3): 278 self.cluster_coord.schedule(normal_function) 279 self.cluster_coord.join() 280 281 def testHandleDatasetCreationFailure(self): 282 model = Model(self.cluster_coord) 283 284 restart_thread = self._restart_in_thread(5, "worker") 285 286 model.schedule_training_functions(3) 287 model.join_training_functions() 288 289 self.thread_coord.join([restart_thread]) 290 self.assertGreaterEqual(model.iterations.numpy(), 3) 291 292 def testWorkerPreemptionErrorType(self): 293 294 @def_function.function 295 def worker_train_fn(): 296 x = random_ops.random_uniform((2, 10)) 297 y = random_ops.random_uniform((10, 2)) 298 return math_ops.reduce_mean(math_ops.matmul(x, y)) 299 300 def run_fn(): 301 with self.thread_coord.stop_on_exception(): 302 with ops.device("/job:worker/replica:0/task:0"): 303 for _ in range(3): 304 for _ in range(3): 305 worker_train_fn() 306 time.sleep(5) 307 308 run_thread = threading.Thread(target=run_fn) 309 run_thread.start() 310 time.sleep(1) # Let it run a couple steps. 311 self._restart(2, "worker") 312 313 try: 314 self.thread_coord.join([run_thread]) 315 except errors.UnavailableError as e: 316 logging.info("Got exception %r, error message is %s", e, e) 317 318 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 319 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 320 321 self.assertTrue("failed to connect to all addresses" in str(e) or 322 "Unable to find a context_id" in str(e) or 323 "Socket closed" in str(e) or 324 "Connection reset by peer" in str(e) or 325 "Transport closed" in str(e)) 326 327 def testWorkerPreemptionErrorTypeWithPythonFunction(self): 328 329 def worker_train_fn(): 330 x = random_ops.random_uniform((2, 10)) 331 y = random_ops.random_uniform((10, 2)) 332 return math_ops.reduce_mean(math_ops.matmul(x, y)) 333 334 def run_fn(): 335 with self.thread_coord.stop_on_exception(): 336 with ops.device("/job:worker/replica:0/task:0"): 337 for _ in range(3): 338 for _ in range(3): 339 worker_train_fn() 340 time.sleep(5) 341 342 run_thread = threading.Thread(target=run_fn) 343 run_thread.start() 344 time.sleep(1) # Let it run a couple steps. 345 self._restart(2, "worker") 346 347 try: 348 self.thread_coord.join([run_thread]) 349 except errors.UnavailableError as e: 350 logging.info("Got exception %r, error message is %s", e, e) 351 352 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 353 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 354 355 self.assertTrue("failed to connect to all addresses" in str(e) or 356 "Unable to find a context_id" in str(e) or 357 "Socket closed" in str(e) or 358 "Connection reset by peer" in str(e) or 359 "Transport closed" in str(e)) 360 361 def testPSPreemptionErrorType(self): 362 363 with ops.device("/job:ps/replica:0/task:0"): 364 v = variables.Variable( 365 initial_value=random_ops.random_uniform((2, 10)), 366 dtype=dtypes.float32) 367 368 @def_function.function 369 def worker_train_fn(): 370 y = random_ops.random_uniform((10, 2)) 371 return math_ops.reduce_mean(math_ops.matmul(v, y)) 372 373 def run_fn(): 374 with self.thread_coord.stop_on_exception(): 375 with ops.device("/job:worker/replica:0/task:0"): 376 for _ in range(3): 377 for _ in range(3): 378 worker_train_fn() 379 time.sleep(5) 380 381 run_thread = threading.Thread(target=run_fn) 382 run_thread.start() 383 time.sleep(1) # Let it run a couple steps. 384 385 # Use a short restart delay to cover the case that RPC channel is reused 386 self._restart(1, "ps") 387 388 try: 389 self.thread_coord.join([run_thread]) 390 except (errors.UnavailableError, errors.AbortedError) as e: 391 logging.info("Got exception %r, error message is %s", e, e) 392 self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except 393 394 if isinstance(e, errors.UnavailableError): 395 self.assertTrue("failed to connect to all addresses" in str(e) or 396 "Unable to find a context_id" in str(e) or 397 "Socket closed" in str(e) or 398 "Connection reset by peer" in str(e) or 399 "Transport closed" in str(e)) 400 401 if isinstance(e, errors.AbortedError): 402 self.assertIn("RecvTensor expects a different device incarnation", 403 str(e)) 404 self._ensure_threads_closed() 405 406 def testTwoWorkersPreempted(self): 407 if self.num_workers < 2: 408 self.skipTest("Worker number is less than 2.") 409 model = self._create_model_and_run_indefinitely() 410 411 self.assertFalse(self.cluster_coord.done()) 412 self._cluster.kill_task("worker", 0) 413 self._cluster.kill_task("worker", 1) 414 time.sleep(2) 415 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 416 self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) 417 self._cluster.start_task("worker", 0) 418 self._cluster.start_task("worker", 1) 419 time.sleep(2) 420 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 421 self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) 422 423 model.join_training_functions() 424 self.assertGreaterEqual(model.iterations.numpy(), 10) 425 426 def testWorkerContinuousFailure(self): 427 model = self._create_model_and_run_indefinitely() 428 429 self.assertFalse(self.cluster_coord.done()) 430 self._cluster.kill_task("worker", 0) 431 time.sleep(2) 432 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 433 self._cluster.start_task("worker", 0) 434 time.sleep(2) 435 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 436 self._cluster.kill_task("worker", 0) 437 time.sleep(2) 438 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 439 self._cluster.start_task("worker", 0) 440 time.sleep(2) 441 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 442 443 model.join_training_functions() 444 self.assertGreaterEqual(model.iterations.numpy(), 10) 445 446 def testNumpyFetchedAfterWorkerFailure(self): 447 448 with self.strategy.scope(): 449 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 450 451 @def_function.function 452 def worker_fn(): 453 return v + 1, v - 1 454 455 remote_value = self.cluster_coord.schedule(worker_fn) 456 # Attempt to fetch before killing worker task should succeed. 457 self.assertEqual((1, -1), remote_value.fetch()) 458 self._cluster.kill_task("worker", 0) 459 # So should attempt to fetch after killing worker task. 460 self.assertEqual((1, -1), remote_value.fetch()) 461 462 def testClusterStateNotDisrupted(self): 463 # This test has side effects and can disrupt other tests, even if the 464 # resource created by it will not be used in following tests. 465 # TODO(b/155209534): enable this test. 466 # self.testPSPreemptionErrorType() 467 468 self.thread_coord = thread_coordinator.Coordinator( 469 clean_stop_exception_types=[]) 470 self.testWorkerPreemptionMidstFunction() 471 472 self.thread_coord = thread_coordinator.Coordinator( 473 clean_stop_exception_types=[]) 474 self.testWorkerPreemptionErrorType() 475 476 # In previous tests, workers may fail after training is done. But the 477 # following tests start with creating resources where failure is not 478 # handled. 479 # TODO(b/153888707): enable the following two tests. 480 # self.testTwoWorkersPreempted() 481 # self.testWorkerContinuousFailure() 482 483 def testJoinRaisesUnavailableErrorAtPsFailure(self): 484 self._create_model_and_run_indefinitely() 485 self._cluster.kill_task("ps", 0) 486 while self.cluster_coord._cluster._closure_queue._error is None: 487 time.sleep(1) 488 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 489 errors.FailedPreconditionError)): 490 self.cluster_coord.join() 491 492 def testScheduleRaisesUnavailableErrorAtPsFailure(self): 493 self._create_model_and_run_indefinitely() 494 self._cluster.kill_task("ps", 0) 495 while self.cluster_coord._cluster._closure_queue._error is None: 496 time.sleep(1) 497 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 498 errors.FailedPreconditionError)): 499 self.cluster_coord.schedule(def_function.function(lambda: None)) 500 501 def testWorkerExecutionAfterPsFailureRaisesExpectedError(self): 502 model = self._create_model_and_run_indefinitely() 503 for i in range(self.num_ps): 504 self._cluster.kill_task("ps", i) 505 while self.cluster_coord._cluster._closure_queue._error is None: 506 time.sleep(1) 507 508 @def_function.function 509 def trivial_function(): 510 return model.iterations + 1 511 512 for i in range(self.num_workers): 513 try: 514 with ops.device("/job:worker/replica:0/task:{}".format(i)): 515 trivial_function() 516 except Exception as e: # pylint: disable=broad-except 517 if cluster_coordinator._is_ps_failure(e): 518 if i < self.num_workers - 1: 519 continue 520 return 521 raise AssertionError("Executing a function after PS fails, should " 522 "result in a PS failure.") 523 524 525class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 526 """Multi worker fault tolerance tests. 527 528 This covers the ordinary cases where multiple workers and PS are used. 529 """ 530 531 def setUp(self): 532 super(MultiWorkerFaultToleranceTest, self).setUp(2, 2) 533 534 535class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 536 """Single worker fault tolerance tests. 537 538 This covers the cases that ensure training can continue in a single-worker 539 cluster, even if the only worker can become unavailable at some point and 540 recovered (if there are multiple workers, it is possible that the training 541 succeeds with the workers that did not fail). Realistically single worker 542 is very rarely used, but the tests are important to ensure the correct 543 behaviors. 544 """ 545 546 def setUp(self): 547 super(SingleWorkerFaultToleranceTest, self).setUp(1, 1) 548 549 550if __name__ == "__main__": 551 v2_compat.enable_v2_behavior() 552 multi_process_runner.test_main() 553