1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AutomaticControlDependencies and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes as dtypes_module
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import control_flow_util
28from tensorflow.python.ops import tensor_array_ops
29from tensorflow.python.util import nest
30from tensorflow.python.util import tf_decorator
31
32# Op types that should not run in program order, e.g. because they need to run
33# asynchronously to avoid deadlock.
34ASYNC_STATEFUL_OPS = [
35    "CollectiveGather",
36    "CollectiveReduce",
37    "CollectiveBcastSend",
38    "CollectiveBcastRecv",
39    "NcclAllReduce",
40]
41
42LEGACY_RANDOM_OPS = [
43    # These may be used in variable initializers -- thus their execution should
44    # not be dependent on other stateful operations.  This is because although
45    # according to program order, tf.Variables may be created in sequence,
46    # their initialization happens outside of the program order (specifically,
47    # in graph mode their initialization happens by calling a grouped
48    # initializer operation or in eager mode, where initialization is lifted
49    # out of the tf.function and executed the first time the function is
50    # executed).
51    #
52    # Unless there is a specific dependency between the initializers
53    # themselves (e.g. one initializer depends on a Variable whose value depends
54    # on another initializer), the initialization can happen in any order so
55    # long as it's before the associated Variable read operations.
56    #
57    # Note that in general the randomness of legacy random operations is only
58    # guaranteed by providing a graph-level and op-level seed (and ordering of
59    # the same op across multiple iterations of a while_loop is specifically not
60    # guaranteed; see the discussion below).
61    #
62    # There is a possible race condition inside while_loop where the same
63    # random OpKernel instantiation is reused across multiple steps
64    # of the loop.  Since legacy Random OpKernels have an internal rng state,
65    # automatic dependency tracking across loop steps would likely
66    # fix this race; and for that case this blacklist is problematic.
67    # However, since automatic dependency tracking inside while loops is not
68    # currently supported, and there are no other examples of OpKernel reuse
69    # (each OpKernel is associated with a unique op in graph mode),
70    # this blacklist has no effect on the aforementioned behavior.
71    #
72    # TODO(ebrevdo,skyewm): Modify the check against this blacklist to
73    # only occur when the op is inside a "variable initialization scope"; and
74    # add proper autodeps inside while_loops that respects this updated check.
75    "RandomUniform",
76    "RandomUniformInt",
77    "RandomStandardNormal",
78    "ParameterizedTruncatedNormal",
79    "TruncatedNormal",
80    "RandomShuffle",
81    "Multinomial",
82    "RandomGamma",
83    "RandomGammaGrad",
84    "RandomPoisson",
85    "RandomPoissonV2",
86]
87
88_ALL_BLACKLISTED_OPS = set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
89
90
91def op_is_stateful(op_def):
92  return op_def.is_stateful and op_def.name not in _ALL_BLACKLISTED_OPS
93
94
95class AutomaticControlDependencies(object):
96  """Context manager to automatically add control dependencies.
97
98  Code under this context manager will act as if a sensible set of control
99  dependencies were present. More specifically:
100    1. All stateful ops in the scope will execute (with the exception of ops in
101       ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS)
102    2. Stateful ops which modify the same resource will execute in program order
103
104  Note: creating variables in an automatic control dependencies context is not
105  supported (the value of the variables will never change as they will keep
106  getting reinitialized).
107
108  NOT THREAD SAFE
109  """
110
111  def __init__(self):
112    self._returned_tensors = set()
113    self.ops_which_must_run = set()
114
115  def mark_as_return(self, tensor):
116    """Acts like identity but marks the `Tensor` as a return value.
117
118    This will possibly return a copy of the `Tensor`. Usage:
119
120    ```
121      with AutomaticControlDependencies() as a:
122       ...
123       t = a.mark_as_return(t)
124      _ = ...(t...)  # i.e. it's safe to use t here
125    ```
126
127    Args:
128      tensor: the `Tensor` to be marked
129
130    Returns:
131      a copy of the `Tensor`.
132    """
133    if isinstance(tensor, ops.IndexedSlices):
134      values = array_ops.identity(tensor.values)
135      indices = array_ops.identity(tensor.indices)
136      self._returned_tensors.add(indices)
137      self._returned_tensors.add(values)
138      return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape)
139    elif isinstance(tensor, sparse_tensor.SparseTensor):
140      values = array_ops.identity(tensor.values)
141      indices = array_ops.identity(tensor.indices)
142      self._returned_tensors.add(indices)
143      self._returned_tensors.add(values)
144      return sparse_tensor.SparseTensor(
145          indices, values, dense_shape=tensor.dense_shape)
146    elif isinstance(tensor, tensor_array_ops.TensorArray):
147      flow = array_ops.identity(tensor.flow)
148      self._returned_tensors.add(flow)
149      return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
150    # We want to make the return values depend on the stateful operations, but
151    # we don't want to introduce a cycle, so we make the return value the result
152    # of a new identity operation that the stateful operations definitely don't
153    # depend on.
154    tensor = array_ops.identity(tensor)
155    self._returned_tensors.add(tensor)
156    return tensor
157
158  def __enter__(self):
159    if context.executing_eagerly():
160      return self
161    # This code assumes no other thread is adding ops to the graph while
162    # we're adding ops to the graph.
163    # TODO(apassos): Fix this by locking the graph or using a temporary
164    # graph (but that would mess up devices and collections at least,
165    # probably other things as well).
166    self._graph = ops.get_default_graph()
167    self._graph._add_control_dependencies = True  # pylint: disable=protected-access
168    self._n_operations = len(self._graph.get_operations())
169    return self
170
171  def _process_switch(self, switch_op, ops_which_must_run,
172                      last_op_using_resource_tensor, merge_for_resource):
173    """Processes a switch node for a resource input.
174
175    When tensorflow creates a cond, it creates a control flow context for each
176    branch of the cond. Each external tensor accessed by that branch is routed
177    through a switch op, which gets created in the graph _after_ the op which
178    uses that tensor get created.
179
180    If the resource comes from another switch op we process that one first.
181
182    _process_switch creates a corresponding merge node for the switch node. This
183    merge node is added to the outer control flow context of the switch
184    node. We also ensure that:
185
186      1. The switch node executes after the previous op which used the resource
187         tensor
188
189      2. Any op which uses a resource output of the switch node executes before
190         the merge for the switch node.
191
192      3. The next op which uses the input resource to the switch node (which
193         might be another switch node for the other branch of the conditional)
194         will execute after the merge node is done.
195
196      4. The merge node is marked as must_run so it will run even if no
197         subsequent operation uses the resource.
198
199    Args:
200      switch_op: the switch op to be processed
201      ops_which_must_run: the set of ops which must run
202      last_op_using_resource_tensor: map from resource tensor to last op using
203        it
204      merge_for_resource: map from resource tensor to merge which must follow
205        all usages of it.
206    """
207    inp = switch_op.inputs[0]
208    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
209      self._process_switch(inp.op, ops_which_must_run,
210                           last_op_using_resource_tensor, merge_for_resource)
211    if switch_op.outputs[0] in merge_for_resource:
212      return
213    new_merge = control_flow_ops.merge(switch_op.outputs,
214                                       name="artificial_merge")
215    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
216        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
217    # Ensures the merge always runs
218    ops_which_must_run.add(new_merge[0].op)
219    if inp in last_op_using_resource_tensor:
220      # Ensures the switch executes after the previous op using the resource.
221      switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
222    # Ensure the next op outside the cond happens after the merge.
223    last_op_using_resource_tensor[inp] = new_merge[0].op
224    if inp in merge_for_resource:
225      merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
226    for o in switch_op.outputs:
227      # Ensures the merge will execute after all ops inside the cond
228      merge_for_resource[o] = new_merge[0].op
229
230  def __exit__(self, unused_type, unused_value, unused_traceback):
231    if context.executing_eagerly():
232      return
233
234    if self._graph is not ops.get_default_graph():
235      raise RuntimeError(
236          "Graph changed while trying to add control dependencies.")
237
238    # pylint: disable=protected-access
239    if hasattr(self._graph, "outer_graph"):
240      outer_val = self._graph.outer_graph._add_control_dependencies
241      self._graph._add_control_dependencies = outer_val
242    else:
243      self._graph._add_control_dependencies = False
244    # pylint: enable=protected-access
245
246    # map from resource tensor to the last op which used it
247    last_op_using_resource_tensor = {}
248    # set of conditional and loop exits
249    ops_which_must_run = set()
250    # merge which must depend on ops which use this resource
251    merge_for_resource = {}
252
253    new_operations = self._graph.get_operations()[self._n_operations:]
254
255    # Ensures that uses of resource tensors get serialized properly and all
256    # execute. This is done by keeping a map from resource tensor to the last op
257    # in graph-construction order which used it (last_op_using_resource_tensor).
258    #
259    # Conditionals are written in TensorFlow such that every external tensor
260    # accessed in the conditional goes through a switch op and every return
261    # tensor (it's guaranteed that there will be at least one) goes through a
262    # merge op.
263    #
264    # To handle conditionals, switches are handled in a special way (see
265    # comments for _process_switch). Merge nodes created by TF's conditional
266    # logic (as opposed to by _process_switch) are forced to run and also get a
267    # control dependency added to them to ensure all stateful ops inside their
268    # control flow context run.
269    #
270    # We also ensure that if an op is using a resource output by a switch node
271    # (that is, a resource tensor for which there's a value in
272    # merge_for_resource) this op will run before the merge for that resource.
273    #
274    # We try to add control inputs to nodes respecting their control flow
275    # contexts to avoid dead nodes propagating everywhere and leading to
276    # "retval[0] doesn't have value" errors. If a node gets a control dependency
277    # on a dead node (i.e. a note from an untaken control flow branch) that node
278    # will be marked as dead unless it's a merge node.
279    #
280    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
281    # test that it works. Support while loops. Support init_scope escaping from
282    # this.
283    for op in new_operations:
284      # TODO(apassos) make this code safely support while loops.
285      if control_flow_util.IsInWhileLoop(op):
286        continue
287      control_inputs = set()
288      # Ensure stateful ops run
289      if (op.type not in self._graph._registered_ops  # pylint: disable=protected-access
290          or op_is_stateful(self._graph._registered_ops[op.type])):  # pylint: disable=protected-access
291        ops_which_must_run.add(op)
292      # Ignore switches (they're handled separately)
293      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
294        continue
295      # Make merges trigger all other computation which must run
296      if op.type == "Merge":
297        for o in ops_which_must_run:
298          op._add_control_input(o)  # pylint: disable=protected-access
299          for inp in o.inputs:
300            if inp in last_op_using_resource_tensor:
301              last_op_using_resource_tensor[inp] = op
302        ops_which_must_run = set([op])
303        continue
304      found_resource = False
305      # Check for any resource inputs. If we find any, we update control_inputs
306      # and last_op_using_resource_tensor. Note that we dedup op.inputs in case
307      # op receives the same resource tensor twice as input, which would result
308      # in op getting a control dependency on itself.
309      for inp in set(op.inputs):
310        if inp.dtype != dtypes_module.resource:
311          continue
312        found_resource = True
313        # Deal with switches, finally.
314        if inp.op.type == "Switch":
315          self._process_switch(inp.op, ops_which_must_run,
316                               last_op_using_resource_tensor,
317                               merge_for_resource)
318        # Ensure uses of resources are serialized
319        if inp in last_op_using_resource_tensor:
320          if (last_op_using_resource_tensor[inp]._control_flow_context  # pylint: disable=protected-access
321              is op._control_flow_context):  # pylint: disable=protected-access
322            control_inputs.add(last_op_using_resource_tensor[inp])
323        # Ensure merges happen after the closing of a cond block
324        if inp in merge_for_resource:
325          merge_for_resource[inp]._add_control_input(op)  # pylint: disable=protected-access
326        last_op_using_resource_tensor[inp] = op
327      if (op_is_stateful(op.op_def) and not found_resource
328          and op._control_flow_context is None):  # pylint: disable=protected-access
329        if None in last_op_using_resource_tensor:
330          op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
331        last_op_using_resource_tensor[None] = op
332      control_inputs = [c for c in control_inputs
333                        if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
334      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access
335
336    # Ensure all ops which must run do run
337    self.ops_which_must_run.update(ops_which_must_run)
338    for r in self._returned_tensors:
339      if self.ops_which_must_run:
340        r.op._add_control_inputs(  # pylint: disable=protected-access
341            [o for o in self.ops_which_must_run
342             if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access
343
344
345def automatic_control_dependencies(f):
346  """Wraps f to automatically insert control dependencies.
347
348  The inserted dependencies ensure that:
349    1. All stateful ops in f run when the result of f runs
350    2. Updates to the same resources happen in order.
351
352  Args:
353    f: the function to be wrapped.
354
355  Returns:
356    The wrapped function.
357  """
358
359  def wrapper(*args, **kwargs):
360    with AutomaticControlDependencies() as a:
361      result = f(*args, **kwargs)
362      result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
363      return nest.pack_sequence_as(result, result_flat)
364
365  return tf_decorator.make_decorator(f, wrapper)
366