1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import contextlib
22import functools
23import threading
24import weakref
26from tensorflow.python import pywrap_tfe
27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
28from tensorflow.python.autograph.impl import api as autograph
29from tensorflow.python.distribute import distribute_lib
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import shared_variable_creator
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import device as tf_device
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import summary_ops_v2
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training import coordinator
42def call_for_each_replica(strategy, fn, args=None, kwargs=None):
43  """Call `fn` on each worker devices(replica).
45  It's highly recommended to wrap the call to this function inside a
46  `tf.function`, otherwise the performance is poor.
48  Args:
49    strategy: `tf.distribute.Strategy`.
50    fn: function to call on each worker devices.
51    args: positional arguments to `fn`.
52    kwargs: keyword arguments to `fn`.
54  Returns:
55    Wrapped returned value of `fn` from all replicas.
56  """
57  if args is None:
58    args = ()
59  if kwargs is None:
60    kwargs = {}
62  if isinstance(fn, def_function.Function):
63    if strategy not in _cfer_fn_cache:
64      _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
65    wrapped = _cfer_fn_cache[strategy].get(fn)
66    if wrapped is None:
67      # We need to wrap fn such that it triggers _call_for_each_replica inside
68      # the tf.function. We use _clone() instead of @tf.function wrapped
69      # call_for_each_replica() because we would like to retain the arguments to
70      # the @tf.function decorator of fn.
71      wrapped = fn._clone(  # pylint: disable=protected-access
72          python_function=functools.partial(call_for_each_replica, strategy,
73                                            fn.python_function))
74      _cfer_fn_cache[strategy][fn] = wrapped
75    return wrapped(args, kwargs)
77  if context.executing_eagerly():
78    logging.log_first_n(
79        logging.WARN, "Using %s eagerly has significant "
80        "overhead currently. We will be working on improving "
81        "this in the future, but for now please wrap "
82        "`call_for_each_replica` or `experimental_run` or "
83        "`run` inside a tf.function to get "
84        "the best performance." % strategy.__class__.__name__, 5)
85  else:
86    # When a tf.function is wrapped to trigger _call_for_each_replica (see
87    # the other branch above), AutoGraph stops conversion at
88    # _call_for_each_replica itself (TF library functions are allowlisted).
89    # This makes sure that the Python function that originally passed to
90    # the tf.function is still converted.
91    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
93  return _call_for_each_replica(strategy, fn, args, kwargs)
96# Per strategy cache for call_for_each_replica def_function.Function objects.
97_cfer_fn_cache = weakref.WeakKeyDictionary()
101def _enter_graph(g, eager, creator_stack=None):
102  """Context manager for selecting a graph and maybe eager mode."""
103  if eager:
104    with g.as_default(), context.eager_mode():
105      if creator_stack is not None:
106        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
107      yield
108  else:
109    with g.as_default():
110      if creator_stack is not None:
111        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
112      yield
115def _cpu_device(device):
116  cpu_device = tf_device.DeviceSpec.from_string(device)
117  cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
118  return cpu_device.to_string()
121class _RequestedStop(Exception):  # pylint: disable=g-bad-exception-name
122  pass
125def _call_for_each_replica(distribution, fn, args, kwargs):
126  """Run `fn` in separate threads, once per replica/worker device.
128  Args:
129    distribution: the DistributionStrategy object.
130    fn: function to run (will be run once per replica, each in its own thread).
131    args: positional arguments for `fn`
132    kwargs: keyword arguments for `fn`.
134  Returns:
135    Merged return value of `fn` across all replicas.
137  Raises:
138    RuntimeError: If fn() calls get_replica_context().merge_call() a different
139        number of times from the available devices.
140  """
141  # TODO(josh11b): Add this option once we add synchronization to variable
142  # creation. Until then, this is pretty unsafe to use.
143  run_concurrently = False
144  if not context.executing_eagerly():
145    # Needed for per-thread device, etc. contexts in graph mode.
146    ops.get_default_graph().switch_to_thread_local()
148  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
150  shared_variable_store = {}
151  devices = distribution.extended.worker_devices
153  # TODO(isaprykin): Create these threads once instead of during every call.
154  threads = []
155  for index in range(len(devices)):
156    variable_creator_fn = shared_variable_creator.make_fn(
157        shared_variable_store, index)
158    t = _MirroredReplicaThread(
159        distribution, coord, index, devices, variable_creator_fn, fn,
160        distribute_utils.select_replica(index, args),
161        distribute_utils.select_replica(index, kwargs))
162    threads.append(t)
164  for t in threads:
165    t.start()
167  # When `fn` starts `should_run` event is set on _MirroredReplicaThread
168  # (`MRT`) threads. The execution waits until
169  # `MRT.has_paused` is set, which indicates that either `fn` is
170  # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
171  # complete, then `MRT.done` is set to True.  Otherwise, arguments
172  # of `get_replica_context().merge_call` from all paused threads are grouped
173  # and the `merge_fn` is performed.  Results of the
174  # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
175  # Each such `get_replica_context().merge_call` call returns the
176  # `MRT.merge_result` for that thread when `MRT.should_run` event
177  # is reset again. Execution of `fn` resumes.
179  try:
180    with coord.stop_on_exception():
181      all_done = False
182      while not all_done and not coord.should_stop():
183        done = []
184        if run_concurrently:
185          for t in threads:
186            t.should_run.set()
187          for t in threads:
188            t.has_paused.wait()
189            t.has_paused.clear()
190            if coord.should_stop():
191              return None
192            done.append(t.done)
193        else:
194          for t in threads:
195            t.should_run.set()
196            t.has_paused.wait()
197            t.has_paused.clear()
198            if coord.should_stop():
199              return None
200            done.append(t.done)
201        if coord.should_stop():
202          return None
203        all_done = all(done)
204        if not all_done:
205          if any(done):
206            raise RuntimeError("Some replicas made a different number of "
207                               "replica_context().merge_call() calls.")
208          # get_replica_context().merge_call() case
209          merge_args = distribute_utils.regroup(
210              tuple(t.merge_args for t in threads))
211          merge_kwargs = distribute_utils.regroup(
212              tuple(t.merge_kwargs for t in threads))
213          # We capture the name_scope of the MRT when we call merge_fn
214          # to ensure that if we have opened a name scope in the MRT,
215          # it will be respected when executing the merge function. We only
216          # capture the name_scope from the first MRT and assume it is
217          # the same for all other MRTs.
218          mtt_captured_name_scope = threads[0].captured_name_scope
219          mtt_captured_var_scope = threads[0].captured_var_scope
220          # Capture and merge the control dependencies from all the threads.
221          mtt_captured_control_deps = set()
222          for t in threads:
223            mtt_captured_control_deps.update(t.captured_control_deps)
224          with ops.name_scope(mtt_captured_name_scope),\
225              ops.control_dependencies(mtt_captured_control_deps), \
226              variable_scope.variable_scope(mtt_captured_var_scope):
227            merge_result = threads[0].merge_fn(distribution, *merge_args,
228                                               **merge_kwargs)
229          for r, t in enumerate(threads):
230            t.merge_result = distribute_utils.select_replica(r, merge_result)
231  finally:
232    for t in threads:
233      t.should_run.set()
234    coord.join(threads)
236  return distribute_utils.regroup(tuple(t.main_result for t in threads))
239class _MirroredReplicaThread(threading.Thread):
240  """A thread that runs() a function on a device."""
242  def __init__(self, dist, coord, replica_id, devices, variable_creator_fn,
243               fn, args, kwargs):
244    super(_MirroredReplicaThread, self).__init__()
245    self.coord = coord
246    self.distribution = dist
247    self.devices = devices
248    self.replica_id = replica_id
249    self.replica_id_in_sync_group = (
250        dist.extended._get_replica_id_in_sync_group(replica_id))  # pylint: disable=protected-access
252    self.variable_creator_fn = variable_creator_fn
253    # State needed to run and return the results of `fn`.
254    self.main_fn = fn
255    self.main_args = args
256    self.main_kwargs = kwargs
257    self.main_result = None
258    self.done = False
259    # State needed to run the next merge_call() (if any) requested via
260    # ReplicaContext.
261    self.merge_fn = None
262    self.merge_args = None
263    self.merge_kwargs = None
264    self.merge_result = None
265    self.captured_name_scope = None
266    self.captured_var_scope = None
267    # We use a thread.Event for the main thread to signal when this
268    # thread should start running (`should_run`), and another for
269    # this thread to transfer control back to the main thread
270    # (`has_paused`, either when it gets to a
271    # `get_replica_context().merge_call` or when `fn` returns). In
272    # either case the event starts cleared, is signaled by calling
273    # set(). The receiving thread waits for the signal by calling
274    # wait() and then immediately clearing the event using clear().
275    self.should_run = threading.Event()
276    self.has_paused = threading.Event()
277    # These fields have to do with inheriting various contexts from the
278    # parent thread:
279    context.ensure_initialized()
280    ctx = context.context()
281    self.in_eager = ctx.executing_eagerly()
282    self.record_thread_local_summary_state()
283    self.record_thread_local_eager_context_state()
284    self.context_device_policy = (
285        pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
286            ctx._context_handle))  # pylint: disable=protected-access
287    self.graph = ops.get_default_graph()
288    with ops.init_scope():
289      self._init_in_eager = context.executing_eagerly()
290      self._init_graph = ops.get_default_graph()
291    self._variable_creator_stack = self.graph._variable_creator_stack[:]  # pylint: disable=protected-access
292    self._var_scope = variable_scope.get_variable_scope()
293    # Adding a "/" at end lets us re-enter this scope later.
294    self._name_scope = self.graph.get_name_scope()
295    if self._name_scope:
296      self._name_scope += "/"
297    if self.replica_id > 0:
298      if not self._name_scope:
299        self._name_scope = ""
300      self._name_scope += "replica_%d/" % self.replica_id
302  def run(self):
303    self.should_run.wait()
304    self.should_run.clear()
305    try:
306      if self.coord.should_stop():
307        return
308      self.restore_thread_local_summary_state()
309      self.restore_thread_local_eager_context_state()
310      # TODO(josh11b): Use current logical device instead of 0 here.
311      with self.coord.stop_on_exception(), \
312          _enter_graph(self._init_graph, self._init_in_eager), \
313          _enter_graph(self.graph, self.in_eager,
314                       self._variable_creator_stack), \
315          context.device_policy(self.context_device_policy), \
316          _MirroredReplicaContext(self.distribution,
317                                  self.replica_id_in_sync_group), \
318          ops.device(self.devices[self.replica_id]), \
319          ops.name_scope(self._name_scope), \
320          variable_scope.variable_scope(
321              self._var_scope, reuse=self.replica_id > 0), \
322          variable_scope.variable_creator_scope(self.variable_creator_fn):
323        self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
324        self.done = True
325    finally:
326      self.has_paused.set()
328  def record_thread_local_summary_state(self):
329    """Record the thread local summary state in self."""
330    # TODO(slebedev): is this still relevant? the referenced bug is closed.
331    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
332    self._summary_step = summary_state.step
333    self._summary_writer = summary_state.writer
334    self._summary_recording = summary_state.is_recording
335    self._summary_recording_distribution_strategy = (
336        summary_state.is_recording_distribution_strategy)
338  def restore_thread_local_summary_state(self):
339    """Restore thread local summary state from self."""
340    # TODO(slebedev): is this still relevant? the referenced bug is closed.
341    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
342    summary_state.step = self._summary_step
343    summary_state.writer = self._summary_writer
344    summary_state.is_recording = self._summary_recording
345    summary_state.is_recording_distribution_strategy = (
346        self._summary_recording_distribution_strategy)
348  def record_thread_local_eager_context_state(self):
349    ctx = context.context()
350    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
351    self._eager_context_op_callbacks = eager_context_state.op_callbacks
352    # TODO(b/125892694): record other fields in EagerContext.
354  def restore_thread_local_eager_context_state(self):
355    ctx = context.context()
356    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
357    eager_context_state.op_callbacks = self._eager_context_op_callbacks
358    # TODO(b/125892694): record other fields in EagerContext.
361class _MirroredReplicaContext(distribute_lib.ReplicaContext):
362  """ReplicaContext for synchronized replica."""
364  def _merge_call(self, fn, args, kwargs):
365    """`merge_call()` implementation for synchronized replica.
367    This pauses the current replica thread and passes `fn` and its arguments to
368    the main thread. The main thread will wait until all replicas pause, then
369    invoke `fn` with grouped arguments. The current replica thread will continue
370    after `fn` completes.
372    See `_call_for_each_replica` for the logic in the main thread.
374    Args:
375      fn: a function that is called in cross replica context with grouped
376        arguments from each replica. `fn` should returns grouped values.
377      args: positional arguments to `fn`.
378      kwargs: keyward arguments to `fn`.
380    Returns:
381      Return value of `fn` for the current replica.
383    Raises:
384      RuntimeError: when merge_call happens in a different graph, e.g. in a
385        different tf.function, which is not supported now.
386      _RequestedStop: when stop is requested.
388    """
389    t = threading.current_thread()
390    assert isinstance(t, _MirroredReplicaThread)
391    t.merge_fn = fn
392    t.merge_args = args
393    t.merge_kwargs = kwargs
394    t.captured_name_scope = t.graph.get_name_scope()
395    # Adding a "/" at end lets us re-enter this scope later.
396    if t.captured_name_scope:
397      t.captured_name_scope += "/"
399    t.captured_var_scope = variable_scope.get_variable_scope()
400    t.captured_control_deps = t.graph._current_control_dependencies()  # pylint: disable=protected-access
402    # It is problematic if `merge_call` is called under a different graph other
403    # than the one that `_call_for_each_replica` is called under, there are
404    # 3 cases this can happen:
405    #
406    #   1. The `fn` passed to `_call_for_each_replica` is decorated with
407    #   `tf.function` and there is a `merge_call` in `fn`. Since
408    #   MirroredStrategy traces a separate function per thread (per device),
409    #   and each trace takes a shared lock, the lock is never released by the
410    #   first thread and subsequent replica threads cannot proceed to trace
411    #   their own functions. This issue is addressed by always converting
412    #   `_call_for_each_replica(tf.function(f))` to
413    #   ``tf.function(_call_for_each_replica(f))`.` in
414    #   `MirroredStrategy._call_for_each_replica`.
415    #
416    #   2. The `fn` passed to `_call_for_each_replica` contains a nested
417    #   `tf.function`, and there is a `merge_call` in the nested `tf.function`.
418    #   In this case each thread can successfully trace its own function, but
419    #   since the `merge_fn` passed to `merge_call` is executed in the main
420    #   thread (where `_call_for_each_replica` is executed), it can't access
421    #   the tensors that come from different graphs.
422    #
423    #   3. The `fn` passed to `_call_for_each_replica` contains a control-flow
424    #   statement, and there is a `merge_call` inside the control-flow body,
425    #   `fn` or `_call_for_each_replica` is decorated with `tf.function`.
426    #   Control flow statement creates a separate graph for its body, similar
427    #   to #2, `merge_fn` executed in the main thread can't access the
428    #   tensors that come from different graphs.
429    #
430    #   We raise an error for #2 and #3.
431    if ops.get_default_graph() != t.graph:
432      raise RuntimeError(
433          "`merge_call` called while defining a new graph or a tf.function."
434          " This can often happen if the function `fn` passed to"
435          " `strategy.run()` contains a nested `@tf.function`, and the nested "
436          "`@tf.function` contains a synchronization point, such as aggregating"
437          " gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
438          " uses a control flow statement which contains a synchronization"
439          " point in the body. Such behaviors are not yet supported. Instead,"
440          " please avoid nested `tf.function`s or control flow statements that"
441          " may potentially cross a synchronization boundary, for example,"
442          " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
443          " inside a `tf.function` or move the control flow out of `fn`")
445    t.has_paused.set()
446    t.should_run.wait()
447    t.should_run.clear()
448    if t.coord.should_stop():
449      raise _RequestedStop()
450    return t.merge_result
452  @property
453  def devices(self):
454    distribute_lib.require_replica_context(self)
455    return [
456        self._strategy.extended.worker_devices_by_replica[
457            self._replica_id_in_sync_group]
458    ]