1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for V2 control flow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.distribute import distribution_strategy_context
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import function_def_to_graph
27from tensorflow.python.framework import ops
28from tensorflow.python.framework.func_graph import FuncGraph
29from tensorflow.python.ops import control_flow_util
30from tensorflow.python.ops import control_flow_v2_func_graphs
31from tensorflow.python.ops import gradients_util
32from tensorflow.python.util import keras_deps
33from tensorflow.python.util import tf_contextlib
34
35
36_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
37_DISABLE_LOWER_USING_SWITCH_MERGE = False
38
39
40CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
41WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph
42WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph
43
44
45def in_defun():
46  """Returns if the current graph is, or is nested in, a defun."""
47  if context.executing_eagerly(): return False
48
49  graph = ops.get_default_graph()
50  while (isinstance(graph, CondBranchFuncGraph) or
51         isinstance(graph, WhileBodyFuncGraph) or
52         isinstance(graph, WhileCondFuncGraph)):
53    graph = graph.outer_graph
54  return isinstance(graph, FuncGraph)
55
56
57def in_while_loop_defun(graph):
58  """Returns if the graph is a while loop FuncGraph."""
59  if context.executing_eagerly(): return False
60  return (isinstance(graph, WhileCondFuncGraph) or
61          isinstance(graph, WhileBodyFuncGraph))
62
63
64def create_new_tf_function(func_graph):
65  """Converts func_graph to a TF_Function and adds it to the current graph.
66
67  Args:
68    func_graph: FuncGraph
69
70  Returns:
71    The name of the new TF_Function.
72  """
73  func = function._EagerDefinedFunction(  # pylint: disable=protected-access
74      func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
75  func.add_to_graph(func_graph.outer_graph)
76  return func_graph.name
77
78
79def unique_fn_name(scope, name):
80  """Returns a unique name to use for a control flow function.
81
82  Args:
83    scope: A name scope string.
84    name: An identifier for this function (e.g. "true", "body").
85
86  Returns:
87    A string, the name to use for the function.
88  """
89  return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_")
90
91
92def unique_grad_fn_name(forward_name):
93  return "%s_grad_%s" % (forward_name, ops.uid())
94
95
96def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
97  """Sets the flag to enable lowering on `op` if necessary.
98
99  Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
100  Functions, allowing users to specify devices & colocation inside of cond_v2
101  and while_v2 input functions, and enabling non-strict evaluation & partial
102  pruning. This brings v2 control flow closer to feature parity with v1 control
103  flow.
104
105  However, we do not lower in the following cases:
106    - When the `If` or `While` ops are in the XLA context. Because it is easier
107      for XLA to apply its own optimizations when dealing with un-lowered
108      control flow operators than with low-level control flow primitives.
109    - When the eager execution context specifies the executor of functions to
110      be the single threaded executor (see context.function_executor_type()).
111      Because the single threaded executor does not support v1 control flow ops.
112    - When 'lower_using_switch_merge' is explicitly set to False.
113
114  Args:
115    op: An `If` or `While` Operation.
116    lower_using_switch_merge: Explicit value to lower or not (optional).
117  """
118  if lower_using_switch_merge is not None:
119    # pylint: disable=protected-access
120    op._set_attr("_lower_using_switch_merge",
121                 attr_value_pb2.AttrValue(b=lower_using_switch_merge))
122    # pylint: enable=protected-access
123  elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
124        not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
125        context.context().function_call_options.executor_type !=
126        "SINGLE_THREADED_EXECUTOR"):
127    # pylint: disable=protected-access
128    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
129    # pylint: enable=protected-access
130
131
132def maybe_propagate_compile_time_consts_in_xla(op):
133  """Tells XLA whether to propagate compile-time consts in the loop body.
134
135  This is needed to make compile time constants available to ops, for example
136  `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this
137  would always be turned on, but that doesn't work with legacy functionalized
138  while_loops.
139
140  Args:
141    op: A `While` Operation.
142  """
143  if control_flow_util.GraphOrParentsInXlaContext(op.graph):
144    # pylint: disable=protected-access
145    op._set_attr("_xla_propagate_compile_time_consts",
146                 attr_value_pb2.AttrValue(b=True))
147    # pylint: enable=protected-access
148
149
150def resource_input_index(tensor_name, input_names, node_defs, functions):
151  """Returns the index of the input corresponding to `tensor_name`.
152
153  This method is used to find the corresponding index of an arbitrary resource
154  tensor in a function (the function could be a loop body). We assume that
155  resource handles are never created in functions, so that every resource
156  tensor can be traced back to a function input.
157
158  The awkward signature of this method is to make it work with both FuncGraphs
159  and FunctionDefs. This is so we can recurse on function call ops without
160  building the corresponding FuncGraph (note that even if a FuncGraph for a
161  FunctionDef already exists, the input/output/node names may have been
162  changed when the FuncGraph was serialized to the FunctionDef, which makes it
163  unusable with this algorithm).
164
165  Args:
166    tensor_name: the name of the resource tensor to be resolved to an input.
167    input_names: a list of the names of all inputs to the function.
168    node_defs: a dict mapping op name -> NodeDef for every op in the function.
169    functions: a dict mapping function name -> _EagerDefinedFunction.
170
171  Returns:
172    The index into input_names corresponding to `tensor_name`.
173  """
174  while tensor_name not in input_names:
175    # FunctionDefs and graphs use different tensor naming conventions.
176    parts = tensor_name.split(":")
177    if len(parts) == 3:
178      op_name, _, output_idx = parts
179    elif len(parts) == 2:
180      op_name, output_idx = parts
181    else:
182      assert len(parts) == 1
183      op_name = parts[0]
184      output_idx = 0
185      tensor_name = "%s:%d" % (tensor_name, output_idx)
186      # Check again for cases where the tensor suffix (":0") is stripped out.
187      if tensor_name in input_names:
188        break
189    output_idx = int(output_idx)
190    node_def = node_defs[op_name]
191
192    if node_def.op in ("Identity", "While"):
193      # Captured resources occur at the same index in the lists of inputs and
194      # outputs of a while or identity op. So we lookup the input of `tensor.op`
195      # at the same index as the index of `tensor` in the `tensor.op.outputs`.
196      tensor_name = node_def.input[output_idx]
197    elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
198      # Functions output any captured resource tensors used by their
199      # gradients.  `tensor_name` is one of these outputs from a nested
200      # function call, so recursively find the corresponding input in the
201      # nested FunctionDef.
202      func_name = node_def.attr["f"].func.name
203      fdef = functions[func_name].definition
204      output_arg_name = fdef.signature.output_arg[output_idx].name
205      output_tensor_name = fdef.ret[output_arg_name]
206      input_index = resource_input_index(
207          output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
208          {ndef.name: ndef for ndef in fdef.node_def}, functions)
209      tensor_name = node_def.input[input_index]
210    else:
211      # We assume there are no other ops types that will "forward" resource
212      # handles like this, so all other handles must have been created by the
213      # op. (Note that cond_v2 wraps resource handle outputs in optionals,
214      # which we'll end up accumulating).
215      raise ValueError("Taking gradient of a while loop which creates "
216                       "a resource in its body is not supported: %s" % op_name)
217
218  return input_names.index(tensor_name)
219
220
221@tf_contextlib.contextmanager
222def clear_control_inputs():
223  """Clears the control inputs but preserves the ControlFlowContext.
224
225  This is needed to preserve the XLAControlFlowControl when clearing
226  control inputs for the gradient accumulators in while_v2.
227  `ops.control_dependencies` does not allow that.
228
229  Yields:
230    A context manager in which the ops created will not have any control inputs
231    by default but the control flow context is the same.
232  """
233  # pylint: disable=protected-access
234  control_flow_context = ops.get_default_graph()._get_control_flow_context()
235  with ops.control_dependencies(None):
236    ops.get_default_graph()._set_control_flow_context(control_flow_context)
237    yield
238  # pylint: enable=protected-access
239
240
241def _is_tpu_strategy(strategy):
242  return (strategy is not None and
243          strategy.__class__.__name__.startswith("TPUStrategy"))
244
245
246def _is_building_keras_layer():
247  # TODO(srbs): Remove this function when we no long support session with Keras.
248  keras_call_context_function = keras_deps.get_call_context_function()
249  if keras_call_context_function:
250    return keras_call_context_function().layer is not None
251  else:
252    return False
253
254
255def output_all_intermediates():
256  """Whether to output all intermediates of a functional control flow op.
257
258  The default behavior is to output intermediates only when building a Keras
259  Layer in graph mode and that too when certain other conditions are met:
260  1. We do not output intermediates if the functional control flow op
261     is being built inside a FuncGraph which is not a If/While graph. This
262     guards against outputting intermediates in eager mode since keras adds
263     tensors to a FuncGraph named "keras_graph" in that case. Also because we
264     do not output intermediates of tf.function (since this feature is only for
265     backwards compatibility) outputting intermediates of functional control
266     flow ops built inside tf.function is of no value.
267  2. We do not output intermediates when the compilation is using XLA or for a
268     TPU.
269  3. We do not output intermediates when a single threaded executor is used
270     since that does not perform inlining and pruning.
271
272  Returns:
273    A bool telling whether to output all intermediates.
274  """
275  if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None:
276    return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
277  if in_defun():
278    return False
279  if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or
280      _is_tpu_strategy(distribution_strategy_context.get_strategy())):
281    return False
282  if (context.context().function_call_options.executor_type ==
283      "SINGLE_THREADED_EXECUTOR"):
284    return False
285  return _is_building_keras_layer()
286
287
288def get_func_graph(op, input_shapes, func_name):
289  """Generates and returns a FuncGraph for the given op and input_shapes."""
290  fdef = None
291  graph = op.graph
292  # Recursively search the func in graphs.
293  while graph is not None:
294    func = graph._get_function(func_name)  # pylint: disable=protected-access
295    if func is not None:
296      fdef = func.definition
297      break
298    if hasattr(graph, "outer_graph"):
299      graph = graph.outer_graph
300    else:
301      break
302
303  if fdef is None:
304    raise KeyError("%s cannot be found in the graph" % func_name)
305
306  # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
307  # in the case of nested if ops or when the gradient is being computed
308  # from inside a Defun. We build the `func_graph` with `op.graph` as its
309  # `outer_graph`. This resembles how the `FuncGraph` was built in the
310  # forward pass. We need this so that we can resolve references to tensors
311  # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
312  with op.graph.as_default():
313    func_graph = function_def_to_graph.function_def_to_graph(
314        fdef, input_shapes)
315  return func_graph
316
317
318def get_op_and_outputs(op_or_outputs):
319  if isinstance(op_or_outputs, ops.Operation):
320    return op_or_outputs, []
321  elif not op_or_outputs:  # Empty list.
322    return None, []
323  else:
324    return op_or_outputs[0].op, op_or_outputs
325
326
327def graph_wrapped_for_higher_order_tape_gradients(graph):
328  """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`."""
329  while graph is not None:
330    if "cflow_gradient_wrapper" in getattr(graph, "name", ""):
331      return True
332    graph = getattr(graph, "outer_graph", None)
333  return False
334
335
336def run_as_function_for_tape_gradients(make_op, inputs):
337  """Fix higher-order tape gradients by wrapping `make_op` in a function.
338
339  Args:
340    make_op: A function that takes a list of inputs and returns a list of output
341      tensors. This function should set any handle data relevant to its outputs
342      before returning.
343    inputs: A list of tensors to check for tape gradients and pass to
344      `make_op`. These should include all tensors used in `make_op`.
345
346  Returns:
347    Tensors corresponding to `make_op`'s output.
348  """
349  # GradientTapes created inside a function currently don't work well with
350  # un-wrapped control flow ops in that same function. Wrapping in an extra
351  # layer of intermediate function means we run extra logic in the function
352  # gradient code to record the correct intermediates on the tape.
353  #
354  # The function attribute inputs to control flow ops are not hashable, so we
355  # pass everything as a capture to bypass defun's caching.
356  if (gradients_util.PossibleTapeGradientTypes(inputs)
357      == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
358      # We only need one function between the tape and the op; if we've already
359      # wrapped once, we stop wrapping to avoid infinite recursion.
360      and not (ops.get_default_graph().building_function
361               and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
362    results = function.defun_with_attributes(
363        make_op,
364        autograph=False,
365        attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
366    return results
367  else:
368    return make_op(inputs)
369