1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Critical Section object and execution logic."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import threading
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import gen_resource_variable_ops
31from tensorflow.python.ops import tensor_array_ops
32from tensorflow.python.util import nest
33from tensorflow.python.util import object_identity
34from tensorflow.python.util.tf_export import tf_export
35
36
37__all__ = ["CriticalSection"]
38
39
40# Graph Keys
41CRITICAL_SECTIONS = "critical_sections"
42CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
43
44
45class _ExecutionSignature(
46    collections.namedtuple("_ExecutionSignature",
47                           ("op", "handle",
48                            "resources", "exclusive_resource_access"))):
49  """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
50  pass
51
52
53def _identity(x):
54  """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
55  if isinstance(x, tensor_array_ops.TensorArray):
56    return x.identity()
57  elif isinstance(x, ops.Operation):
58    return control_flow_ops.group(x)
59  elif context.executing_eagerly() and x is None:
60    return None
61  else:
62    return array_ops.identity(x)
63
64
65def _get_device_or_colocation(op):
66  return op.device or _get_colocation(op)
67
68
69def _get_colocation(op):
70  """Get colocation symbol from op, if any."""
71  try:
72    return op.get_attr("_class")
73  except (ValueError, AttributeError):
74    return None
75
76
77_CRITICAL_SECTION_STACK = threading.local()
78
79
80def _get_critical_section_stack():
81  try:
82    return _CRITICAL_SECTION_STACK.value
83  except AttributeError:
84    _CRITICAL_SECTION_STACK.value = []
85    return _CRITICAL_SECTION_STACK.value
86
87
88@contextlib.contextmanager
89def _push_critical_section_stack(signature):
90  """Push a CriticalSection._signature to the thread-local stack.
91
92  If the signature is already on the stack, raise an error because it means
93  we're trying to execute inside the same locked CriticalSection, which
94  will create a deadlock.
95
96  Args:
97    signature: Tuple of the type `CriticalSection._signature`.  Uniquely
98      identifies a CriticalSection by its `shared_name`, `container`,
99      and device.
100
101  Yields:
102    An empty value.  The context is guaranteed to run without deadlock.
103
104  Raises:
105    ValueError: If the signature is already on the stack.
106    RuntimeError: If another thread or function modifies the current stack
107      entry during the yield.
108  """
109  stack = _get_critical_section_stack()
110  if signature in stack:
111    raise ValueError(
112        "Attempting to lock a CriticalSection in which we are "
113        "already running.  This is illegal and may cause deadlocks.")
114  stack.append(signature)
115  try:
116    yield
117  finally:
118    received_signature = stack.pop()
119    if received_signature != signature:
120      raise RuntimeError(
121          "CriticalSection stack inconsistency: expected signature "
122          "{} but saw {}".format(signature, received_signature))
123
124
125@tf_export("CriticalSection")
126class CriticalSection(object):
127  """Critical section.
128
129  A `CriticalSection` object is a resource in the graph which executes subgraphs
130  in **serial** order.  A common example of a subgraph one may wish to run
131  exclusively is the one given by the following function:
132
133  ```python
134  v = resource_variable_ops.ResourceVariable(0.0, name="v")
135
136  def count():
137    value = v.read_value()
138    with tf.control_dependencies([value]):
139      with tf.control_dependencies([v.assign_add(1)]):
140        return tf.identity(value)
141  ```
142
143  Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
144  The snapshot value is returned.
145
146  If multiple workers or threads all execute `count` in parallel, there is no
147  guarantee that access to the variable `v` is atomic at any point within
148  any thread's calculation of `count`.  In fact, even implementing an atomic
149  counter that guarantees that the user will see each value `0, 1, ...,` is
150  currently impossible.
151
152  The solution is to ensure any access to the underlying resource `v` is
153  only processed through a critical section:
154
155  ```python
156  cs = CriticalSection()
157  f1 = cs.execute(count)
158  f2 = cs.execute(count)
159  output = f1 + f2
160  session.run(output)
161  ```
162  The functions `f1` and `f2` will be executed serially, and updates to `v`
163  will be atomic.
164
165  **NOTES**
166
167  All resource objects, including the critical section and any captured
168  variables of functions executed on that critical section, will be
169  colocated to the same device (host and cpu/gpu).
170
171  When using multiple critical sections on the same resources, there is no
172  guarantee of exclusive access to those resources.  This behavior is disallowed
173  by default (but see the kwarg `exclusive_resource_access`).
174
175  For example, running the same function in two separate critical sections
176  will not ensure serial execution:
177
178  ```python
179  v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
180  def accumulate(up):
181    x = v.read_value()
182    with tf.control_dependencies([x]):
183      with tf.control_dependencies([v.assign_add(up)]):
184        return tf.identity(x)
185  ex1 = CriticalSection().execute(
186    accumulate, 1.0, exclusive_resource_access=False)
187  ex2 = CriticalSection().execute(
188    accumulate, 1.0, exclusive_resource_access=False)
189  bad_sum = ex1 + ex2
190  sess.run(v.initializer)
191  sess.run(bad_sum)  # May return 0.0
192  ```
193  """
194
195  def __init__(self, name=None, shared_name=None,
196               critical_section_def=None, import_scope=None):
197    """Creates a critical section."""
198    context.ensure_initialized()
199    if critical_section_def and name is not None:
200      raise ValueError("critical_section_def and shared_name are "
201                       "mutually exclusive.")
202    if critical_section_def:
203      raise ValueError("critical_section_def is not supported.")
204    else:
205      self._init_from_args(name, shared_name)
206
207  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
208    """Initialize the CriticalSection from constructor arguments."""
209    with ops.name_scope(name, "CriticalSection", []) as name:
210      with ops.init_scope():
211        # pylint: disable=protected-access
212        container = ops.get_default_graph()._container
213        # pylint: enable=protected-access
214        if shared_name is None:
215          shared_name = name
216        if container is None:
217          container = ""
218        self._handle = gen_resource_variable_ops.mutex_v2(
219            shared_name=shared_name, container=container, name=name)
220        # Get a uniquely identifying signature for the handle.
221        self._signature = (
222            container,
223            # If shared_name is empty, a unique CriticalSection is created.
224            shared_name or id(self._handle),
225            _get_device_or_colocation(self._handle))
226
227    if not context.executing_eagerly():
228      ops.add_to_collections(CRITICAL_SECTIONS, self)
229
230  @property
231  def name(self):
232    return self._handle.op.name
233
234  def execute(self, fn, exclusive_resource_access=True, name=None):
235    """Execute function `fn()` inside the critical section.
236
237    `fn` should not accept any arguments.  To add extra arguments to when
238    calling `fn` in the critical section, create a lambda:
239
240    ```python
241    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
242    ```
243
244    Args:
245      fn: The function to execute.  Must return at least one tensor.
246      exclusive_resource_access: Whether the resources required by
247        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
248        You may want to set this to `False` if you will be accessing a
249        resource in read-only mode in two different CriticalSections.
250      name: The name to use when creating the execute operation.
251
252    Returns:
253      The tensors returned from `fn()`.
254
255    Raises:
256      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
257        or lazy way that may cause a deadlock.
258      ValueError: If `exclusive_resource_access == True` and
259        another `CriticalSection` has an execution requesting the same
260        resources as `fn``.  Note, even if `exclusive_resource_access` is
261        `True`, if another execution in another `CriticalSection` was created
262        without `exclusive_resource_access=True`, a `ValueError` will be raised.
263    """
264    with ops.name_scope(name, "critical_section_execute", []):
265      # Ensure that mutex locking only happens *after* all args and
266      # kwargs have been executed.  This avoids certain types of deadlocks.
267      with _push_critical_section_stack(self._signature):
268        lock = gen_resource_variable_ops.mutex_lock(self._handle)
269
270        if not context.executing_eagerly():
271          # NOTE(ebrevdo): This is to ensure we don't pick up spurious
272          # Operations created by other threads.
273          with ops.get_default_graph()._lock:  # pylint: disable=protected-access
274            existing_ops = ops.get_default_graph().get_operations()
275            with ops.control_dependencies([lock]):
276              r = fn()
277            # TODO(ebrevdo): If creating critical sections in a python loop,
278            # this makes graph creation time quadratic.  Revisit if this
279            # becomes a problem.
280            created_ops = (set(ops.get_default_graph().get_operations())
281                           .difference(existing_ops))
282        else:
283          with ops.control_dependencies([lock]):
284            r = fn()
285
286      if not context.executing_eagerly():
287        self._add_control_dependencies_to_lock(created_ops, lock.op)
288
289        # captured_resources is a list of resources that are directly
290        # accessed only by ops created during fn(), not by any
291        # ancestors of those ops in the graph.
292        captured_resources = object_identity.ObjectIdentitySet([
293            input_ for op in created_ops
294            for input_ in op.inputs
295            if input_.dtype == dtypes.resource
296        ])
297
298        # NOTE(ebrevdo): The only time self._is_self_handle() is True
299        # in this call is if one of the recently created ops, within
300        # the execute(), themselves attempt to access the
301        # CriticalSection.  This will cause a deadlock.
302        if any(self._is_self_handle(x) for x in captured_resources):
303          raise ValueError(
304              "Attempting to lock a CriticalSection in which we are "
305              "already running.  This is illegal and may cause deadlocks.")
306
307        self._check_multiple_access_to_resources(
308            captured_resources, exclusive_resource_access)
309
310      r_flat = [_identity(x) for x in nest.flatten(r)]
311
312      with ops.control_dependencies(r_flat):
313        # The identity must run on the same machine as self._handle
314        with ops.colocate_with(self._handle):
315          # Do not use array_ops.identity as there are special
316          # optimizations within TensorFlow which seem to elide it
317          # even when optimizations are disabled(!).
318          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
319              lock)
320
321        # Make sure that if any element of r is accessed, all of
322        # them are executed together.
323        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
324
325      with ops.control_dependencies([ensure_lock_exists]):
326        outputs = nest.map_structure(_identity, r)
327
328      if not context.executing_eagerly():
329        signature = _ExecutionSignature(
330            op=lock.op,
331            handle=self._handle,
332            resources=list(captured_resources),
333            exclusive_resource_access=exclusive_resource_access)
334        ops.add_to_collections(
335            CRITICAL_SECTION_EXECUTIONS, signature)
336
337      return outputs
338
339  def _add_control_dependencies_to_lock(self, created_ops, lock_op):
340    """To avoid deadlocks, all args must be executed before lock_op."""
341    # Get all arguments (explicit and captured) of all ops created by fn().
342    all_args = set([input_.op for op in created_ops for input_ in op.inputs])
343    all_args.update(
344        input_op for op in created_ops for input_op in op.control_inputs)
345    # Unfortunately, we can't use sets throughout because TF seems to
346    # create new Operation objects for the same op sometimes; and we
347    # can't rely on id(op).
348
349    # pylint: disable=protected-access
350    all_args_dict = dict((op._id, op) for op in all_args)
351
352    # Remove ops created within fn, or that lock_op already has a
353    # control dependency on.  Also remove a possible self-loop.
354    for op in created_ops:
355      all_args_dict.pop(op._id, None)
356    for op in lock_op.control_inputs:
357      all_args_dict.pop(op._id, None)
358    for input_ in lock_op.inputs:
359      all_args_dict.pop(input_.op._id, None)
360    all_args_dict.pop(lock_op._id, None)
361
362    all_args = all_args_dict.values()
363
364    if not all_args:
365      # No control dependencies to add; return early.
366      return
367
368    # This group is important: it ensures that any ops in all_args
369    # outside the control context of the lock_op (and this fn, which
370    # runs in the same context) are added to this context before
371    # being added to the control dependencies of lock_op.
372    all_args = control_flow_ops.group(*all_args)
373
374    lock_op._add_control_input(all_args)
375    # pylint: enable=protected-access
376
377  def _is_self_handle(self, x):
378    """Check if the tensor `x` is the same Mutex as `self._handle`."""
379    if isinstance(x, ops.EagerTensor):
380      return x is self._handle
381    return (x.op.type == "MutexV2"
382            # blank shared_name means the op will create a unique one.
383            and x.op.get_attr("shared_name")
384            and (x.op.get_attr("shared_name") ==
385                 self._handle.op.get_attr("shared_name"))
386            and (x.op.device == self._handle.op.device
387                 or _get_colocation(x.op) == _get_colocation(self._handle.op)))
388
389  def _check_multiple_access_to_resources(
390      self, captured_resources, exclusive_resource_access):
391    """Raise if captured_resources are accessed by another CriticalSection.
392
393    Args:
394      captured_resources: Set of tensors of type resource.
395      exclusive_resource_access: Whether this execution requires exclusive
396        resource access.
397
398    Raises:
399      ValueError: If any tensors in `captured_resources` are also accessed
400        by another `CriticalSection`, and at least one of them requires
401        exclusive resource access.
402    """
403    # Collections and op introspection does not work in eager
404    # mode.  This is generally ok; since eager mode (as of
405    # writing) executes sequentially anyway.
406    for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
407      if self._is_self_handle(sg.handle):
408        # Other executions in the same critical section are allowed.
409        continue
410      if not (exclusive_resource_access or sg.exclusive_resource_access):
411        # Neither execution requested exclusive access.
412        continue
413      resource_intersection = captured_resources.intersection(sg.resources)
414      if resource_intersection:
415        raise ValueError(
416            "This execution would access resources: %s.  Either this "
417            "lock (CriticalSection: %s) or lock '%s' "
418            "(CriticalSection: %s) requested exclusive resource access "
419            "of this resource.  Did you mean to call execute with keyword "
420            "argument exclusive_resource_access=False?" %
421            (list(resource_intersection), self._handle, sg, sg.handle))
422