1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AutomaticControlDependencies and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import enum
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import auto_control_deps_utils as utils
26from tensorflow.python.framework import dtypes as dtypes_module
27from tensorflow.python.framework import op_def_registry
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import registry
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import tensor_array_ops
35from tensorflow.python.util import nest
36from tensorflow.python.util import object_identity
37from tensorflow.python.util import tf_decorator
38
39# LINT.IfChange
40# Op types that should not run in program order, e.g. because they need to run
41# asynchronously to avoid deadlock.
42ASYNC_STATEFUL_OPS = [
43    "CollectiveGather",
44    "CollectiveGatherV2",
45    "CollectiveReduce",
46    "CollectiveReduceV2",
47    "CollectiveBcastSend",
48    "CollectiveBcastSendV2",
49    "CollectiveBcastRecv",
50    "CollectiveBcastRecvV2",
51    "NcclAllReduce",
52    # We do not add "Send" here since we want it to be added as a control output
53    # in order to avoid being pruned.
54    "Recv",
55]
56
57LEGACY_RANDOM_OPS = [
58    # These may be used in variable initializers -- thus their execution should
59    # not be dependent on other stateful operations.  This is because although
60    # according to program order, tf.Variables may be created in sequence,
61    # their initialization happens outside of the program order (specifically,
62    # in graph mode their initialization happens by calling a grouped
63    # initializer operation or in eager mode, where initialization is lifted
64    # out of the tf.function and executed the first time the function is
65    # executed).
66    #
67    # Unless there is a specific dependency between the initializers
68    # themselves (e.g. one initializer depends on a Variable whose value depends
69    # on another initializer), the initialization can happen in any order so
70    # long as it's before the associated Variable read operations.
71    #
72    # Note that in general the randomness of legacy random operations is only
73    # guaranteed by providing a graph-level and op-level seed (and ordering of
74    # the same op across multiple iterations of a while_loop is specifically not
75    # guaranteed; see the discussion below).
76    #
77    # There is a possible race condition inside while_loop where the same
78    # random OpKernel instantiation is reused across multiple steps
79    # of the loop.  Since legacy Random OpKernels have an internal rng state,
80    # automatic dependency tracking across loop steps would likely
81    # fix this race; and for that case this denylist is problematic.
82    # However, since automatic dependency tracking inside while loops is not
83    # currently supported, and there are no other examples of OpKernel reuse
84    # (each OpKernel is associated with a unique op in graph mode),
85    # this denylist has no effect on the aforementioned behavior.
86    #
87    # TODO(ebrevdo,skyewm): Modify the check against this denylist to
88    # only occur when the op is inside a "variable initialization scope"; and
89    # add proper autodeps inside while_loops that respects this updated check.
90    "RandomUniform",
91    "RandomUniformInt",
92    "RandomStandardNormal",
93    "ParameterizedTruncatedNormal",
94    "TruncatedNormal",
95    "RandomShuffle",
96    "Multinomial",
97    "RandomGamma",
98    "RandomGammaGrad",
99    "RandomPoisson",
100    "RandomPoissonV2",
101]
102
103_ORDER_INSENSITIVE_STATEFUL_OPS = [
104    "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
105    "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
106    "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
107    "EnqueueTPUEmbeddingSparseTensorBatch",
108    "EnqueueTPUEmbeddingRaggedTensorBatch", "RestoreV2", "SaveV2"
109]
110# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
111
112_ALL_DENYLISTED_OPS = (
113    set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
114    | set(_ORDER_INSENSITIVE_STATEFUL_OPS))
115
116# Op types that are marked as stateless, but should be allowlisted to add auto
117# control dependencies.
118_ALLOWLIST_STATELESS_OPS = [
119    # As TPU collective ops are blocking, if there are more than one collective
120    # op in the function, we need to make sure different collectives ops are
121    # scheduled in certain orders. Otherwise if at the same time all the
122    # replicas are launching different collective ops/programs, it may cause
123    # deadlock.
124    "AllToAll",
125    "CrossReplicaSum",
126    "CollectivePermute",
127]
128
129
130def op_is_stateful(op):
131  # pylint: disable=protected-access
132  return (op._is_stateful and op.type not in _ALL_DENYLISTED_OPS) or (
133      op.type in _ALLOWLIST_STATELESS_OPS)
134
135
136class ResourceType(enum.Enum):
137  READ_ONLY = "read-only"
138  READ_WRITE = "read-write"
139
140
141def collective_manager_ids_from_op(op):
142  """Returns CollectiveManager ID from the op if one exists, else None.
143
144  CollectiveManager adds collective and no_op operations tagged with an ID,
145  unique to the manager object. This function extracts that ID, or None, if the
146  node was not generated by a CollectiveManager.
147
148  Args:
149    op: `Operation` to get the collective manager ID from.
150
151  Returns:
152    List of CollectiveManager IDs used by the op.
153  """
154  if op.type == "CollectiveReduce":
155    try:
156      return [op.get_attr("_collective_manager_id")]
157    except ValueError:
158      pass
159  elif op.type == "StatefulPartitionedCall":
160    try:
161      return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
162    except ValueError:
163      pass
164  return []
165
166
167class AutomaticControlDependencies(object):
168  """Context manager to automatically add control dependencies.
169
170  Code under this context manager will act as if a sensible set of control
171  dependencies were present. More specifically:
172    1. All stateful ops in the scope will execute (with the exception of ops in
173       ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS)
174    2. Stateful ops which modify the same resource will execute in program order
175
176  Note: creating variables in an automatic control dependencies context is not
177  supported (the value of the variables will never change as they will keep
178  getting reinitialized).
179
180  NOT THREAD SAFE
181  """
182
183  __slots__ = [
184      "_returned_tensors", "ops_which_must_run", "_graph", "_n_operations",
185      "collective_manager_ids_used"
186  ]
187
188  def __init__(self):
189    self._returned_tensors = object_identity.ObjectIdentitySet()
190    self.ops_which_must_run = set()
191
192  def mark_as_return(self, tensor):
193    """Acts like identity but marks the `Tensor` as a return value.
194
195    This will possibly return a copy of the `Tensor`. Usage:
196
197    ```
198      with AutomaticControlDependencies() as a:
199       ...
200       t = a.mark_as_return(t)
201      _ = ...(t...)  # i.e. it's safe to use t here
202    ```
203
204    Args:
205      tensor: the `Tensor` to be marked
206
207    Returns:
208      a copy of the `Tensor`.
209    """
210    if isinstance(tensor, ops.IndexedSlices):
211      values = array_ops.identity(tensor.values)
212      indices = array_ops.identity(tensor.indices)
213      self._returned_tensors.add(indices)
214      self._returned_tensors.add(values)
215      return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape)
216    elif isinstance(tensor, sparse_tensor.SparseTensor):
217      values = array_ops.identity(tensor.values)
218      indices = array_ops.identity(tensor.indices)
219      self._returned_tensors.add(indices)
220      self._returned_tensors.add(values)
221      return sparse_tensor.SparseTensor(
222          indices, values, dense_shape=tensor.dense_shape)
223    elif isinstance(tensor, tensor_array_ops.TensorArray):
224      flow = array_ops.identity(tensor.flow)
225      self._returned_tensors.add(flow)
226      return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
227    # We want to make the return values depend on the stateful operations, but
228    # we don't want to introduce a cycle, so we make the return value the result
229    # of a new identity operation that the stateful operations definitely don't
230    # depend on.
231    tensor = array_ops.identity(tensor)
232    self._returned_tensors.add(tensor)
233    return tensor
234
235  def __enter__(self):
236    if context.executing_eagerly():
237      return self
238    # This code assumes no other thread is adding ops to the graph while
239    # we're adding ops to the graph.
240    # TODO(apassos): Fix this by locking the graph or using a temporary
241    # graph (but that would mess up devices and collections at least,
242    # probably other things as well).
243    self._graph = ops.get_default_graph()
244    self._graph._add_control_dependencies = True  # pylint: disable=protected-access
245    self._n_operations = len(self._graph.get_operations())
246    return self
247
248  def _process_switch(self, switch_op, ops_which_must_run,
249                      last_write_to_resource, merge_for_resource):
250    """Processes a switch node for a resource input.
251
252    When tensorflow creates a cond, it creates a control flow context for each
253    branch of the cond. Each external tensor accessed by that branch is routed
254    through a switch op, which gets created in the graph _after_ the op which
255    uses that tensor get created.
256
257    If the resource comes from another switch op we process that one first.
258
259    _process_switch creates a corresponding merge node for the switch node. This
260    merge node is added to the outer control flow context of the switch
261    node. We also ensure that:
262
263      1. The switch node executes after the previous op which used the resource
264         tensor
265
266      2. Any op which uses a resource output of the switch node executes before
267         the merge for the switch node.
268
269      3. The next op which uses the input resource to the switch node (which
270         might be another switch node for the other branch of the conditional)
271         will execute after the merge node is done.
272
273      4. The merge node is marked as must_run so it will run even if no
274         subsequent operation uses the resource.
275
276    Args:
277      switch_op: the switch op to be processed
278      ops_which_must_run: the set of ops which must run
279      last_write_to_resource: map from resource tensor to last op updating
280        it
281      merge_for_resource: map from resource tensor to merge which must follow
282        all usages of it.
283    """
284    # pylint: disable=protected-access
285    inp = switch_op.inputs[0]
286    input_id = ops.tensor_id(inp)
287    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
288      self._process_switch(inp.op, ops_which_must_run, last_write_to_resource,
289                           merge_for_resource)
290    output = switch_op.outputs[0]
291    output_id = ops.tensor_id(output)
292    if output_id in merge_for_resource:
293      return
294    new_merge = control_flow_ops.merge(
295        switch_op.outputs, name="artificial_merge")
296    new_merge[0].op._control_flow_context = (
297        switch_op._control_flow_context.outer_context)
298    # Ensures the merge always runs
299    ops_which_must_run.add(new_merge[0].op)
300    if input_id in last_write_to_resource:
301      # Ensures the switch executes after the previous op using the resource.
302      switch_op._add_control_input(last_write_to_resource[input_id])
303    # Ensure the next op outside the cond happens after the merge.
304    last_write_to_resource[input_id] = new_merge[0].op
305    if input_id in merge_for_resource:
306      merge_for_resource[input_id]._add_control_input(new_merge[0].op)
307    for o in switch_op.outputs:
308      # Ensures the merge will execute after all ops inside the cond
309      merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
310
311  def __exit__(self, unused_type, unused_value, unused_traceback):
312    # pylint: disable=protected-access
313    if context.executing_eagerly():
314      return
315
316    if self._graph is not ops.get_default_graph():
317      raise RuntimeError(
318          "Graph changed while trying to add control dependencies.")
319
320    if hasattr(self._graph, "outer_graph"):
321      outer_val = self._graph.outer_graph._add_control_dependencies
322      self._graph._add_control_dependencies = outer_val
323    else:
324      self._graph._add_control_dependencies = False
325
326    # map from resource tensor to the last op which wrote to it
327    last_write_to_resource = {}
328    # map from resource tensor to the list of reads from it since the last
329    # write or since the beginning of the function.
330    reads_since_last_write_to_resource = collections.defaultdict(list)
331    # CollectiveManager manager_ids within a particular function call should not
332    # be needed outside of that function call. So we keep them separate (though
333    # the general idea of the maps is the same, in the future, we'll need to
334    # correctly thread the control output outside).
335    # Map from collective manager scope to the last op which used it
336    collective_manager_scopes_opened = {}
337    collective_manager_scopes_used = {}
338    # set of conditional and loop exits
339    ops_which_must_run = set()
340    # merge which must depend on ops which use this resource
341    merge_for_resource = {}
342
343    new_operations = self._graph.get_operations()[self._n_operations:]
344
345    # Ensures that uses of resource tensors get serialized properly and all
346    # execute. This is done by keeping a map from resource tensor to the last op
347    # in graph-construction order which used it (last_write_to_resource).
348    #
349    # Conditionals are written in TensorFlow such that every external tensor
350    # accessed in the conditional goes through a switch op and every return
351    # tensor (it's guaranteed that there will be at least one) goes through a
352    # merge op.
353    #
354    # To handle conditionals, switches are handled in a special way (see
355    # comments for _process_switch). Merge nodes created by TF's conditional
356    # logic (as opposed to by _process_switch) are forced to run and also get a
357    # control dependency added to them to ensure all stateful ops inside their
358    # control flow context run.
359    #
360    # We also ensure that if an op is using a resource output by a switch node
361    # (that is, a resource tensor for which there's a value in
362    # merge_for_resource) this op will run before the merge for that resource.
363    #
364    # We try to add control inputs to nodes respecting their control flow
365    # contexts to avoid dead nodes propagating everywhere and leading to
366    # "retval[0] doesn't have value" errors. If a node gets a control dependency
367    # on a dead node (i.e. a note from an untaken control flow branch) that node
368    # will be marked as dead unless it's a merge node.
369    #
370    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
371    # test that it works. Support while loops. Support init_scope escaping from
372    # this.
373    for op in new_operations:
374      # TODO(apassos) make this code safely support while loops.
375      if control_flow_util.IsInWhileLoop(op):
376        continue
377      control_inputs = set()
378      # Ensure stateful ops run.
379      # Read-only ops are added to control outputs if the read value is
380      # consumed. This covers the case when the read value is returned from
381      # the function since that goes through a tf.identity in mark_as_return.
382      if (op_def_registry.get(op.type) is None or
383          (op_is_stateful(op) and
384           (op.type not in utils.RESOURCE_READ_OPS or
385            any(output.consumers() for output in op.outputs)))):
386        ops_which_must_run.add(op)
387      # Make a note of all opened manager_ids.
388      if op.type == "NoOp":
389        try:
390          collective_manager_scopes_opened[op.get_attr(
391              "_collective_manager_id")] = op
392        except ValueError:
393          pass
394      # Ignore switches (they're handled separately)
395      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
396        continue
397      # Make merges trigger all other computation which must run
398      if op.type == "Merge":
399        for o in ops_which_must_run:
400          op._add_control_input(o)
401          for inp in o.inputs:
402            input_id = ops.tensor_id(inp)
403            if input_id in last_write_to_resource:
404              last_write_to_resource[input_id] = op
405        ops_which_must_run = set([op])
406        continue
407
408      resource_inputs = set()
409      # Check for any resource inputs. If we find any, we update control_inputs
410      # and last_write_to_resource.
411      for inp, resource_type in _get_resource_inputs(op):
412        is_read = resource_type == ResourceType.READ_ONLY
413        input_id = ops.tensor_id(inp)
414
415        # If the op receives the same resource tensor twice as an input, we skip
416        # to avoid the op getting a control dependency on itself.
417        if input_id in resource_inputs:
418          continue
419
420        resource_inputs.add(input_id)
421        # Deal with switches, finally.
422        if inp.op.type == "Switch":
423          self._process_switch(inp.op, ops_which_must_run,
424                               last_write_to_resource, merge_for_resource)
425        is_building_function = op.graph.building_function
426        # Ensure uses of resources are serialized
427        if input_id in last_write_to_resource:
428          if is_building_function or (
429              last_write_to_resource[input_id]._control_flow_context
430              is op._control_flow_context):
431            control_inputs.add(last_write_to_resource[input_id])
432        # Ensure merges happen after the closing of a cond block
433        if input_id in merge_for_resource:
434          merge_for_resource[input_id]._add_control_input(op)
435        if is_read:
436          reads_since_last_write_to_resource[input_id].append(op)
437        else:
438          control_inputs.update(reads_since_last_write_to_resource[input_id])
439          reads_since_last_write_to_resource[input_id] = []
440          last_write_to_resource[input_id] = op
441
442      if (op_is_stateful(op) and not resource_inputs
443          and op._control_flow_context is None):
444        if None in last_write_to_resource:
445          op._add_control_input(last_write_to_resource[None])
446        last_write_to_resource[None] = op
447
448      # Ensure ordering of collective ops
449      manager_ids = collective_manager_ids_from_op(op)
450      for manager_id in manager_ids:
451        if manager_id in collective_manager_scopes_opened:
452          # Chain this function call if the scope was opened.
453          op._add_control_input(collective_manager_scopes_opened[manager_id])
454          collective_manager_scopes_opened[manager_id] = op
455        else:
456          # If this op is in a scope not created here, create a chain starting
457          # at this op.
458          if manager_id in collective_manager_scopes_used:
459            op._add_control_input(collective_manager_scopes_used[manager_id])
460          collective_manager_scopes_used[manager_id] = op
461
462      if control_inputs and not is_building_function:
463        control_inputs = [
464            c for c in control_inputs
465            if c._control_flow_context is op._control_flow_context
466        ]
467
468      op._add_control_inputs(control_inputs)
469
470    # Ensure all ops which must run do run
471    self.ops_which_must_run.update(ops_which_must_run)
472    for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
473      if self.ops_which_must_run:
474        updated_ops_which_must_run = []
475        if r.graph.building_function:
476          updated_ops_which_must_run = self.ops_which_must_run
477        else:
478          updated_ops_which_must_run = [
479              o for o in self.ops_which_must_run
480              if o._control_flow_context is r.op._control_flow_context
481          ]
482        r.op._add_control_inputs(updated_ops_which_must_run)
483
484    self.collective_manager_ids_used = collective_manager_scopes_used
485
486
487_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
488
489
490def register_acd_resource_resolver(f):
491  """Register a function for resolving resources touched by an op.
492
493  `f` is called for every Operation added in the ACD context with the op's
494  original resource reads and writes. `f` is expected to update the sets of
495  resource reads and writes in-place and return True if it updated either of the
496  sets, False otherwise.
497
498  Example:
499  @register_acd_resource_resolver
500  def ResolveIdentity(op, resource_reads, resource_writes):
501    # op: The `Operation` being processed by ACD currently.
502    # resource_reads: An `ObjectIdentitySet` of read-only resources.
503    # resource_writes: An `ObjectIdentitySet` of read-write resources.
504    if not resource_reads or resource_writes:
505      return False
506    def update(resource_inputs):
507      to_add = []
508      to_remove = []
509      for t in resource_inputs:
510        if t.op.type == "Identity":
511          to_remove.append(t)
512          to_add.append(t.op.inputs[0])
513      if not to_add and not to_remove:
514        return False
515      for t in to_remove:
516        resource_inputs.discard(t)
517      resource_inputs.update(to_add)
518      return True
519    return update(resource_reads) or update(resource_writes)
520
521  Args:
522    f: Python function with signature
523    (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool
524
525  Returns:
526    The function `f` after adding it to the registry.
527  """
528  _acd_resource_resolvers_registry.register(f)
529  return f
530
531
532def _get_resource_inputs(op):
533  """Returns an iterable of resources touched by this `op`."""
534  reads, writes = utils.get_read_write_resource_inputs(op)
535  saturated = False
536  while not saturated:
537    saturated = True
538    for key in _acd_resource_resolvers_registry.list():
539      # Resolvers should return true if they are updating the list of
540      # resource_inputs.
541      # TODO(srbs): An alternate would be to just compare the old and new set
542      # but that may not be as fast.
543      updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes)
544      if updated:
545        # Conservatively remove any resources from `reads` that are also writes.
546        reads = reads.difference(writes)
547      saturated = saturated and not updated
548
549  # Note: A resource handle that is not written to is treated as read-only. We
550  # don't have a special way of denoting an unused resource.
551  for t in reads:
552    yield (t, ResourceType.READ_ONLY)
553  for t in writes:
554    yield (t, ResourceType.READ_WRITE)
555
556
557def automatic_control_dependencies(f):
558  """Wraps f to automatically insert control dependencies.
559
560  The inserted dependencies ensure that:
561    1. All stateful ops in f run when the result of f runs
562    2. Updates to the same resources happen in order.
563
564  Args:
565    f: the function to be wrapped.
566
567  Returns:
568    The wrapped function.
569  """
570
571  def wrapper(*args, **kwargs):
572    with AutomaticControlDependencies() as a:
573      result = f(*args, **kwargs)
574      result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
575      return nest.pack_sequence_as(result, result_flat)
576
577  return tf_decorator.make_decorator(f, wrapper)
578