1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base testing class for strategies that require multiple nodes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import copy
23import json
24import multiprocessing
25import os
26import subprocess
27import sys
28import threading
29import unittest
30
31import six
32
33_portpicker_import_error = None
34try:
35  import portpicker  # pylint: disable=g-import-not-at-top
36except (ImportError, ModuleNotFoundError) as _error:  # pylint: disable=invalid-name
37  _portpicker_import_error = _error
38  portpicker = None
39
40# pylint: disable=g-import-not-at-top
41from tensorflow.core.protobuf import config_pb2
42from tensorflow.core.protobuf import rewriter_config_pb2
43from tensorflow.python.client import session
44from tensorflow.python.distribute import distribute_coordinator as dc
45from tensorflow.python.distribute import multi_process_runner
46from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
47from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
48from tensorflow.python.eager import context
49from tensorflow.python.eager import remote
50from tensorflow.python.framework import errors
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import test_util
53from tensorflow.python.platform import test
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.training import coordinator
56from tensorflow.python.training import server_lib
57from tensorflow.python.util import deprecation
58from tensorflow.python.util import nest
59from tensorflow.python.util.compat import collections_abc
60from tensorflow.python.util.tf_export import tf_export
61
62
63original_run_std_server = dc._run_std_server  # pylint: disable=protected-access
64
65ASSIGNED_PORTS = set()
66lock = threading.Lock()
67
68
69def pick_unused_port():
70  """Returns an unused and unassigned local port."""
71  if _portpicker_import_error:
72    raise _portpicker_import_error  # pylint: disable=raising-bad-type
73
74  global ASSIGNED_PORTS
75  with lock:
76    while True:
77      try:
78        port = portpicker.pick_unused_port()
79      except portpicker.NoFreePortFoundError:
80        raise unittest.SkipTest('Flakes in portpicker library do not represent '
81                                'TensorFlow errors.')
82      if port > 10000 and port not in ASSIGNED_PORTS:
83        ASSIGNED_PORTS.add(port)
84        logging.info('Using local port %r', port)
85        return port
86
87
88def _create_cluster(num_workers,
89                    num_ps,
90                    has_chief=False,
91                    has_eval=False,
92                    protocol='grpc',
93                    worker_config=None,
94                    ps_config=None,
95                    eval_config=None,
96                    worker_name='worker',
97                    ps_name='ps',
98                    chief_name='chief'):
99  """Creates and starts local servers and returns the cluster_spec dict."""
100  if _portpicker_import_error:
101    raise _portpicker_import_error  # pylint: disable=raising-bad-type
102  worker_ports = [pick_unused_port() for _ in range(num_workers)]
103  ps_ports = [pick_unused_port() for _ in range(num_ps)]
104
105  cluster_dict = {}
106  if num_workers > 0:
107    cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports]
108  if num_ps > 0:
109    cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports]
110  if has_eval:
111    cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
112  if has_chief:
113    cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()]
114
115  cs = server_lib.ClusterSpec(cluster_dict)
116
117  for i in range(num_workers):
118    server_lib.Server(
119        cs,
120        job_name=worker_name,
121        protocol=protocol,
122        task_index=i,
123        config=worker_config,
124        start=True)
125
126  for i in range(num_ps):
127    server_lib.Server(
128        cs,
129        job_name=ps_name,
130        protocol=protocol,
131        task_index=i,
132        config=ps_config,
133        start=True)
134
135  if has_chief:
136    server_lib.Server(
137        cs,
138        job_name=chief_name,
139        protocol=protocol,
140        task_index=0,
141        config=worker_config,
142        start=True)
143
144  if has_eval:
145    server_lib.Server(
146        cs,
147        job_name='evaluator',
148        protocol=protocol,
149        task_index=0,
150        config=eval_config,
151        start=True)
152
153  return cluster_dict
154
155
156def create_in_process_cluster(num_workers,
157                              num_ps,
158                              has_chief=False,
159                              has_eval=False,
160                              rpc_layer='grpc'):
161  """Create an in-process cluster that consists of only standard server."""
162  # Leave some memory for cuda runtime.
163  gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
164  worker_config = config_pb2.ConfigProto()
165  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
166
167  # The cluster may hang if workers don't have enough inter_op threads. See
168  # b/172296720 for more details.
169  if multiprocessing.cpu_count() < 4:
170    worker_config.inter_op_parallelism_threads = 4
171
172  # Enable collective ops which has no impact on non-collective ops.
173  # TODO(yuefengz, tucker): removing this after we move the initialization of
174  # collective mgr to the session level.
175  if has_chief:
176    worker_config.experimental.collective_group_leader = (
177        '/job:chief/replica:0/task:0')
178  else:
179    worker_config.experimental.collective_group_leader = (
180        '/job:worker/replica:0/task:0')
181
182  ps_config = config_pb2.ConfigProto()
183  ps_config.device_count['GPU'] = 0
184
185  eval_config = config_pb2.ConfigProto()
186  eval_config.experimental.collective_group_leader = ''
187
188  # Create in-process servers. Once an in-process tensorflow server is created,
189  # there is no way to terminate it. So we create one cluster per test process.
190  # We could've started the server in another process, we could then kill that
191  # process to terminate the server. The reasons why we don't want multiple
192  # processes are
193  # 1) it is more difficult to manage these processes;
194  # 2) there is something global in CUDA such that if we initialize CUDA in the
195  # parent process, the child process cannot initialize it again and thus cannot
196  # use GPUs (https://stackoverflow.com/questions/22950047).
197  cluster = None
198  try:
199    cluster = _create_cluster(
200        num_workers,
201        num_ps=num_ps,
202        has_chief=has_chief,
203        has_eval=has_eval,
204        worker_config=worker_config,
205        ps_config=ps_config,
206        eval_config=eval_config,
207        protocol=rpc_layer)
208  except errors.UnknownError as e:
209    if 'Could not start gRPC server' in e.message:
210      raise unittest.SkipTest('Cannot start std servers.')
211    else:
212      raise
213  return cluster
214
215
216class MultiProcessCluster(object):
217  """A cluster of TensorFlow servers in separate processes.
218
219  This class is not thread-safe.
220  """
221
222  def __init__(self, cluster_resolver):
223    self._cluster_resolver = cluster_resolver
224    self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
225    self._rpc_layer = cluster_resolver.rpc_layer
226    self._start_events = {}
227    self._finish_events = {}
228    self._mpr_manager = multi_process_runner.manager()
229
230    def task_function(start_events, finish_events):
231      cluster_resolver = TFConfigClusterResolver()
232      cluster_spec = cluster_resolver.cluster_spec()
233      task_type = cluster_resolver.task_type
234      task_id = cluster_resolver.task_id
235      rpc_layer = cluster_resolver.rpc_layer
236
237      logging.info(
238          'Starting server with cluster_spec = %r, task_type = %r, '
239          'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
240          rpc_layer)
241
242      # TODO(yuefengz): support GPU clusters.
243      server_config = config_pb2.ConfigProto()
244      server_config.device_count['GPU'] = 0
245
246      # Set the environment variable to prevent hanging upon job failure and
247      # restart. Note that it defaults to 'use_caller' at Google, but defaults
248      # to False in OSS.
249      os.environ['GRPC_FAIL_FAST'] = 'use_caller'
250
251      server_lib.Server(
252          cluster_spec,
253          job_name=task_type,
254          protocol=rpc_layer,
255          task_index=task_id,
256          config=server_config,
257          start=True)
258
259      start_event = start_events[task_type][task_id]
260      start_event.set()
261
262      finish_event = finish_events[task_type][task_id]
263      finish_event.wait()
264
265      os._exit(0)  # pylint: disable=protected-access
266
267    self._task_function = task_function
268    self._mpr = None
269
270  def start(self):
271    """Starts one TensorFlow server for each task in the cluster_resolver.
272
273    It will wait until all the servers are up before returns.
274    """
275    if self._mpr:
276      raise ValueError('The cluster has already been started.')
277    for task_type, task_addresses in self._cluster_spec.items():
278      self._start_events[task_type] = []
279      self._finish_events[task_type] = []
280      for _ in task_addresses:
281        self._start_events[task_type].append(self._mpr_manager.Event())
282        self._finish_events[task_type].append(self._mpr_manager.Event())
283
284    self._mpr = multi_process_runner.MultiProcessRunner(
285        self._task_function,
286        self._cluster_spec,
287        args=(self._start_events, self._finish_events),
288        rpc_layer=self._rpc_layer,
289        stream_output=False,
290        return_output=False,
291        use_dill_for_args=False)
292    self._mpr.start()
293    for task_type, task_addresses in self._cluster_spec.items():
294      for i in range(len(task_addresses)):
295        self._start_events[task_type][i].wait()
296
297  def stop(self):
298    """Stops all the servers."""
299    for task_type, task_addresses in self._cluster_spec.items():
300      for i in range(len(task_addresses)):
301        self._finish_events[task_type][i].set()
302    try:
303      self._mpr.join()
304    except multi_process_runner.UnexpectedSubprocessExitError:
305      # TODO(yuefengz): investigate why processes exit with 255.
306      pass
307    self._mpr = None
308    self._start_events = {}
309    self._finish_events = {}
310
311  def kill_task(self, task_type, task_id):
312    """Kill a server given task_type and task_id.
313
314    Args:
315      task_type: the type of the task such as "worker".
316      task_id: the id the task such as 1.
317    """
318    assert self._mpr
319    if (not self._start_events[task_type][task_id].is_set() or
320        self._finish_events[task_type][task_id].is_set()):
321      raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
322
323    self._finish_events[task_type][task_id].set()
324    self._mpr._processes[(task_type, task_id)].join()
325
326  def start_task(self, task_type, task_id):
327    """Starts a server given task_type and task_id.
328
329    Args:
330      task_type: the type of the task such as "worker".
331      task_id: the id the task such as 1.
332
333    Raises:
334      ValueError: if the server alreay exists.
335    """
336    assert self._mpr
337
338    if (not self._start_events[task_type][task_id].is_set() or
339        not self._finish_events[task_type][task_id].is_set()):
340      raise ValueError(
341          'The task %s:%d is still alive. You cannot start another one.' %
342          (task_type, task_id))
343    self._start_events[task_type][task_id] = self._mpr_manager.Event()
344    self._finish_events[task_type][task_id] = self._mpr_manager.Event()
345    self._mpr.start_single_process(task_type=task_type, task_id=task_id)
346    self._start_events[task_type][task_id].wait()
347
348  @property
349  def cluster_resolver(self):
350    return copy.deepcopy(self._cluster_resolver)
351
352
353def create_multi_process_cluster(num_workers,
354                                 num_ps,
355                                 has_chief=False,
356                                 has_eval=False,
357                                 rpc_layer='grpc'):
358  cluster_spec = create_cluster_spec(
359      has_chief=has_chief,
360      num_workers=num_workers,
361      num_ps=num_ps,
362      has_eval=has_eval)
363
364  cluster = MultiProcessCluster(
365      SimpleClusterResolver(
366          server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer))
367  cluster.start()
368  return cluster
369
370
371@tf_export(
372    '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[])
373def create_cluster_spec(has_chief=False,
374                        num_workers=1,
375                        num_ps=0,
376                        has_eval=False):
377  """Create a cluster spec with tasks with unused local ports.
378
379  This utility finds available ports at localhost, and returns a dict that
380  represents the cluster spec that utilizes those ports, according to the
381  arguments. The dict representing the cluster spec contains task types, and
382  their instances' addresses. Note that this is usually only for testing purpose
383  using multiple processes in the local machine, and should not be used for real
384  multi-worker TensorFlow programs, where the addresses need to point to the
385  processes at separate machines.
386
387  This util is useful when creating the `cluster_spec` arg for
388  `tf.__internal__.distribute.multi_process_runner.run`.
389
390  Args:
391    has_chief: Whether the generated cluster spec should contain "chief" task
392      type.
393    num_workers: Number of workers to use in the cluster spec.
394    num_ps: Number of parameter servers to use in the cluster spec.
395    has_eval: Whether this cluster spec has evaluator.
396
397  Returns:
398    A dict that represents the cluster spec using localhost ports for the tasks.
399
400  Example:
401
402  ```python
403  cluster_spec =
404  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
405      has_chief=True, num_workers=2, num_ps=2)
406  # An example of cluster_spec is
407  # {'chief': ['localhost:23381'],
408  # 'worker': ['localhost:19197', 'localhost:22903'],
409  # 'ps': ['localhost:16912', 'localhost:21535']}
410
411  cluster_spec =
412  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
413      has_chief=False, num_workers=0, num_ps=0, has_eval=True)
414  # An example of cluster_spec is
415  # {'evaluator': ['localhost:23381']}
416  ```
417  """
418  if _portpicker_import_error:
419    raise _portpicker_import_error  # pylint: disable=raising-bad-type
420
421  cluster_spec = {}
422  if has_chief:
423    cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
424  if num_workers:
425    cluster_spec['worker'] = [
426        'localhost:%s' % pick_unused_port() for _ in range(num_workers)
427    ]
428  if num_ps:
429    cluster_spec['ps'] = [
430        'localhost:%s' % pick_unused_port() for _ in range(num_ps)
431    ]
432  if has_eval:
433    cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
434  return cluster_spec
435
436
437@contextlib.contextmanager
438def skip_if_grpc_server_cant_be_started(test_obj):
439  try:
440    yield
441  except errors.UnknownError as e:
442    if 'Could not start gRPC server' in e.message:
443      reason = 'Cannot start std servers.'
444      test_obj.test_skipped_reason = reason
445      test_obj.skipTest(reason)
446    else:
447      raise
448
449
450class MultiWorkerTestBase(test.TestCase):
451  """Base class for testing multi node strategy and dataset."""
452
453  @classmethod
454  def setUpClass(cls, num_workers=2, num_ps=1):  # pylint: disable=g-missing-super-call
455    """Create a local cluster with 2 workers."""
456    cls._cluster_spec = create_in_process_cluster(num_workers=num_workers,
457                                                  num_ps=num_ps)
458    cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
459
460  def setUp(self):
461    # We only cache the session in one test because another test may have a
462    # different session config or master target.
463    self._thread_local = threading.local()
464    self._thread_local.cached_session = None
465    self._coord = coordinator.Coordinator()
466
467  @contextlib.contextmanager
468  def session(self, graph=None, config=None, target=None):
469    """Create a test session with master target set to the testing cluster.
470
471    Creates a test session that connects to the local testing cluster.
472
473    Args:
474      graph: Optional graph to use during the returned session.
475      config: An optional config_pb2.ConfigProto to use to configure the
476        session.
477      target: the target of session to connect to.
478
479    Yields:
480      A Session object that should be used as a context manager to surround
481      the graph building and execution code in a test case.
482    """
483    config = self._create_config(config)
484
485    if target is None:
486      target = self._default_target
487    with session.Session(graph=graph, config=config, target=target) as sess:
488      yield sess
489
490  @contextlib.contextmanager
491  # TODO(b/117573461): Overwrite self.evaluate() to use this function.
492  def cached_session(self, graph=None, config=None, target=None):
493    """Create a test session with master target set to the testing cluster.
494
495    Creates a test session that connects to the local testing cluster.
496    The session is only created once per test and then reused.
497
498    Args:
499      graph: Optional graph to use during the returned session.
500      config: An optional config_pb2.ConfigProto to use to configure the
501        session.
502      target: the target of session to connect to.
503
504    Yields:
505      A Session object that should be used as a context manager to surround
506      the graph building and execution code in a test case. Note that the
507      session will live until the end of the test.
508    """
509    config = self._create_config(config)
510
511    if target is None:
512      target = self._default_target
513    if getattr(self._thread_local, 'cached_session', None) is None:
514      self._thread_local.cached_session = session.Session(
515          graph=None, config=config, target=target)
516    sess = self._thread_local.cached_session
517    with sess.graph.as_default(), sess.as_default():
518      yield sess
519
520  def _create_config(self, config):
521    if config is None:
522      config = config_pb2.ConfigProto(allow_soft_placement=True)
523    else:
524      config = copy.deepcopy(config)
525    # Don't perform optimizations for tests so we don't inadvertently run
526    # gpu ops on cpu
527    config.graph_options.optimizer_options.opt_level = -1
528    config.graph_options.rewrite_options.constant_folding = (
529        rewriter_config_pb2.RewriterConfig.OFF)
530
531    return config
532
533  def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode,
534                  *args, **kwargs):
535
536    def wrapped_client_fn():
537      with self._coord.stop_on_exception():
538        client_fn(task_type, task_id, num_gpus, *args, **kwargs)
539
540    if eager_mode:
541      with context.eager_mode():
542        wrapped_client_fn()
543    else:
544      with context.graph_mode():
545        wrapped_client_fn()
546
547  def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
548                                 **kwargs):
549    """Runs several clients for between-graph replication.
550
551    Args:
552      client_fn: a function that needs to accept `task_type`, `task_id`,
553        `num_gpus`.
554      cluster_spec: a dict specifying jobs in a cluster.
555      num_gpus: number of GPUs per worker.
556      *args: will be passed to `client_fn`.
557      **kwargs: will be passed to `client_fn`.
558    """
559    threads = []
560    for task_type in ['chief', 'worker']:
561      for task_id in range(len(cluster_spec.get(task_type, []))):
562        t = threading.Thread(
563            target=self._run_client,
564            args=(client_fn, task_type, task_id, num_gpus,
565                  context.executing_eagerly()) + args,
566            kwargs=kwargs)
567        t.start()
568        threads.append(t)
569    self._coord.join(threads)
570
571
572class SingleWorkerTestBaseGraph(MultiWorkerTestBase):
573  """Base class for testing remote single worker strategy graph and dataset."""
574
575  @classmethod
576  def setUpClass(cls):
577    super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1)
578
579
580class SingleWorkerTestBaseEager(test.TestCase):
581  """Base class for testing remote single worker strategy eager and dataset."""
582
583  def setUp(self):
584    super(SingleWorkerTestBaseEager, self).setUp()
585    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
586    remote.connect_to_remote_host(workers[0].target)
587
588  def cached_session(self):
589    return DummySession()
590
591
592class DummySession(object):
593
594  def __enter__(self):
595    return
596
597  def __exit__(self, exception_type, exception_value, traceback):
598    pass
599
600
601class MockOsEnv(collections_abc.Mapping):
602  """A class that allows per-thread TF_CONFIG."""
603
604  def __init__(self, *args):
605    self._dict = dict()
606    self._thread_local = threading.local()
607    super(MockOsEnv, self).__init__(*args)
608
609  def get(self, key, default=None):
610    if not hasattr(self._thread_local, 'dict'):
611      self._thread_local.dict = dict()
612    if key == 'TF_CONFIG':
613      return dict.get(self._thread_local.dict, key, default)
614    else:
615      return dict.get(self._dict, key, default)
616
617  def __getitem__(self, key):
618    if not hasattr(self._thread_local, 'dict'):
619      self._thread_local.dict = dict()
620    if key == 'TF_CONFIG':
621      return dict.__getitem__(self._thread_local.dict, key)
622    else:
623      return dict.__getitem__(self._dict, key)
624
625  def __setitem__(self, key, val):
626    if not hasattr(self._thread_local, 'dict'):
627      self._thread_local.dict = dict()
628    if key == 'TF_CONFIG':
629      return dict.__setitem__(self._thread_local.dict, key, val)
630    else:
631      return dict.__setitem__(self._dict, key, val)
632
633  def __iter__(self):
634    if not hasattr(self._thread_local, 'dict'):
635      self._thread_local.dict = dict()
636    for x in self._thread_local.dict:
637      yield x
638    for x in self._dict:
639      yield x
640
641  def __len__(self):
642    if not hasattr(self._thread_local, 'dict'):
643      self._thread_local.dict = dict()
644    return self._thread_local.dict.__len__() + self._dict.__len__()
645
646
647class IndependentWorkerTestBase(test.TestCase):
648  """Testing infra for independent workers."""
649
650  def _make_mock_run_std_server(self):
651
652    def _mock_run_std_server(*args, **kwargs):
653      """Returns the std server once all threads have started it."""
654      with skip_if_grpc_server_cant_be_started(self):
655        ret = original_run_std_server(*args, **kwargs)
656      # Wait for all std servers to be brought up in order to reduce the chance
657      # of remote sessions taking local ports that have been assigned to std
658      # servers. Only call this barrier the first time this function is run for
659      # each thread.
660      if not getattr(self._thread_local, 'server_started', False):
661        self._barrier.wait()
662      self._thread_local.server_started = True
663      return ret
664
665    return _mock_run_std_server
666
667  def setUp(self):
668    self._mock_os_env = MockOsEnv()
669    self._mock_context = test.mock.patch.object(os, 'environ',
670                                                self._mock_os_env)
671    self._coord = coordinator.Coordinator()
672    super(IndependentWorkerTestBase, self).setUp()
673    self._mock_context.__enter__()
674    # threading local object to be shared by all threads
675    self._thread_local = threading.local()
676
677  def tearDown(self):
678    self._mock_context.__exit__(None, None, None)
679    super(IndependentWorkerTestBase, self).tearDown()
680
681  def _task_thread(self, task_fn, tf_config, executing_eagerly, *args,
682                   **kwargs):
683    with self._coord.stop_on_exception():
684      os.environ['TF_CONFIG'] = json.dumps(tf_config)
685      # Force the new thread simulating a worker to run in the same context
686      # mode as the parent thread does.
687      if executing_eagerly:
688        with context.eager_mode():
689          task_fn(*args, **kwargs)
690      else:
691        with ops.Graph().as_default(), context.graph_mode():
692          task_fn(*args, **kwargs)
693
694  def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
695                          *args, **kwargs):
696    """Run tasks in a thread.
697
698    If `tf_config` is provided, use it for the new thread; if not, construct one
699    from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new
700    thread to be set as `TF_CONFIG` environment.
701
702    Args:
703      task_fn: The function to run in the new thread.
704      cluster_spec: The cluster spec.
705      task_type: The task type.
706      task_id: The task id.
707      *args: Additional positional arguments to provide to the thread's task_fn.
708      **kwargs: Additional keyword arguments to provide to the thread's task_fn.
709        If `tf_config` is provided, that dict will be used for the TF_CONFIG for
710        the new thread.
711
712    Returns:
713      The thread that has started.
714    """
715    tf_config = kwargs.pop('tf_config', None)
716    if tf_config is None:
717      if task_type:
718        tf_config = {
719            'cluster': cluster_spec,
720            'task': {
721                'type': task_type,
722                'index': task_id
723            }
724        }
725      else:
726        tf_config = {
727            'cluster': cluster_spec,
728        }
729    t = threading.Thread(
730        target=self._task_thread,
731        args=(task_fn, tf_config, context.executing_eagerly()) + args,
732        kwargs=kwargs)
733    t.start()
734    return t
735
736  def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
737                                    **kwargs):
738    # The task_fn should create std_server by itself.
739    threads = {}
740    for task_type in cluster_spec.keys():
741      threads[task_type] = []
742      for task_id in range(len(cluster_spec[task_type])):
743        t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
744                                     *args, **kwargs)
745        threads[task_type].append(t)
746    return threads
747
748  def join_independent_workers(self, worker_threads):
749    with skip_if_grpc_server_cant_be_started(self):
750      self._coord.join(worker_threads)
751
752
753class MultiWorkerMultiProcessTest(test.TestCase):
754  """Testing infra for independent workers using multiple processes."""
755
756  def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id):
757    env = os.environ.copy()
758    env['TF_CONFIG'] = json.dumps({
759        'cluster': cluster_spec,
760        'task': {
761            'type': task_type,
762            'index': task_id
763        }
764    })
765    return subprocess.Popen(
766        cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
767
768  @deprecation.deprecated(
769      None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
770      'requiring multiple processes should use `multi_process_runner` for '
771      'better support of log printing, streaming, and more functionality.')
772  def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
773    """Run `cmd_args` in a process for each task in `cluster_spec`."""
774    processes = {}
775    for task_type in cluster_spec.keys():
776      processes[task_type] = []
777      for task_id in range(len(cluster_spec[task_type])):
778        p = self._run_task_in_process(cmd_args, cluster_spec, task_type,
779                                      task_id)
780        processes[task_type].append(p)
781    return processes
782
783  @deprecation.deprecated(
784      None, '`join_independent_workers` is deprecated; any new test '
785      'requiring multiple processes should use `multi_process_runner` for '
786      'better support of log printing, streaming, and more functionality.')
787  def join_independent_workers(self, worker_processes):
788    return_codes = []
789    for p in nest.flatten(worker_processes):
790      try:
791        # Calling p.wait() will hang if we don't consume its output.
792        p.communicate()
793      except ValueError:
794        # The output of the process may have been consumed, in which case
795        # calling `p.communicate()` will raise a ValueError.
796        pass
797      finally:
798        return_codes.append(p.returncode)
799    for return_code in return_codes:
800      self.assertEqual(return_code, 0)
801
802  @deprecation.deprecated(
803      None, '`stream_stderr` is deprecated; any new test '
804      'requiring multiple processes should use `multi_process_runner` for '
805      'better support of log printing, streaming, and more functionality.')
806  def stream_stderr(self, processes, print_only_first=False):
807    """Consume stderr of all processes and print to stdout.
808
809    To reduce the amount of logging, caller can set print_only_first to True.
810    In that case, this function only prints stderr from the first process of
811    each type.
812
813    Args:
814      processes: A dictionary from process type string -> list of processes.
815      print_only_first: If true, only print output from first process of each
816        type.
817    """
818
819    def _stream_stderr_single_process(process, type_string, index,
820                                      print_to_stdout):
821      """Consume a single process's stderr and optionally print to stdout."""
822      while True:
823        output = process.stderr.readline()
824        if not output and process.poll() is not None:
825          break
826        if output and print_to_stdout:
827          print('{}{} {}'.format(type_string, index, output.strip()))
828          sys.stdout.flush()
829
830    stream_threads = []
831    for process_type, process_list in six.iteritems(processes):
832      for i in range(len(process_list)):
833        print_to_stdout = (not print_only_first) or (i == 0)
834        thread = threading.Thread(
835            target=_stream_stderr_single_process,
836            args=(process_list[i], process_type, i, print_to_stdout))
837        thread.start()
838        stream_threads.append(thread)
839    for thread in stream_threads:
840      thread.join()
841
842
843def get_tf_config_task():
844  return json.loads(os.environ['TF_CONFIG'])['task']
845
846
847def get_tf_config_cluster_spec():
848  return json.loads(os.environ['TF_CONFIG'])['cluster']
849
850
851def get_task_type():
852  return get_tf_config_task()['type']
853
854
855def get_task_index():
856  return get_tf_config_task()['index']
857
858
859def is_chief():
860  return ('chief' not in get_tf_config_cluster_spec()
861          and get_task_type() == 'worker'
862          and get_task_index() == 0)
863