1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility to get tf.distribute.Strategy related contexts."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.util.lazy_loader import LazyLoader
23from tensorflow.python.util.tf_export import tf_export
24
25
26# There is a circular dependency between this and `distribute` module. So we
27# load it lazily to workaround this.
28distribute_lib = LazyLoader(
29    "distribute_lib", globals(),
30    "tensorflow.python.distribute.distribute_lib")
31
32# ------------------------------------------------------------------------------
33# Internal API for setting the current thread mode as being either in a
34# replica or cross-replica context for a particular tf.distribute.Strategy.
35
36
37class _ThreadMode(object):
38
39  def __init__(self, dist, cross, replica):
40    self.strategy = dist
41    self.cross_replica_context = cross
42    self.replica_context = replica
43
44
45class _CrossReplicaThreadMode(_ThreadMode):
46
47  def __init__(self, strategy):
48    _ThreadMode.__init__(self, strategy, strategy, None)
49
50
51class _InReplicaThreadMode(_ThreadMode):
52
53  def __init__(self, replica_ctx):
54    _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx)
55
56
57def _push_per_thread_mode(context):
58  ops.get_default_graph()._distribution_strategy_stack.append(context)  # pylint: disable=protected-access
59
60
61def _pop_per_thread_mode():
62  ops.get_default_graph()._distribution_strategy_stack.pop(-1)  # pylint: disable=protected-access
63
64
65class _DefaultReplicaThreadMode(_ThreadMode):
66  """Type of default value returned by `_get_per_thread_mode()`.
67
68  Used when the thread-local stack is empty.
69  """
70
71  def __init__(self):
72    _ThreadMode.__init__(self, _get_default_strategy(), None,
73                         _get_default_replica_context())
74
75
76def _get_per_thread_mode():
77  try:
78    return ops.get_default_graph()._distribution_strategy_stack[-1]  # pylint: disable=protected-access
79  except (AttributeError, IndexError):
80    return _get_default_replica_mode()
81
82
83# ------------------------------------------------------------------------------
84# Public API for accessing the current thread mode
85
86
87@tf_export("distribute.get_replica_context")
88def get_replica_context():
89  """Returns the current `tf.distribute.ReplicaContext` or `None`.
90
91  Returns `None` if in a cross-replica context.
92
93  Note that execution:
94
95  1. starts in the default (single-replica) replica context (this function
96     will return the default `ReplicaContext` object);
97  2. switches to cross-replica context (in which case this will return
98     `None`) when entering a `with tf.distribute.Strategy.scope():` block;
99  3. switches to a (non-default) replica context inside
100     `extended.call_for_each_replica(fn, ...)`;
101  4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
102     inside `merge_fn` you are back in the cross-replica context (and again
103     this function will return `None`).
104
105  Note that you can also go directly from step 1 to 4 to switch to a
106  cross-replica context for the default `tf.distribute.Strategy`. You may
107  also switch from the cross-replica context of 4 to a replica context by
108  calling `extended.call_for_each_replica()`, jumping back to step 3.
109
110  Most `tf.distribute.Strategy` methods may only be executed in
111  a cross-replica context, in a replica context you should use the
112  `ReplicaContext` API instead.
113
114  Returns:
115    The current `ReplicaContext` object when in a replica context scope,
116    else `None`.
117
118    Within a particular block, exactly one of these two things will be true:
119
120    * `get_replica_context()` returns non-`None`, or
121    * `tf.distribute.is_cross_replica_context()` returns True.
122  """
123  return _get_per_thread_mode().replica_context
124
125
126def get_cross_replica_context():
127  """Returns the current tf.distribute.Strategy if in a cross-replica context.
128
129  DEPRECATED: Please use `in_cross_replica_context()` and
130  `get_strategy()` instead.
131
132  Note that execution:
133
134  1. starts in the default (single-replica) replica context;
135  2. switches to cross-replica context when entering a
136     `with tf.distribute.Strategy.scope():` block;
137  3. switches to a (non-default) replica context inside
138     `call_for_each_replica(fn, ...)`;
139  4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
140     inside `merge_fn` you are back in the cross-replica context.
141
142  Note that you can also go directly from step 1 to 4 to switch to a
143  cross-replica context for the default `tf.distribute.Strategy`. You may
144  also switch from the cross-replica context of 4 to a replica context by
145  calling `call_for_each_replica()`, jumping back to step 3.
146
147  Most `tf.distribute.Strategy` methods may only be executed in
148  a cross-replica context.
149
150  Returns:
151    Returns the current `tf.distribute.Strategy` object in a cross-replica
152    context, or `None`.
153
154    Exactly one of `get_replica_context()` and `get_cross_replica_context()`
155    will return `None` in a particular block.
156  """
157  return _get_per_thread_mode().cross_replica_context
158
159
160@tf_export("distribute.in_cross_replica_context")
161def in_cross_replica_context():
162  """Returns True if in a cross-replica context.
163
164  See `tf.distribute.get_replica_context` for details.
165
166  Returns:
167    True if in a cross-replica context (`get_replica_context()` returns
168    `None`), or False if in a replica context (`get_replica_context()` returns
169    non-`None`).
170  """
171  return _get_per_thread_mode().cross_replica_context is not None
172
173
174@tf_export("distribute.get_strategy")
175def get_strategy():
176  """Returns the current `tf.distribute.Strategy` object.
177
178  Typically only used in a cross-replica context:
179
180  ```
181  if tf.distribute.in_cross_replica_context():
182    strategy = tf.distribute.get_strategy()
183    ...
184  ```
185
186  Returns:
187    A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block,
188    it returns `strategy`, otherwise it returns the default (single-replica)
189    `tf.distribute.Strategy` object.
190  """
191  return _get_per_thread_mode().strategy
192
193
194@tf_export("distribute.has_strategy")
195def has_strategy():
196  """Return if there is a current non-default `tf.distribute.Strategy`.
197
198  Returns:
199    True if inside a `with strategy.scope():`.
200  """
201  return get_strategy() is not _get_default_strategy()
202
203
204def get_strategy_and_replica_context():
205  per_thread_mode = _get_per_thread_mode()
206  return (per_thread_mode.strategy, per_thread_mode.replica_context)
207
208
209# ------------------------------------------------------------------------------
210# Defaults that are used when no tf.distribute.Strategy is explicitly created.
211# We create them lazily in a function so that we can workaround the circular
212# dependency on distribute_lib. See lazy loader at the top of this file.
213
214_defaults = {
215    "strategy": None,
216    "replica_context": None,
217    "replica_mode": None
218}
219
220
221def _get_default_strategy():
222  if _defaults["strategy"] is None:
223    _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy()  # pylint: disable=protected-access
224  return _defaults["strategy"]
225
226
227def _get_default_replica_context():
228  if _defaults["replica_context"] is None:
229    _defaults["replica_context"] = distribute_lib.ReplicaContext(
230        _get_default_strategy(), replica_id_in_sync_group=0)
231  return _defaults["replica_context"]
232
233
234def _get_default_replica_mode():
235  if _defaults["replica_mode"] is None:
236    _defaults["replica_mode"] = _DefaultReplicaThreadMode()
237  return _defaults["replica_mode"]
238
239
240# Aliases for compatibility with old names.
241get_distribution_strategy = get_strategy
242has_distribution_strategy = has_strategy
243