1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility to get tf.distribute.Strategy related contexts."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import threading
23
24from tensorflow.python import tf2
25from tensorflow.python.framework import ops
26from tensorflow.python.util.lazy_loader import LazyLoader
27from tensorflow.python.util.tf_export import tf_export
28
29
30# There is a circular dependency between this and the `distribute_lib` module.
31# So we load it lazily to work around this.
32distribute_lib = LazyLoader(
33    "distribute_lib", globals(),
34    "tensorflow.python.distribute.distribute_lib")
35
36# ------------------------------------------------------------------------------
37# Internal API for setting the current thread mode as being either in a
38# replica or cross-replica context for a particular tf.distribute.Strategy.
39
40
41class _ThreadMode(object):
42
43  def __init__(self, dist, cross, replica):
44    self.strategy = dist
45    self.cross_replica_context = cross
46    self.replica_context = replica
47
48
49class _CrossReplicaThreadMode(_ThreadMode):
50
51  def __init__(self, strategy):
52    _ThreadMode.__init__(self, strategy, strategy, None)
53
54
55class _InReplicaThreadMode(_ThreadMode):
56
57  def __init__(self, replica_ctx):
58    _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx)
59
60
61def _push_per_thread_mode(context):
62  ops.get_default_graph()._distribution_strategy_stack.append(context)  # pylint: disable=protected-access
63
64
65def _pop_per_thread_mode():
66  ops.get_default_graph()._distribution_strategy_stack.pop(-1)  # pylint: disable=protected-access
67
68
69class _DefaultReplicaThreadMode(_ThreadMode):
70  """Type of default value returned by `_get_per_thread_mode()`.
71
72  Used when the thread-local stack is empty.
73  """
74
75  def __init__(self):
76    _ThreadMode.__init__(self, _get_default_strategy(), None,
77                         _get_default_replica_context())
78
79
80def _get_per_thread_mode():
81  try:
82    return ops.get_default_graph()._distribution_strategy_stack[-1]  # pylint: disable=protected-access
83  except (AttributeError, IndexError):
84    return _get_default_replica_mode()
85
86
87# ------------------------------------------------------------------------------
88# Public API for accessing the current thread mode
89
90
91@tf_export("distribute.get_replica_context")
92def get_replica_context():
93  """Returns the current `tf.distribute.ReplicaContext` or `None`.
94
95  Returns `None` if in a cross-replica context.
96
97  Note that execution:
98
99  1. starts in the default (single-replica) replica context (this function
100     will return the default `ReplicaContext` object);
101  2. switches to cross-replica context (in which case this will return
102     `None`) when entering a `with tf.distribute.Strategy.scope():` block;
103  3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`;
104  4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
105     inside `merge_fn` you are back in the cross-replica context (and again
106     this function will return `None`).
107
108  Most `tf.distribute.Strategy` methods may only be executed in
109  a cross-replica context, in a replica context you should use the
110  API of the `tf.distribute.ReplicaContext` object returned by this
111  method instead.
112
113  ```
114  assert tf.distribute.get_replica_context() is not None  # default
115  with strategy.scope():
116    assert tf.distribute.get_replica_context() is None
117
118    def f():
119      replica_context = tf.distribute.get_replica_context()  # for strategy
120      assert replica_context is not None
121      tf.print("Replica id: ", replica_context.replica_id_in_sync_group,
122               " of ", replica_context.num_replicas_in_sync)
123
124    strategy.run(f)
125  ```
126
127  Returns:
128    The current `tf.distribute.ReplicaContext` object when in a replica context
129    scope, else `None`.
130
131    Within a particular block, exactly one of these two things will be true:
132
133    * `get_replica_context()` returns non-`None`, or
134    * `tf.distribute.is_cross_replica_context()` returns True.
135  """
136  return _get_per_thread_mode().replica_context
137
138
139def get_cross_replica_context():
140  """Returns the current tf.distribute.Strategy if in a cross-replica context.
141
142  DEPRECATED: Please use `in_cross_replica_context()` and
143  `get_strategy()` instead.
144
145  Returns:
146    Returns the current `tf.distribute.Strategy` object in a cross-replica
147    context, or `None`.
148
149    Exactly one of `get_replica_context()` and `get_cross_replica_context()`
150    will return `None` in a particular block.
151  """
152  return _get_per_thread_mode().cross_replica_context
153
154
155@tf_export("distribute.in_cross_replica_context")
156def in_cross_replica_context():
157  """Returns `True` if in a cross-replica context.
158
159  See `tf.distribute.get_replica_context` for details.
160
161  ```
162  assert not tf.distribute.in_cross_replica_context()
163  with strategy.scope():
164    assert tf.distribute.in_cross_replica_context()
165
166    def f():
167      assert not tf.distribute.in_cross_replica_context()
168
169    strategy.run(f)
170  ```
171
172  Returns:
173    `True` if in a cross-replica context (`get_replica_context()` returns
174    `None`), or `False` if in a replica context (`get_replica_context()` returns
175    non-`None`).
176  """
177  return _get_per_thread_mode().cross_replica_context is not None
178
179
180@tf_export("distribute.get_strategy")
181def get_strategy():
182  """Returns the current `tf.distribute.Strategy` object.
183
184  Typically only used in a cross-replica context:
185
186  ```
187  if tf.distribute.in_cross_replica_context():
188    strategy = tf.distribute.get_strategy()
189    ...
190  ```
191
192  Returns:
193    A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block,
194    it returns `strategy`, otherwise it returns the default (single-replica)
195    `tf.distribute.Strategy` object.
196  """
197  return _get_per_thread_mode().strategy
198
199
200@tf_export("distribute.has_strategy")
201def has_strategy():
202  """Return if there is a current non-default `tf.distribute.Strategy`.
203
204  ```
205  assert not tf.distribute.has_strategy()
206  with strategy.scope():
207    assert tf.distribute.has_strategy()
208  ```
209
210  Returns:
211    True if inside a `with strategy.scope():`.
212  """
213  return get_strategy() is not _get_default_strategy()
214
215
216def get_strategy_and_replica_context():
217  per_thread_mode = _get_per_thread_mode()
218  return (per_thread_mode.strategy, per_thread_mode.replica_context)
219
220
221@tf_export("distribute.experimental_set_strategy")
222def experimental_set_strategy(strategy):
223  """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`.
224
225  ```
226  tf.distribute.experimental_set_strategy(strategy1)
227  f()
228  tf.distribute.experimental_set_strategy(strategy2)
229  g()
230  tf.distribute.experimental_set_strategy(None)
231  h()
232  ```
233
234  is equivalent to:
235
236  ```
237  with strategy1.scope():
238    f()
239  with strategy2.scope():
240    g()
241  h()
242  ```
243
244  In general, you should use the `with strategy.scope():` API, but this
245  alternative may be convenient in notebooks where you would have to put
246  each cell in a `with strategy.scope():` block.
247
248  Note: This should only be called outside of any TensorFlow scope to
249  avoid improper nesting.
250
251  Args:
252    strategy: A `tf.distribute.Strategy` object or None.
253
254  Raises:
255    RuntimeError: If called inside a `with strategy.scope():`.
256  """
257  old_scope = ops.get_default_graph()._global_distribute_strategy_scope  # pylint: disable=protected-access
258  if old_scope is not None:
259    old_scope.__exit__(None, None, None)
260    ops.get_default_graph()._global_distribute_strategy_scope = None  # pylint: disable=protected-access
261  if has_strategy():
262    raise RuntimeError(
263        "Must not be called inside a `tf.distribute.Strategy` scope.")
264  if strategy is not None:
265    new_scope = strategy.scope()
266    new_scope.__enter__()
267    ops.get_default_graph()._global_distribute_strategy_scope = new_scope  # pylint: disable=protected-access
268
269
270# ------------------------------------------------------------------------------
271# Internal helpers.
272
273
274@contextlib.contextmanager
275def enter_or_assert_strategy(strategy):
276  if not has_strategy():
277    with strategy.scope():
278      yield
279  else:
280    _assert_strategy(strategy)
281    yield
282
283
284# ------------------------------------------------------------------------------
285# Defaults that are used when no tf.distribute.Strategy is explicitly created.
286# We create them lazily in a function so that we can workaround the circular
287# dependency on distribute_lib. See lazy loader at the top of this file.
288
289_defaults = {
290    "strategy": None,
291    "replica_context": None,
292    "replica_mode": None
293}
294# Note: These need to be different locks since _get_default_replica_context
295# calls _get_default_strategy inside its lock, and them using the same lock
296# can lead to deadlock.
297_default_strategy_lock = threading.Lock()
298_default_replica_context_lock = threading.Lock()
299_default_replica_mode_lock = threading.Lock()
300
301
302def _assert_strategy(strategy):
303  if not has_strategy():
304    raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
305                       (strategy,))
306  current_strategy = get_strategy()
307  if current_strategy is not strategy:
308    raise RuntimeError(
309        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
310        (current_strategy, strategy))
311
312
313def _get_default_strategy():
314  if _defaults["strategy"] is None:
315    # Avoid race condition causing two defaults to be created
316    with _default_strategy_lock:
317      if _defaults["strategy"] is None:
318        # pylint: disable=protected-access
319        # Make sure distribute_lib module is loaded by accessing some member.
320        _ = distribute_lib._creating_default_strategy_singleton
321        distribute_lib._creating_default_strategy_singleton = True
322        if tf2.enabled():
323          _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy()
324        else:
325          _defaults["strategy"] = (
326              distribute_lib._DefaultDistributionStrategyV1())
327        distribute_lib._creating_default_strategy_singleton = False
328        # pylint: enable=protected-access
329  return _defaults["strategy"]
330
331
332def _get_default_replica_context():
333  if _defaults["replica_context"] is None:
334    # Avoid race condition causing two defaults to be created
335    with _default_replica_context_lock:
336      if _defaults["replica_context"] is None:
337        # pylint: disable=protected-access
338        _defaults["replica_context"] = distribute_lib._DefaultReplicaContext(
339            _get_default_strategy(), replica_id_in_sync_group=0)
340        # pylint: enable=protected-access
341  return _defaults["replica_context"]
342
343
344def _get_default_replica_mode():
345  if _defaults["replica_mode"] is None:
346    # Avoid race condition causing two defaults to be created
347    with _default_replica_mode_lock:
348      if _defaults["replica_mode"] is None:
349        _defaults["replica_mode"] = _DefaultReplicaThreadMode()
350  return _defaults["replica_mode"]
351
352
353# Aliases for compatibility with old names.
354get_distribution_strategy = get_strategy
355has_distribution_strategy = has_strategy
356