1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Multi-process runner for testing purpose."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import contextlib
24import json
25import os
26import signal
27import sys
28import threading
29import time
30import unittest
31import weakref
32
33from absl import logging
34import six
35from six.moves import queue as Queue
36
37from tensorflow.python import tf2
38from tensorflow.python.compat import v2_compat
39from tensorflow.python.distribute import multi_process_lib
40from tensorflow.python.eager import context
41from tensorflow.python.util.tf_export import tf_export
42
43multiprocessing = multi_process_lib.multiprocessing
44
45# pylint: disable=g-import-not-at-top
46try:
47  # `faulthandler` is not available in py2.
48  import faulthandler
49except ImportError:
50  faulthandler = None
51
52# TODO(b/150264776): Remove after resolving CI issue.
53try:
54  import dill
55except ImportError:
56  dill = None
57
58# TODO(b/150264776): Remove after resolving CI issue.
59try:
60  import tblib.pickling_support
61  # For pickling traceback objects.
62  tblib.pickling_support.install()
63except ImportError:
64  pass
65
66
67# _ProcessStatusInfo contains process status information. When is_successful
68# attribute is True, the subprocess has ended successfully, or if False, the
69# exception stack trace info is stored in exc_info to pass on to parent process
70# to be re-raised.
71_ProcessStatusInfo = collections.namedtuple(
72    '_ProcessStatusInfo',
73    ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value'])
74
75# Information returned from a successful MultiProcessRunner run.
76MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
77                                                  ['return_value', 'stdout'])
78
79TestEnvironment = collections.namedtuple('TestEnvironment', [
80    'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast',
81    'v2_enabled', 'executing_eagerly'
82])
83
84# Resources for communication between worker processes and the main process.
85#
86# `process_status_queue` is used by `multi_process_runner` internally for
87#   communication from subprocesses to the parent process for whether it's been
88#   successful, and if not what the error stack trace is.
89# `parent_to_sub_queue` is used for communications from parent to subprocess.
90#   Currently this is only used to terminate subprocesses.
91# TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
92# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent
93#   process.
94# `barrier` is a barrier for the party of all subprocesses.
95Resources = collections.namedtuple('Resources', [
96    'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
97])
98
99# Default time out sec is selected so that it's handled before the default
100# "medium" timeout of the test runs.
101_DEFAULT_TIMEOUT_SEC = 200
102
103# The timeout in seconds to wait to force kill a child process. When a child
104# process times out we first try to SIGTERM it so that it has a chance to dump
105# stacktraces. However dumping stacktrace can take a long time.
106_FORCE_KILL_WAIT_SEC = 30
107
108
109class MultiProcessRunner(object):
110  """A utility class to start multiple processes to simulate a cluster.
111
112  We need to use multiple processes to simulate a cluster in TF 2.0 tests
113  because TF 2.0 has some process-global data structures that have to be
114  separated by processes. We also need child processes to test out our fault
115  tolerance because shutting down a standard TensorFlow server within its
116  process is not supported.
117
118  Note: the main test program that uses this runner class must run main program
119  via `test_main` defined in this file. Using this runner in non-test binaries
120  is not supported yet.
121
122  This class is not thread-safe. Child processes will inherit TF2 behavior flag.
123  """
124
125  def __init__(self,
126               fn,
127               cluster_spec,
128               rpc_layer=None,
129               max_run_time=None,
130               grpc_fail_fast=None,
131               stream_output=True,
132               return_output=False,
133               use_dill_for_args=True,
134               daemon=False,
135               dependence_on_chief=True,
136               auto_restart=False,
137               args=None,
138               kwargs=None):
139    """Instantiation of a `MultiProcessRunner`.
140
141    Args:
142      fn: Function to be run on child processes. This will be run on processes
143        for all task types.
144      cluster_spec: Dict for cluster spec. The utility function
145        `tf.__internal__.distribute.multi_process_runner.create_cluster_spec`
146        can be conveniently used to create such dict. The following is an
147        example of cluster with three workers and two ps's.
148        {"worker": ["worker0.example.com:2222",
149                    "worker1.example.com:2222",
150                    "worker2.example.com:2222"],
151         "ps": ["ps0.example.com:2222",
152                "ps1.example.com:2222"]}
153      rpc_layer: RPC layer to use. Default value is 'grpc'.
154      max_run_time: `None` or integer. If not `None`, child processes are forced
155        to exit at approximately this many seconds after this utility is called.
156        We achieve this through `signal.alarm()` api. Note that this is best
157        effort at Python level since Python signal handler does not get executed
158        when it runs lower level C/C++ code. So it can be delayed for
159        arbitrarily long time. If any of the child process is still running when
160        `max_run_time` is up, they will be force-terminated and an
161        `UnexpectedSubprocessExitError` may be raised. If `None`, child
162        processes are not forced to exit.
163      grpc_fail_fast: Whether GRPC connection between processes should fail
164        without retrying. Defaults to None, in which case the environment
165        variable is not explicitly set.
166      stream_output: True if the output/error from the subprocesses should be
167        streamed to be printed in parent process' log. Defaults to True.
168      return_output: If True, the output/error from the subprocesses should be
169        collected to be attached to the resulting namedtuple returned from
170        `join()`. The list of output can be retrieved via `stdout` attribute.
171        Defaults to False.
172      use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
173        can pickle more objects, but doesn't work with types in
174        `multiprocessing` library like `Mutex`.
175      daemon: Whether to start processes as daemons.
176      dependence_on_chief: Whether to terminates the cluster if the chief exits.
177        If auto_restart is True, it only terminates the cluster if the chief
178        exits with a zero exit code.
179      auto_restart: Whether to automatically restart processes that exit with
180        non-zero exit code.
181      args: Positional arguments to be sent to `fn` run on subprocesses.
182      kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
183
184    Raises:
185      RuntimeError: if `multi_process_runner.test_main()` is not called.
186      ValueError: if there are more than one chief in the `cluster_spec`.
187    """
188
189    assert cluster_spec is not None
190    if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1:
191      raise ValueError('If chief exists in the cluster, there must be at most '
192                       'one chief. Current `cluster_spec` has {} chiefs.'
193                       .format(len(cluster_spec['chief'])))
194    if not multi_process_lib.initialized():
195      raise NotInitializedError(
196          '`multi_process_runner` is not initialized. '
197          'Please call `tf.__internal__.distribute.multi_process_runner.'
198          'test_main()` within `if __name__ == \'__main__\':` block '
199          'in your python module to properly initialize '
200          '`multi_process_runner`.')
201    if not callable(fn):
202      raise ValueError('fn is not a callable')
203
204    self._fn = fn
205    self._cluster_spec = cluster_spec
206    self._rpc_layer = rpc_layer or 'grpc'
207    self._max_run_time = max_run_time
208    self._grpc_fail_fast = grpc_fail_fast
209    self._stream_output = stream_output
210    # TODO(rchao): Revisit return_output argument to consider other solution.
211    self._return_output = return_output
212    self._dependence_on_chief = dependence_on_chief
213    self._use_dill_for_args = use_dill_for_args
214    self._daemon = daemon
215    self._auto_restart = auto_restart
216    self._args = args or ()
217    self._kwargs = kwargs or {}
218
219    # Child processes should have the same v2 and eager behavior.
220    self._v2_enabled = tf2.enabled()
221    self._executing_eagerly = context.executing_eagerly()
222
223    self._joined = False
224    self._process_lock = threading.Lock()
225    # Guarded by self._process_lock.
226    self._processes = {}
227    # Record which processes are terminated. Due to a bug in Python<3.7,
228    # terminated processes return 255 exit code, which should cause an exception
229    # in join().
230    # https://bugs.python.org/issue30589
231    # Guarded by self._process_lock.
232    self._terminated = set()
233    self._reading_threads = []
234
235    self._manager = manager()
236    self._process_status_queue = self._manager.Queue()
237    self._parent_to_sub_queue = self._manager.Queue()
238    parties = sum(len(addresses) for addresses in self._cluster_spec.values())
239    self._barrier = self._manager.Barrier(parties)
240
241    # We use a queue to collect outputs from worker processes since it's thread
242    # safe.
243    self._streaming_queue = self._manager.Queue()
244
245    self._watchdog_thread = None
246
247  def set_args(self, args=None, kwargs=None):
248    self._args = args or self._args
249    self._kwargs = kwargs or self._kwargs
250
251  def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
252    """Function to continuously read lines from subprocesses."""
253    with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader:
254      for line in reader:
255        task_string = '[{}-{}]:'.format(task_type, task_id)
256        formatted_line = '{} {}'.format(task_string.ljust(14), line)
257        if self._stream_output:
258          # TODO(rchao): Use a lock here to ensure the printed lines are not
259          # broken.
260          print(formatted_line, end='', flush=True)
261        if self._return_output:
262          self._streaming_queue.put(formatted_line)
263
264  def _start_subprocess_and_reading_thread(self,
265                                           task_type,
266                                           task_id,
267                                           cluster_spec=None,
268                                           fn=None,
269                                           args=None,
270                                           kwargs=None):
271    """Start a subprocess and a thread the reads lines from the subprocess."""
272
273    if dill is None:
274      raise unittest.SkipTest(
275          'TODO(b/150264776): Resolve dependency issue in CI')
276
277    test_env = TestEnvironment(
278        task_type=task_type,
279        task_id=task_id,
280        cluster_spec=cluster_spec or self._cluster_spec,
281        rpc_layer=self._rpc_layer,
282        grpc_fail_fast=self._grpc_fail_fast,
283        v2_enabled=self._v2_enabled,
284        executing_eagerly=self._executing_eagerly,
285    )
286    pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
287    resources = Resources(
288        process_status_queue=self._process_status_queue,
289        parent_to_sub_queue=self._parent_to_sub_queue,
290        streaming_pipe_w=pipe_w,
291        barrier=self._barrier,
292    )
293    if fn is None:
294      fn, args, kwargs = self._fn, self._args, self._kwargs
295    # Always use dill to pickle fn so that we support more callable
296    # types, e.g. lambda.
297    fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
298    if self._use_dill_for_args:
299      args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
300      kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
301
302    p = _Process(
303        test_env=test_env,
304        target=_ProcFunc(),
305        args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args),
306        daemon=self._daemon)
307    p.start()
308    self._processes[(task_type, task_id)] = p
309    self._terminated.discard((task_type, task_id))
310
311    # For each subprocess, we dedicate a thread continuously reading lines
312    # from them.
313    thread = threading.Thread(  # pylint: disable=unexpected-keyword-arg
314        target=self._continuously_readline_from_sub,
315        args=(pipe_r, task_type, task_id))
316    thread.start()
317    self._reading_threads.append(thread)
318
319    if self._watchdog_thread is None or not self._watchdog_thread.is_alive():
320      self._watchdog_thread = threading.Thread(target=self._process_watchdog)
321      self._watchdog_thread.start()
322
323  def start(self):
324    """Starts processes, one for each task in `cluster_spec`.
325
326    Note that this is best effort by the applicable multiprocessing library,
327    and it may take up to seconds for a subprocess to be successfully started.
328    """
329    with self._process_lock:
330      if self._processes:
331        raise ValueError('MultiProcessRunner already started.')
332      if self._joined:
333        raise ValueError('cannot start new processes after'
334                         'MultiProcessRunner.join() is called')
335
336      for task_type, addresses in self._cluster_spec.items():
337        for task_id, _ in enumerate(addresses):
338          self._start_subprocess_and_reading_thread(task_type, task_id)
339
340    # TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
341    # without this the tests become very flaky.
342    if self._max_run_time is not None:
343
344      def handler(signum, frame):
345        del signum, frame
346        self.terminate_all()
347
348      signal.signal(signal.SIGALRM, handler)
349      signal.alarm(self._max_run_time)
350
351  def start_in_process_as(self, as_task_type, as_task_id):
352    """Start the processes, with the specified task run in main process.
353
354    This is similar to `start()` except that the task with task_type
355    `as_task_type` and task_id `as_task_id` is run in the main process.
356    This method is particularly useful when debugging tool such as `pdb` is
357    needed in some specific task. Note that since this method is blocking until
358    that specific task exits, additional actions would need a thread to be
359    called:
360
361    ```python
362    def fn():
363      # user code to be run
364      import pdb; pdb.set_trace()
365
366    def follow_ups():
367      time.sleep(5)
368      mpr.start_single_process(
369          task_type='evaluator',
370          task_id=0)
371
372    mpr = multi_process_runner.MultiProcessRunner(
373        fn,
374        multi_worker_test_base.create_cluster_spec(
375            has_chief=True, num_workers=1))
376    threading.Thread(target=follow_ups).start()
377    mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
378    mpr.join()
379    ```
380
381    Note that if `return_output=True`, the logs/stdout by task
382    run by the main process is not available in result.stdout.
383
384    Args:
385      as_task_type: The task type to be run in the main process.
386      as_task_id: The task id to be run in the main process.
387    """
388    if self._processes:
389      raise ValueError('MultiProcessRunner already started.')
390    with self._process_lock:
391      if self._joined:
392        raise ValueError('cannot start new processes after'
393                         'MultiProcessRunner.join() is called')
394      for task_type, addresses in self._cluster_spec.items():
395        for task_id, _ in enumerate(addresses):
396          if not (task_type == as_task_type and task_id == as_task_id):
397            self._start_subprocess_and_reading_thread(task_type, task_id)
398
399    _set_tf_config(as_task_type, as_task_id, self._cluster_spec,
400                   self._rpc_layer)
401    self._fn(*self._args, **self._kwargs)
402
403  def start_single_process(self,
404                           task_type,
405                           task_id,
406                           cluster_spec=None,
407                           fn=None,
408                           args=None,
409                           kwargs=None):
410    """Starts a single process.
411
412    This starts a process in the cluster with the task type, task id, and the
413    process function (`fn`). If process function is `None`, the function
414    provided at `__init__` will be used. If `cluster_spec` is `None`, the
415    cluster spec provided at `__init__` will be used.
416
417    TODO(rchao): It is meant that all subprocesses will be updated with the new
418    cluster spec, but this has yet to be implemented. At this time only the
419    newly started subprocess picks up this updated cluster spec.
420
421    Args:
422      task_type: The task type.
423      task_id: The task id.
424      cluster_spec: The cluster spec to be used on the newly started
425        process. If `None`, the cluster spec provided at `__init__` will be
426        used.
427      fn: The process function to be run on the newly started
428        process. If specified, specify `args` and `kwargs` as well. If `None`,
429        the function provided at `__init__` will be used.
430      args: Optional positional arguments to be supplied in `fn`.
431      kwargs: Optional keyword arguments to be supplied in `fn`.
432    """
433    with self._process_lock:
434      if self._joined:
435        raise ValueError('cannot start new processes after'
436                         'MultiProcessRunner.join() is called')
437      self._start_subprocess_and_reading_thread(
438          task_type,
439          task_id,
440          cluster_spec=cluster_spec,
441          fn=fn,
442          args=args or (),
443          kwargs=kwargs or {})
444
445  def _queue_to_list(self, queue_to_convert):
446    """Convert `queue.Queue` to `list`."""
447    list_to_return = []
448    # Calling `queue.empty()` is not reliable.
449    while True:
450      try:
451        list_to_return.append(queue_to_convert.get(block=False))
452      except Queue.Empty:
453        break
454    return list_to_return
455
456  def _get_process_statuses(self):
457    # One worker may have multiple statuses. We only keep the last one.
458    statuses = {}
459    for status in self._queue_to_list(self._process_status_queue):
460      statuses[(status.task_type, status.task_id)] = status
461    return statuses
462
463  def get_process_id(self, task_type, task_id):
464    """Returns the subprocess id given the task type and task id."""
465    with self._process_lock:
466      p = self._processes.get((task_type, task_id), None)
467    return p.pid if p else None
468
469  def get_process_exit_code(self, task_type, task_id):
470    """Returns the subprocess exit code given the task type and task id.
471
472    Args:
473      task_type: The task type.
474      task_id: The task id.
475
476    Returns:
477      The subprocess exit code; `None` if the subprocess has not exited yet.
478
479    Raises:
480      KeyError: If the corresponding subprocess is not found with `task_type`
481        and `task_id`.
482    """
483    with self._process_lock:
484      p = self._processes[(task_type, task_id)]
485    return p.exitcode if p else None
486
487  def process_exists(self, task_type, task_id):
488    """Returns whether the subprocess still exists given the task type and id.
489
490    Args:
491      task_type: The task type.
492      task_id: The task id.
493
494    Returns:
495      Boolean; whether the subprocess still exists. If the subprocess has
496      exited, this returns False.
497    """
498    return self.get_process_exit_code(task_type, task_id) is None
499
500  def _process_watchdog(self):
501    """Simulates a cluster management system.
502
503    - If auto_restart is True, it restarts processes that exit with a non-zero
504      exit code. Note that when join() times out it overrides auto_restart to
505      False.
506    - If dependence_on_chief is True, it terminates all processes once the chief
507      exits. If auto_restart is also True, it only terminates all processes if
508      the chief exit with a zero exit code, otherwise it restarts the chief.
509
510    This runs in self._watchdog_thread.
511    """
512    while True:
513      time.sleep(1)
514      with self._process_lock:
515        chief = self._processes.get(('chief', 0), None)
516        # Terminate the cluster when _dependence_on_chief is True if either:
517        # - chief has exited with zero exit code.
518        # - chief has exited with non-zero exit code and self._auto_restart is
519        #   False.
520        if chief and self._dependence_on_chief and chief.exitcode is not None:
521          if chief.exitcode == 0 or (not self._auto_restart):
522            for p in self._processes.values():
523              # Give other processes a chance to exit on their own.
524              p.join(timeout=3)
525            self._terminate_all()
526            for p in self._processes.values():
527              p.join()
528            return
529
530        # Auto restart failed processes if self._auto_restart is True.
531        if self._auto_restart:
532          has_failure = False
533          for (task_type, task_id), p in self._processes.items():
534            if p.exitcode is not None and p.exitcode != 0:
535              has_failure = True
536              logging.info('Restarting failed %s-%d', task_type, task_id)
537              self._start_subprocess_and_reading_thread(task_type, task_id)
538          if has_failure:
539            continue
540
541        # Exit the thread if all processes have exited at this point.
542        if all(p.exitcode is not None for p in self._processes.values()):
543          return
544
545  def _reraise_if_subprocess_error(self, process_statuses):
546    for process_status in process_statuses.values():
547      assert isinstance(process_status, _ProcessStatusInfo)
548      if not process_status.is_successful:
549        process_status.exc_info[1].mpr_result = self._get_mpr_result(
550            process_statuses)
551        six.reraise(*process_status.exc_info)
552
553  def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
554    """Joins all the processes with timeout.
555
556    If any of the subprocesses does not exit approximately after `timeout`
557    seconds has passed after `join` call, this raises a
558    `SubprocessTimeoutError`.
559
560    Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
561    log the stack traces of the subprocesses when they exit. However, this
562    results in timeout when the test runs with tsan (thread sanitizer); if tsan
563    is being run on the test targets that rely on timeout to assert information,
564    `MultiProcessRunner.terminate_all()` must be called after `join()`, before
565    the test exits, so the subprocesses are terminated with SIGKILL, and data
566    race is removed.
567
568    Args:
569      timeout: optional integer or `None`. If provided as an integer, and not
570      all processes report status within roughly `timeout` seconds, a
571      `SubprocessTimeoutError` exception will be raised. If `None`, `join` never
572      times out.
573
574    Returns:
575      A `MultiProcessRunnerResult` object, which has two attributes,
576      `return_value` and `stdout`. `return_value` always contains a list of
577      return values from the subprocesses, although the order is not meaningful.
578      If `return_output` argument is True at `__init__`, `stdout` is available
579      that contains a list of all messages from subprocesses' stdout and stderr.
580
581    Raises:
582      SubprocessTimeoutError: if not all processes report status approximately
583        within `timeout` seconds. When this is raised, a
584        `MultiProcessRunnerResult` object can be retrieved by
585        `SubprocessTimeoutError`'s mpr_result attribute, which has the same
586        structure as above 'Returns' section describes.
587      UnexpectedSubprocessExitError: If any of the subprocesses did not exit
588        properly (for example, they exit on SIGTERM or SIGKILL signal). When
589        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
590        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
591        same structure as above 'Returns' section describes. If `max_run_time`
592        is not `None`, it is expected that some subprocesses may be
593        force-killed when `max_run_time` is up, and this is raised in those
594        cases.
595      Exception: if there is an Exception propagated from any subprocess. When
596        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
597        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
598        same structure as above 'Returns' section describes.
599    """
600    if timeout and not isinstance(timeout, int):
601      raise ValueError('`timeout` must be an integer or `None`.')
602    with self._process_lock:
603      if self._joined:
604        raise ValueError("MultiProcessRunner can't be joined twice.")
605      self._joined = True
606
607    self._watchdog_thread.join(timeout)
608    if self._watchdog_thread.is_alive():
609      # Timeout. Force termination to dump worker processes stack trace.
610      with self._process_lock:
611        self._auto_restart = False
612      logging.error('Timeout when joining for child processes. Terminating...')
613      self.terminate_all(sig=signal.SIGTERM)
614      # Wait for the processes to terminate by themselves first, so they have a
615      # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them.
616      self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC)
617      if self._watchdog_thread.is_alive():
618        logging.error('Timeout when waiting for child processes to '
619                      'print stacktrace. Sending SIGKILL...')
620        self.terminate_all()
621        self._watchdog_thread.join()
622      process_statuses = self._get_process_statuses()
623      self._reraise_if_subprocess_error(process_statuses)
624      raise SubprocessTimeoutError(
625          'One or more subprocesses timed out, where timeout was set to {}s. '
626          'Please change the `timeout` argument for '
627          '`MultiProcessRunner.join()` or `multi_process_runner.run()` '
628          'if it should be adjusted.'.format(timeout),
629          self._get_mpr_result(process_statuses))
630
631    for (task_type, task_id), p in self._processes.items():
632      logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
633
634    process_statuses = self._get_process_statuses()
635    self._reraise_if_subprocess_error(process_statuses)
636
637    # Checking all the processes that are expected to exit properly.
638    for (task_type, task_id), p in self._processes.items():
639      # Successfully exiting process has exit code 0. We ignore processes that
640      # are terminated.
641      assert p.exitcode is not None
642      if (p.exitcode > 0 and (task_type, task_id) not in self._terminated):
643        raise UnexpectedSubprocessExitError(
644            'Subprocess %s-%d exited with exit code %s. See logs for details.'
645            % (task_type, task_id, p.exitcode),
646            self._get_mpr_result(process_statuses))
647
648    logging.info('Joining log reading threads.')
649    for thread in self._reading_threads:
650      thread.join()
651    logging.info('Joined log reading threads.')
652
653    # Clear the alarm.
654    signal.alarm(0)
655
656    return self._get_mpr_result(process_statuses)
657
658  def _get_mpr_result(self, process_statuses):
659    stdout = self._queue_to_list(self._streaming_queue)
660    return_values = []
661    for process_status in process_statuses.values():
662      if process_status.return_value is not None:
663        return_values.append(process_status.return_value)
664    return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
665
666  def terminate(self, task_type, task_id):
667    """Terminates the process with `task_type` and `task_id`.
668
669    If auto_retart=True, the terminated task will be restarted unless the chief
670    has already exited with zero exit code.
671
672    Args:
673      task_type: the task type.
674      task_id: the task id.
675
676    """
677    with self._process_lock:
678      p = self._processes.get((task_type, task_id), None)
679      if p is None:
680        raise ValueError('{}-{} does not exist'.format(task_type, task_id))
681      self._terminated.add((task_type, task_id))
682      # TODO(crccw): change to use Process.terminate() as well.
683      self._parent_to_sub_queue.put('terminate {} {}'.format(
684          task_type, task_id))
685      p.join()
686
687  def _terminate_all(self, sig=None):
688    """Terminates all subprocesses.
689
690    The caller is required to hold self._process_lock.
691
692    Args:
693      sig: the signal used to terminate the process. The default is SIGKILL.
694    """
695
696    # Use SIGKILL as default. In systems where that's unavailable such as
697    # windows, use SIGTERM.
698    sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
699    for (task_type, task_id), p in self._processes.items():
700      if p.exitcode is not None:
701        logging.info('%s-%d has already exited. Not terminating.', task_type,
702                     task_id)
703        continue
704      try:
705        os.kill(p.pid, sig)
706        self._terminated.add((task_type, task_id))
707        logging.info('%s-%d terminated with signal %r.', task_type, task_id,
708                     sig)
709      except ProcessLookupError:
710        logging.info('Attempting to kill %s-%d but it does not exist.',
711                     task_type, task_id)
712
713  def terminate_all(self, sig=None):
714    """Terminates all subprocesses."""
715    with self._process_lock:
716      self._terminate_all(sig)
717
718
719class _Process(multi_process_lib.Process):
720  """A modified `multiprocessing.Process` that can set up environment variables."""
721
722  # TODO(crccw): consider moving other logics in _ProcFunc to _Process.
723
724  def __init__(self, test_env, **kwargs):
725    super(_Process, self).__init__(**kwargs)
726    self._test_env = test_env
727    self._actual_run = getattr(self, 'run')
728    self.run = self._run_with_setenv
729
730  def _run_with_setenv(self):
731    # We need to set environment variables before doing anything because
732    # setenv() is not thread-safe.
733    test_env = self._test_env
734    if test_env.grpc_fail_fast is not None:
735      os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
736    _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
737                   test_env.rpc_layer)
738    return self._actual_run()
739
740
741class _ProcFunc(object):
742  """Represents a callable to run in a subprocess."""
743
744  @contextlib.contextmanager
745  def _runtime_mode(self, executing_eagerly):
746    if executing_eagerly:
747      with context.eager_mode():
748        yield
749    else:
750      with context.graph_mode():
751        yield
752
753  def _message_checking_func(self, task_type, task_id):
754    """A function that regularly checks messages from parent process."""
755    # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
756    while True:
757      try:
758        message = self._resources.parent_to_sub_queue.get(block=False)
759
760        # Currently the only possible message is termination.
761        if not message.startswith('terminate'):
762          raise ValueError('Unrecognized message: {}'.format(message))
763
764        if message == 'terminate {} {}'.format(task_type, task_id):
765          break
766        else:
767          # If the message is not targeting this process, put it back to the
768          # queue.
769          self._resources.parent_to_sub_queue.put(message)
770          time.sleep(1)
771      except Queue.Empty:
772        time.sleep(0.1)
773    self._resources.process_status_queue.put(
774        _ProcessStatusInfo(
775            task_type=task_type,
776            task_id=task_id,
777            is_successful=True,
778            exc_info=None,
779            return_value=None))
780    # `os._exit(1)` is used to more reliably terminate a subprocess.
781    os._exit(1)  # pylint: disable=protected-access
782
783  def _close_streaming(self):
784    """Close stdout, stderr and streaming pipe.
785
786    We need to explicitly close them since Tensorflow may take a while to exit,
787    so that the reading threads in the main process can exit more quickly.
788    """
789    sys.stdout.flush()
790    sys.stderr.flush()
791    sys.stdout.close()
792    sys.stderr.close()
793    self._resources.streaming_pipe_w.close()
794
795  def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args):
796    """The wrapper function that actually gets run in child process(es)."""
797
798    global _barrier
799
800    self._resources = resources
801    _barrier = self._resources.barrier
802    fn = dill.loads(fn)
803    if use_dill_for_args:
804      args = dill.loads(args)
805      kwargs = dill.loads(kwargs)
806
807    if faulthandler is not None:
808      faulthandler.enable()
809      faulthandler.register(signal.SIGTERM, chain=True)
810
811    # All logging should go to stderr to be streamed to the main process.
812    logging.set_stderrthreshold(logging.DEBUG)
813
814    # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
815    # print() and logging.*() write directly to `streaming_pipe_w`.
816    # Unfortunately since we cannot prepend task_type and task_id information to
817    # the streamed logs we will need a thread per subprocess to distinguish
818    # where the piece of message is from.
819    os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
820    os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
821
822    pid = os.getpid()
823    logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
824                 test_env.task_type, test_env.task_id)
825
826    # The thread will be dedicated to checking messages from the parent process.
827    threading.Thread(  # pylint: disable=unexpected-keyword-arg
828        target=self._message_checking_func,
829        args=(test_env.task_type, test_env.task_id),
830        daemon=True).start()
831
832    if test_env.v2_enabled:
833      v2_compat.enable_v2_behavior()
834
835    with self._runtime_mode(test_env.executing_eagerly):
836      info = _run_contained(test_env.task_type, test_env.task_id, fn, args,
837                            kwargs)
838      self._resources.process_status_queue.put(info)
839
840      # Re-raise the exception in addition to reporting it to the parent
841      # process, so that even if `--test_timeout` flag is set and the
842      # error doesn't make it to be shown in parent process before bazel's
843      # timeout, the log would still show what happens in this subprocess,
844      # instead of silently suppressing the error due to early bazel
845      # timeout. Raising an error in the subprocess produces stack trace in
846      # the log, but the program continues running.
847      if not info.is_successful:
848        six.reraise(*info.exc_info)
849
850      self._close_streaming()
851
852    # Exit with code 0 as it's considered successful exit at this point.
853    sys.exit(0)
854
855
856# Active MultiProcessPoolRunner. We need to shut them down when the program
857# exits, and this is by setting the `tearDownModule` of the module containing
858# `__main__`. Note this it set in both the parent process and the subprocesses.
859_active_pool_runners = weakref.WeakSet()
860
861
862def _shutdown_all_pool_runners():
863  for pool in _active_pool_runners:
864    pool.shutdown()
865
866
867def is_oss():
868  """Returns whether the test is run under OSS."""
869  return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
870
871
872class MultiProcessPoolRunner(object):
873  """A utility class to start a process pool to simulate a cluster.
874
875  It's similar to MultiProcessRunner, but uses a pool of processes to avoid the
876  expensive initialization cost of Tensorflow.
877  """
878
879  def __init__(self, cluster_spec, initializer=None):
880    """Creates a multi-process pool runner.
881
882    Args:
883      cluster_spec: Dict for cluster spec. The following is an example of
884        cluster with three workers.
885        {"worker": ["worker0.example.com:2222",
886                    "worker1.example.com:2222",
887                    "worker2.example.com:2222"]}
888      initializer: a callable to called at the startup of worker processes.
889
890    Raises:
891      RuntimeError: if `multi_process_runner.test_main()` is not called.
892      ValueError: if there are more than one chief in the `cluster_spec`.
893    """
894    _active_pool_runners.add(self)
895    self._cluster_spec = cluster_spec
896    self._initializer = initializer
897    self._conn = {}
898    self._runner = None
899
900  def __del__(self):
901    self.shutdown()
902
903  def shutdown(self):
904    """Shuts down the worker pool."""
905    for conn in self._conn.values():
906      conn.close()
907    self._conn = {}
908    if self._runner is not None:
909      try:
910        self._runner.join()
911      except Exception as e:  # pylint: disable=broad-except
912        logging.error(
913            'Ignoring exception when shutting down MultiProcessPoolRunner: %s',
914            e)
915      self._runner = None
916
917  def _start(self):
918    """Starts the worker pool."""
919    # We need different arguments for different processes so we're passing a
920    # no-op fn here and use start_single_process instead.
921
922    if dill is None:
923      raise unittest.SkipTest(
924          'TODO(b/150264776): Resolve dependency issue in CI')
925
926    self._runner = MultiProcessRunner(
927        fn=lambda: None,
928        cluster_spec=self._cluster_spec,
929        use_dill_for_args=False)
930    if self._initializer:
931      initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL)
932    else:
933      initializer = None
934    for task_type, addresses in self._cluster_spec.items():
935      for task_id, _ in enumerate(addresses):
936        conn1, conn2 = multiprocessing.Pipe(duplex=True)
937        self._conn[(task_type, task_id)] = conn1
938        self._runner.start_single_process(
939            task_type,
940            task_id,
941            fn=_pool_runner_worker,
942            args=(task_type, task_id, initializer, conn2))
943
944  def run(self, fn, args=None, kwargs=None):
945    """Runs `fn` with `args` and `kwargs` on all jobs.
946
947    Args:
948      fn: The function to be run.
949      args: Optional positional arguments to be supplied in `fn`.
950      kwargs: Optional keyword arguments to be supplied in `fn`.
951
952    Returns:
953      A list of return values.
954    """
955    # TODO(b/150264776): skip in OSS until it's implemented.
956    multi_process_lib.Process()
957    if self._runner is None:
958      self._start()
959
960    fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL)
961    for conn in self._conn.values():
962      conn.send((fn, args or [], kwargs or {}))
963
964    process_statuses = []
965    for (task_type, task_id), conn in self._conn.items():
966      logging.info('Waiting for the result from %s-%d', task_type, task_id)
967      try:
968        process_statuses.append(conn.recv())
969      except EOFError:
970        # This shouldn't happen due to exceptions in fn. This usually
971        # means bugs in the runner.
972        self.shutdown()
973        raise RuntimeError('Unexpected EOF. Worker process may have died. '
974                           'Please report a bug')
975
976    return_values = []
977    for process_status in process_statuses:
978      assert isinstance(process_status, _ProcessStatusInfo)
979      if not process_status.is_successful:
980        six.reraise(*process_status.exc_info)
981      if process_status.return_value is not None:
982        return_values.append(process_status.return_value)
983
984    return return_values
985
986
987def _pool_runner_worker(task_type, task_id, initializer, conn):
988  """Function that runs on the workers in a pool.
989
990  It listens for callables to run and returns the result until `conn` is closed.
991  It captures the exceptions during executing the callable and return it through
992  `conn`.
993
994  Args:
995    task_type: the task type.
996    task_id: the task index.
997    initializer: a callable to execute during startup.
998    conn: a multiprocessing.Connection object to listen for tasks and send
999      results.
1000  """
1001  if initializer:
1002    initializer = dill.loads(initializer)
1003    initializer()
1004  while True:
1005    try:
1006      fn, args, kwargs = conn.recv()
1007    except EOFError:
1008      break
1009    fn = dill.loads(fn)
1010    info = _run_contained(task_type, task_id, fn, args, kwargs)
1011    sys.stdout.flush()
1012    sys.stderr.flush()
1013    conn.send(info)
1014
1015
1016def _run_contained(task_type, task_id, fn, args, kwargs):
1017  """Runs `fn` with `args` and `kwargs`.
1018
1019  The function returns _ProcessStatusInfo which captures the return value and
1020  the exception.
1021
1022  Args:
1023    task_type: the task type.
1024    task_id: the task index.
1025    fn: the function to be run.
1026    args: optional positional arguments to be supplied in `fn`.
1027    kwargs: optional keyword arguments to be supplied in `fn`.
1028
1029  Returns:
1030    a _ProcessStatusInfo.
1031
1032  """
1033  is_successful = False
1034  return_value = None
1035  exc_info = None
1036  try:
1037    return_value = fn(*args, **kwargs)
1038    is_successful = True
1039    return _ProcessStatusInfo(
1040        task_type=task_type,
1041        task_id=task_id,
1042        is_successful=is_successful,
1043        exc_info=exc_info,
1044        return_value=return_value)
1045
1046  # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not
1047  # handled here.
1048  except Exception:  # pylint: disable=broad-except
1049    exc_info = sys.exc_info()
1050    return _ProcessStatusInfo(
1051        task_type=task_type,
1052        task_id=task_id,
1053        is_successful=is_successful,
1054        exc_info=exc_info,
1055        return_value=return_value)
1056
1057
1058@tf_export('__internal__.distribute.multi_process_runner'
1059           '.SubprocessTimeoutError',
1060           v1=[])
1061class SubprocessTimeoutError(RuntimeError):
1062  """An error that indicates there is at least one subprocess timing out.
1063
1064  When this is raised, a namedtuple object representing the multi-process run
1065  result can be retrieved by
1066  `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1067  `mpr_result` attribute. See
1068  `tf.__internal__.distribute.multi_process_runner.run` for more information.
1069  """
1070
1071  def __init__(self, msg, mpr_result):
1072    super(SubprocessTimeoutError, self).__init__(msg)
1073    self.mpr_result = mpr_result
1074
1075
1076@tf_export('__internal__.distribute.multi_process_runner'
1077           '.UnexpectedSubprocessExitError',
1078           v1=[])
1079class UnexpectedSubprocessExitError(RuntimeError):
1080  """An error indicating there is at least one subprocess with unexpected exit.
1081
1082  When this is raised, a namedtuple object representing the multi-process run
1083  result can be retrieved by
1084  `tf.__internal__.distribute.multi_process_runner
1085  .UnexpectedSubprocessExitError`'s
1086  `mpr_result` attribute. See
1087  `tf.__internal__.distribute.multi_process_runner.run` for more information.
1088  """
1089
1090  def __init__(self, msg, mpr_result):
1091    super(UnexpectedSubprocessExitError, self).__init__(msg)
1092    self.mpr_result = mpr_result
1093
1094
1095@tf_export(
1096    '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[])
1097class NotInitializedError(RuntimeError):
1098  """An error indicating `multi_process_runner.run` is used without init.
1099
1100  When this is raised, user is supposed to call
1101  `tf.__internal__.distribute.multi_process_runner.test_main()` within
1102  `if __name__ == '__main__':` block to properly initialize
1103  `multi_process_runner.run`.
1104  """
1105  pass
1106
1107
1108def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None):
1109  """Set TF_CONFIG environment variable."""
1110  tf_config_dict = {
1111      'cluster': cluster_spec,
1112      'task': {
1113          'type': task_type,
1114          'index': task_id,
1115      },
1116  }
1117  if rpc_layer is not None:
1118    tf_config_dict['rpc_layer'] = rpc_layer
1119  os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
1120
1121
1122@tf_export('__internal__.distribute.multi_process_runner.run', v1=[])
1123def run(fn,
1124        cluster_spec,
1125        rpc_layer=None,
1126        max_run_time=None,
1127        return_output=False,
1128        timeout=_DEFAULT_TIMEOUT_SEC,
1129        args=None,
1130        kwargs=None):
1131  """Run `fn` in multiple processes according to `cluster_spec`.
1132
1133  Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run`
1134  launches multiple processes, each of which runs `fn`. These processes are
1135  referred to as "subprocesses" or "child processes". Each of those subprocesses
1136  will have their `TF_CONFIG` environment variable set, according to
1137  `cluster_spec` and their task types. The stdout of the subprocesses are
1138  streamed to the main process' and thus available in logs (if `stream_output`
1139  is True), with [type-id] prefix.
1140
1141  `tf.__internal__.distribute.multi_process_runner.run` will block until all
1142  subprocesses have successfully exited, and return a namedtuple object that
1143  represents the run result. This object has a `return_value` attribute, which
1144  is a list that contains subprocesses `fn`'s return values, for those
1145  subprocesses that successfully returned from `fn`. The order of `return_value`
1146  list is not meaningful. If an optional arg `return_output` (default to False)
1147  is set to True, the namedtuple object will have an additional attribute
1148  `stdout`, which is a list containing the stdout of the subprocesses. If any
1149  subprocess' `fn` ends up raising an error, that error will be reraised from
1150  `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned
1151  namedtuple object will be available through the exception's
1152  `mpr_result` attribute.
1153
1154  This utility is used for simulating running TensorFlow programs across
1155  multiple task types, and each of the task type may contain more than one task
1156  (except for "chief" where more than one task is prohibited). Test coverage of
1157  multi-worker training is the main application of this utility, where code
1158  written for multi-worker training can be realistically covered in unit tests.
1159
1160  Any test module that uses
1161  `tf.__internal__.distribute.multi_process_runner.run()` must call
1162  `tf.__internal__.distribute.multi_process_runner.test_main()` instead of
1163  regular `test.main()` inside `if __name__ == '__main__':` block for proper
1164  initialization.
1165
1166  Args:
1167    fn: Function to be run on child processes. This will be run on processes for
1168      all task types.
1169    cluster_spec: Dict for cluster spec. The utility function
1170      `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can
1171      be conveniently used to create such dict. The following is an example of
1172      cluster with three workers and two ps's.
1173      {"worker": ["worker0.example.com:2222",
1174                  "worker1.example.com:2222",
1175                  "worker2.example.com:2222"],
1176       "ps": ["ps0.example.com:2222",
1177              "ps1.example.com:2222"]}
1178    rpc_layer: RPC layer to use. Default value is 'grpc'.
1179    max_run_time: `None` or integer. If not `None`, child processes are forced
1180      to exit at approximately this many seconds after this utility is called.
1181      We achieve this through `signal.alarm()` api. Note that this is best
1182      effort at Python level since Python signal handler does not get executed
1183      when it runs lower level C/C++ code. So it can be delayed for arbitrarily
1184      long time. If any of the child process is still running when
1185      `max_run_time` is up, they will be force-terminated and an
1186      `tf.__internal__.distribute.multi_process_runner
1187      .UnexpectedSubprocessExitError`
1188      may be raised. If `None`, child processes are not forced to exit.
1189    return_output: If True, the output/error from the subprocesses should be
1190      collected to be attached to the resulting namedtuple returned from this
1191      utility. The list of output can be retrieved via `stdout` attribute.
1192      Defaults to False.
1193    timeout: optional integer or `None`. If provided as an integer, and not all
1194      processes report status within roughly `timeout` seconds, a
1195      `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`
1196      exception will be raised. If `None`,
1197      `tf.__internal__.distribute.multi_process_runner.run` never times out.
1198      Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in
1199      `multi_process_runner` module.
1200    args: Positional arguments to be sent to `fn` run on subprocesses.
1201    kwargs: Keyword arguments to be sent to `fn` run on subprocesses.
1202
1203  Returns:
1204      A namedtuple object, which has two attributes,
1205      `return_value` and `stdout`. `return_value` always contains a list of
1206      returnvalues from the subprocesses, although the order is not meaningful.
1207      If `return_output` argument is True, `stdout` is available that contains a
1208      list of all messages from subprocesses' stdout and stderr, and the order
1209      is mostly chronological.
1210
1211  Raises:
1212    RuntimeError: if
1213    `tf.__internal__.distribute.multi_process_runner.test_main()` is
1214      not called in test's `if __name__ == '__main__':` block.
1215    ValueError: if there are more than one chief in the `cluster_spec`.
1216    tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if
1217      not all processes report status approximately
1218      within `timeout` seconds. When this is raised, a
1219      namedtuple object can be retrieved by
1220      `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s
1221      `mpr_result` attribute, which has the same
1222      structure as above 'Returns' section describes.
1223    tf.__internal__.distribute.multi_process_runner
1224    .UnexpectedSubprocessExitError:
1225      If any of the subprocesses did not exit
1226      properly (for example, they exit on SIGTERM or SIGKILL signal). When
1227      this is raised, a namedtuple object can be retrieved by
1228      `tf.__internal__.distribute.multi_process_runner
1229      .UnexpectedSubprocessExitError`'s
1230      `mpr_result` attribute, which has the
1231      same structure as above 'Returns' section describes. If `max_run_time`
1232      is not `None`, it is expected that some subprocesses may be
1233      force-killed when `max_run_time` is up, and this is raised in those
1234      cases.
1235    Exception: if there is an Exception propagated from any subprocess. When
1236      this is raised, a namedtuple object can be retrieved by
1237      `tf.__internal__.distribute.multi_process_runner
1238      .UnexpectedSubprocessExitError`
1239      `mpr_result` attribute, which has the
1240      same structure as above 'Returns' section describes.
1241
1242  Examples:
1243
1244  ```python
1245  class SimpleMultiProcessTest(tf.test.TestCase):
1246
1247    def test_simple_printing_and_return(self):
1248
1249      def fn():
1250        resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1251
1252        # This will print "[chief-0]:     Task type: chief , task id: 0"
1253        # for chief, for example.
1254        logging.info('Task type: %s, task id: %d',
1255                     resolver.task_type, resolver.task_id)
1256
1257        return resolver.task_type
1258
1259      result = tf.__internal__.distribute.multi_process_runner.run(
1260          fn=fn,
1261          cluster_spec=(
1262              tf.__internal__
1263              .distribute.multi_process_runner.create_cluster_spec(
1264                  has_chief=True, num_workers=2)))
1265      assert sorted(result.return_value) == ['chief', 'worker', 'worker']
1266
1267    def test_error_from_fn(self):
1268
1269      def fn():
1270        resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
1271        raise ValueError('Task type {}, task id {} is errors out'.format(
1272            resolver.task_type, resolver.task_id))
1273
1274      with self.assertRaisesRegexp(ValueError,
1275                                   'Task type worker, task id 0 is errors out'):
1276        cluster_spec = (
1277            tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
1278                num_workers=1))
1279        tf.__internal__.distribute.multi_process_runner.run(
1280            fn=fn, cluster_spec=cluster_spec)
1281
1282
1283  if __name__ == '__main__':
1284    tf.__internal__.distribute.multi_process_runner.test_main()
1285  ```
1286  """
1287  runner = MultiProcessRunner(
1288      fn,
1289      cluster_spec,
1290      rpc_layer,
1291      max_run_time=max_run_time,
1292      return_output=return_output,
1293      args=args,
1294      kwargs=kwargs)
1295  runner.start()
1296  return runner.join(timeout)
1297
1298
1299# This is set by MultiProcessRunner in worker processes.
1300_barrier = None
1301
1302
1303@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[])
1304def get_barrier():
1305  """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`.
1306
1307  `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns
1308  a `multiprocessing.Barrier` object which can be used within `fn` of
1309  `tf.__internal__.distribute.multi_process_runner` to wait with
1310  `barrier.wait()` call until all other tasks have also reached the
1311  `barrier.wait()` call, before they can proceed individually.
1312
1313  Note that all tasks (subprocesses) have to reach `barrier.wait()` call to
1314  proceed. Currently it is not supported to block on only a subset of tasks
1315  in the cluster.
1316
1317  Example:
1318  ```python
1319
1320  def fn():
1321    some_work_to_be_done_by_all_tasks()
1322
1323    tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
1324
1325    # The barrier guarantees that at this point, all tasks have finished
1326    # `some_work_to_be_done_by_all_tasks()`
1327    some_other_work_to_be_done_by_all_tasks()
1328
1329  result = tf.__internal__.distribute.multi_process_runner.run(
1330      fn=fn,
1331      cluster_spec=(
1332          tf.__internal__
1333          .distribute.multi_process_runner.create_cluster_spec(
1334              num_workers=2)))
1335  ```
1336
1337
1338  Returns:
1339    A `multiprocessing.Barrier` for `multi_process_runner.run`.
1340  """
1341  if _barrier is None:
1342    raise ValueError(
1343        'barrier is not defined. It is likely because you are calling '
1344        'get_barrier() in the main process. get_barrier() can only be called '
1345        'in the subprocesses.'
1346    )
1347  return _barrier
1348
1349
1350_manager = None
1351_manager_lock = threading.Lock()
1352
1353
1354def manager():
1355  """Returns the multiprocessing manager object for concurrency tools.
1356
1357  The manager object is useful as it controls a server process that holds
1358  the python objects that can be shared across processes. This can be used
1359  for parent-subprocess communication:
1360
1361  ```python
1362  manager = multi_process_runner.manager()
1363  some_event_happening_in_subprocess = manager.Event()
1364  mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec,
1365      args=(some_event_happening_in_subprocess,))
1366  mpr.start()
1367  some_event_happening_in_subprocess.wait()
1368  # Do something that only should after some event happens in subprocess.
1369  ```
1370
1371  Note that the user of multi_process_runner should not create additional
1372  `multiprocessing.Manager()` objects; doing so can result in segfault in
1373  some cases.
1374
1375  This method should only be called after multi_process_runner.test_main() is
1376  called.
1377  """
1378  global _manager
1379  with _manager_lock:
1380    if _manager is None:
1381      _manager = multiprocessing.Manager()
1382    return _manager
1383
1384
1385@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[])
1386def test_main():
1387  """Main function to be called within `__main__` of a test file.
1388
1389  Any test module that uses
1390  `tf.__internal__.distribute.multi_process_runner.run()`
1391  must call this instead of regular `test.main()` inside
1392  `if __name__ == '__main__':` block, or an error will be raised when
1393  `tf.__internal__.distribute.multi_process_runner.run()` is used. This method
1394  takes
1395  care of needed initialization for launching multiple subprocesses.
1396
1397  Example:
1398  ```python
1399  class MyTestClass(tf.test.TestCase):
1400    def testSomething(self):
1401      # Testing code making use of
1402      # `tf.__internal__.distribute.multi_process_runner.run()`.
1403
1404  if __name__ == '__main__':
1405    tf.__internal__.distribute.multi_process_runner.test_main()
1406  ```
1407  """
1408  # Inject tearDownModule() to shut down all pool runners. Active pool runners
1409  # will block the program from exiting. This is necessary for global pool
1410  # runners. We tried atexit in the past, and it doesn't work in some
1411  # deployment.
1412  old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule',
1413                                 None)
1414
1415  def tear_down_module():
1416    _shutdown_all_pool_runners()
1417    if old_tear_down_module is not None:
1418      old_tear_down_module()
1419
1420  setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module)
1421  multi_process_lib.test_main()
1422