1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Critical Section object and execution logic.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import threading 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import gen_resource_variable_ops 31from tensorflow.python.ops import tensor_array_ops 32from tensorflow.python.util import nest 33from tensorflow.python.util import object_identity 34from tensorflow.python.util.tf_export import tf_export 35 36 37__all__ = ["CriticalSection"] 38 39 40# Graph Keys 41CRITICAL_SECTIONS = "critical_sections" 42CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" 43 44 45class _ExecutionSignature( 46 collections.namedtuple("_ExecutionSignature", 47 ("op", "handle", 48 "resources", "exclusive_resource_access"))): 49 """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" 50 pass 51 52 53def _identity(x): 54 """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" 55 if isinstance(x, tensor_array_ops.TensorArray): 56 return x.identity() 57 elif isinstance(x, ops.Operation): 58 return control_flow_ops.group(x) 59 elif context.executing_eagerly() and x is None: 60 return None 61 else: 62 return array_ops.identity(x) 63 64 65def _get_device_or_colocation(op): 66 return op.device or _get_colocation(op) 67 68 69def _get_colocation(op): 70 """Get colocation symbol from op, if any.""" 71 try: 72 return op.get_attr("_class") 73 except (ValueError, AttributeError): 74 return None 75 76 77_CRITICAL_SECTION_STACK = threading.local() 78 79 80def _get_critical_section_stack(): 81 try: 82 return _CRITICAL_SECTION_STACK.value 83 except AttributeError: 84 _CRITICAL_SECTION_STACK.value = [] 85 return _CRITICAL_SECTION_STACK.value 86 87 88@contextlib.contextmanager 89def _push_critical_section_stack(signature): 90 """Push a CriticalSection._signature to the thread-local stack. 91 92 If the signature is already on the stack, raise an error because it means 93 we're trying to execute inside the same locked CriticalSection, which 94 will create a deadlock. 95 96 Args: 97 signature: Tuple of the type `CriticalSection._signature`. Uniquely 98 identifies a CriticalSection by its `shared_name`, `container`, 99 and device. 100 101 Yields: 102 An empty value. The context is guaranteed to run without deadlock. 103 104 Raises: 105 ValueError: If the signature is already on the stack. 106 RuntimeError: If another thread or function modifies the current stack 107 entry during the yield. 108 """ 109 stack = _get_critical_section_stack() 110 if signature in stack: 111 raise ValueError( 112 "Attempting to lock a CriticalSection in which we are " 113 "already running. This is illegal and may cause deadlocks.") 114 stack.append(signature) 115 try: 116 yield 117 finally: 118 received_signature = stack.pop() 119 if received_signature != signature: 120 raise RuntimeError( 121 "CriticalSection stack inconsistency: expected signature " 122 "{} but saw {}".format(signature, received_signature)) 123 124 125@tf_export("CriticalSection") 126class CriticalSection(object): 127 """Critical section. 128 129 A `CriticalSection` object is a resource in the graph which executes subgraphs 130 in **serial** order. A common example of a subgraph one may wish to run 131 exclusively is the one given by the following function: 132 133 ```python 134 v = resource_variable_ops.ResourceVariable(0.0, name="v") 135 136 def count(): 137 value = v.read_value() 138 with tf.control_dependencies([value]): 139 with tf.control_dependencies([v.assign_add(1)]): 140 return tf.identity(value) 141 ``` 142 143 Here, a snapshot of `v` is captured in `value`; and then `v` is updated. 144 The snapshot value is returned. 145 146 If multiple workers or threads all execute `count` in parallel, there is no 147 guarantee that access to the variable `v` is atomic at any point within 148 any thread's calculation of `count`. In fact, even implementing an atomic 149 counter that guarantees that the user will see each value `0, 1, ...,` is 150 currently impossible. 151 152 The solution is to ensure any access to the underlying resource `v` is 153 only processed through a critical section: 154 155 ```python 156 cs = CriticalSection() 157 f1 = cs.execute(count) 158 f2 = cs.execute(count) 159 output = f1 + f2 160 session.run(output) 161 ``` 162 The functions `f1` and `f2` will be executed serially, and updates to `v` 163 will be atomic. 164 165 **NOTES** 166 167 All resource objects, including the critical section and any captured 168 variables of functions executed on that critical section, will be 169 colocated to the same device (host and cpu/gpu). 170 171 When using multiple critical sections on the same resources, there is no 172 guarantee of exclusive access to those resources. This behavior is disallowed 173 by default (but see the kwarg `exclusive_resource_access`). 174 175 For example, running the same function in two separate critical sections 176 will not ensure serial execution: 177 178 ```python 179 v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True) 180 def accumulate(up): 181 x = v.read_value() 182 with tf.control_dependencies([x]): 183 with tf.control_dependencies([v.assign_add(up)]): 184 return tf.identity(x) 185 ex1 = CriticalSection().execute( 186 accumulate, 1.0, exclusive_resource_access=False) 187 ex2 = CriticalSection().execute( 188 accumulate, 1.0, exclusive_resource_access=False) 189 bad_sum = ex1 + ex2 190 sess.run(v.initializer) 191 sess.run(bad_sum) # May return 0.0 192 ``` 193 """ 194 195 def __init__(self, name=None, shared_name=None, 196 critical_section_def=None, import_scope=None): 197 """Creates a critical section.""" 198 context.ensure_initialized() 199 if critical_section_def and name is not None: 200 raise ValueError("critical_section_def and shared_name are " 201 "mutually exclusive.") 202 if critical_section_def: 203 raise ValueError("critical_section_def is not supported.") 204 else: 205 self._init_from_args(name, shared_name) 206 207 def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name 208 """Initialize the CriticalSection from constructor arguments.""" 209 with ops.name_scope(name, "CriticalSection", []) as name: 210 with ops.init_scope(): 211 # pylint: disable=protected-access 212 container = ops.get_default_graph()._container 213 # pylint: enable=protected-access 214 if shared_name is None: 215 shared_name = name 216 if container is None: 217 container = "" 218 self._handle = gen_resource_variable_ops.mutex_v2( 219 shared_name=shared_name, container=container, name=name) 220 # Get a uniquely identifying signature for the handle. 221 self._signature = ( 222 container, 223 # If shared_name is empty, a unique CriticalSection is created. 224 shared_name or id(self._handle), 225 _get_device_or_colocation(self._handle)) 226 227 if not context.executing_eagerly(): 228 ops.add_to_collections(CRITICAL_SECTIONS, self) 229 230 @property 231 def name(self): 232 return self._handle.op.name 233 234 def execute(self, fn, exclusive_resource_access=True, name=None): 235 """Execute function `fn()` inside the critical section. 236 237 `fn` should not accept any arguments. To add extra arguments to when 238 calling `fn` in the critical section, create a lambda: 239 240 ```python 241 critical_section.execute(lambda: fn(*my_args, **my_kwargs)) 242 ``` 243 244 Args: 245 fn: The function to execute. Must return at least one tensor. 246 exclusive_resource_access: Whether the resources required by 247 `fn` should be exclusive to this `CriticalSection`. Default: `True`. 248 You may want to set this to `False` if you will be accessing a 249 resource in read-only mode in two different CriticalSections. 250 name: The name to use when creating the execute operation. 251 252 Returns: 253 The tensors returned from `fn()`. 254 255 Raises: 256 ValueError: If `fn` attempts to lock this `CriticalSection` in any nested 257 or lazy way that may cause a deadlock. 258 ValueError: If `exclusive_resource_access == True` and 259 another `CriticalSection` has an execution requesting the same 260 resources as `fn``. Note, even if `exclusive_resource_access` is 261 `True`, if another execution in another `CriticalSection` was created 262 without `exclusive_resource_access=True`, a `ValueError` will be raised. 263 """ 264 with ops.name_scope(name, "critical_section_execute", []): 265 # Ensure that mutex locking only happens *after* all args and 266 # kwargs have been executed. This avoids certain types of deadlocks. 267 with _push_critical_section_stack(self._signature): 268 lock = gen_resource_variable_ops.mutex_lock(self._handle) 269 270 if not context.executing_eagerly(): 271 # NOTE(ebrevdo): This is to ensure we don't pick up spurious 272 # Operations created by other threads. 273 with ops.get_default_graph()._lock: # pylint: disable=protected-access 274 existing_ops = ops.get_default_graph().get_operations() 275 with ops.control_dependencies([lock]): 276 r = fn() 277 # TODO(ebrevdo): If creating critical sections in a python loop, 278 # this makes graph creation time quadratic. Revisit if this 279 # becomes a problem. 280 created_ops = (set(ops.get_default_graph().get_operations()) 281 .difference(existing_ops)) 282 else: 283 with ops.control_dependencies([lock]): 284 r = fn() 285 286 if not context.executing_eagerly(): 287 self._add_control_dependencies_to_lock(created_ops, lock.op) 288 289 # captured_resources is a list of resources that are directly 290 # accessed only by ops created during fn(), not by any 291 # ancestors of those ops in the graph. 292 captured_resources = object_identity.ObjectIdentitySet([ 293 input_ for op in created_ops 294 for input_ in op.inputs 295 if input_.dtype == dtypes.resource 296 ]) 297 298 # NOTE(ebrevdo): The only time self._is_self_handle() is True 299 # in this call is if one of the recently created ops, within 300 # the execute(), themselves attempt to access the 301 # CriticalSection. This will cause a deadlock. 302 if any(self._is_self_handle(x) for x in captured_resources): 303 raise ValueError( 304 "Attempting to lock a CriticalSection in which we are " 305 "already running. This is illegal and may cause deadlocks.") 306 307 self._check_multiple_access_to_resources( 308 captured_resources, exclusive_resource_access) 309 310 r_flat = [_identity(x) for x in nest.flatten(r)] 311 312 with ops.control_dependencies(r_flat): 313 # The identity must run on the same machine as self._handle 314 with ops.colocate_with(self._handle): 315 # Do not use array_ops.identity as there are special 316 # optimizations within TensorFlow which seem to elide it 317 # even when optimizations are disabled(!). 318 ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( 319 lock) 320 321 # Make sure that if any element of r is accessed, all of 322 # them are executed together. 323 r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) 324 325 with ops.control_dependencies([ensure_lock_exists]): 326 outputs = nest.map_structure(_identity, r) 327 328 if not context.executing_eagerly(): 329 signature = _ExecutionSignature( 330 op=lock.op, 331 handle=self._handle, 332 resources=list(captured_resources), 333 exclusive_resource_access=exclusive_resource_access) 334 ops.add_to_collections( 335 CRITICAL_SECTION_EXECUTIONS, signature) 336 337 return outputs 338 339 def _add_control_dependencies_to_lock(self, created_ops, lock_op): 340 """To avoid deadlocks, all args must be executed before lock_op.""" 341 # Get all arguments (explicit and captured) of all ops created by fn(). 342 all_args = set([input_.op for op in created_ops for input_ in op.inputs]) 343 all_args.update( 344 input_op for op in created_ops for input_op in op.control_inputs) 345 # Unfortunately, we can't use sets throughout because TF seems to 346 # create new Operation objects for the same op sometimes; and we 347 # can't rely on id(op). 348 349 # pylint: disable=protected-access 350 all_args_dict = dict((op._id, op) for op in all_args) 351 352 # Remove ops created within fn, or that lock_op already has a 353 # control dependency on. Also remove a possible self-loop. 354 for op in created_ops: 355 all_args_dict.pop(op._id, None) 356 for op in lock_op.control_inputs: 357 all_args_dict.pop(op._id, None) 358 for input_ in lock_op.inputs: 359 all_args_dict.pop(input_.op._id, None) 360 all_args_dict.pop(lock_op._id, None) 361 362 all_args = all_args_dict.values() 363 364 if not all_args: 365 # No control dependencies to add; return early. 366 return 367 368 # This group is important: it ensures that any ops in all_args 369 # outside the control context of the lock_op (and this fn, which 370 # runs in the same context) are added to this context before 371 # being added to the control dependencies of lock_op. 372 all_args = control_flow_ops.group(*all_args) 373 374 lock_op._add_control_input(all_args) 375 # pylint: enable=protected-access 376 377 def _is_self_handle(self, x): 378 """Check if the tensor `x` is the same Mutex as `self._handle`.""" 379 if isinstance(x, ops.EagerTensor): 380 return x is self._handle 381 return (x.op.type == "MutexV2" 382 # blank shared_name means the op will create a unique one. 383 and x.op.get_attr("shared_name") 384 and (x.op.get_attr("shared_name") == 385 self._handle.op.get_attr("shared_name")) 386 and (x.op.device == self._handle.op.device 387 or _get_colocation(x.op) == _get_colocation(self._handle.op))) 388 389 def _check_multiple_access_to_resources( 390 self, captured_resources, exclusive_resource_access): 391 """Raise if captured_resources are accessed by another CriticalSection. 392 393 Args: 394 captured_resources: Set of tensors of type resource. 395 exclusive_resource_access: Whether this execution requires exclusive 396 resource access. 397 398 Raises: 399 ValueError: If any tensors in `captured_resources` are also accessed 400 by another `CriticalSection`, and at least one of them requires 401 exclusive resource access. 402 """ 403 # Collections and op introspection does not work in eager 404 # mode. This is generally ok; since eager mode (as of 405 # writing) executes sequentially anyway. 406 for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): 407 if self._is_self_handle(sg.handle): 408 # Other executions in the same critical section are allowed. 409 continue 410 if not (exclusive_resource_access or sg.exclusive_resource_access): 411 # Neither execution requested exclusive access. 412 continue 413 resource_intersection = captured_resources.intersection(sg.resources) 414 if resource_intersection: 415 raise ValueError( 416 "This execution would access resources: %s. Either this " 417 "lock (CriticalSection: %s) or lock '%s' " 418 "(CriticalSection: %s) requested exclusive resource access " 419 "of this resource. Did you mean to call execute with keyword " 420 "argument exclusive_resource_access=False?" % 421 (list(resource_intersection), self._handle, sg, sg.handle)) 422