1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for control flow.
17
18This file is necessary to avoid cyclic dependencies between ops.py and
19control_flow_ops.py.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27import traceback
28
29from tensorflow.python import tf2
30from tensorflow.python.platform import tf_logging as logging
31
32ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
33                           os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
34                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
35                          os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
36                          os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
37                          os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
38
39
40# TODO(b/137793122): Remove this.
41def enable_control_flow_v2():  # pylint: disable=invalid-name
42  """Use control flow v2.
43
44  Do not use this symbol. This will be removed.
45  """
46  global ENABLE_CONTROL_FLOW_V2
47  ENABLE_CONTROL_FLOW_V2 = True
48
49
50def EnableControlFlowV2(graph):
51  """Returns whether control flow v2 should be used in `graph`."""
52  # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
53  # TODO(skyewm): do something better than hasattr without messing up imports.
54  return ENABLE_CONTROL_FLOW_V2 or (
55      graph.building_function and not hasattr(graph, "_captured"))
56
57
58def IsInXLAContext(op):
59  try:
60    xla_compile = op.get_attr("_XlaCompile")
61    if xla_compile: return True
62  except ValueError:
63    pass
64  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
65  return GetContainingXLAContext(ctxt) is not None
66
67
68def InXlaContext(graph):
69  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
70  return GetContainingXLAContext(ctxt) is not None
71
72
73def GraphOrParentsInXlaContext(graph):
74  while True:
75    if InXlaContext(graph): return True
76    try:
77      graph = graph.outer_graph
78    except AttributeError:
79      return False
80
81
82def IsInWhileLoop(op):
83  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
84  return GetContainingWhileContext(ctxt) is not None
85
86
87def IsInCond(op):
88  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
89  return GetContainingCondContext(ctxt) is not None
90
91
92def IsSwitch(op):
93  """Return true if `op` is a Switch."""
94  return op.type == "Switch" or op.type == "RefSwitch"
95
96
97def IsMerge(op):
98  """Return true if `op` is a Merge."""
99  return op.type == "Merge" or op.type == "RefMerge"
100
101
102def IsLoopEnter(op):
103  """Returns true if `op` is an Enter."""
104  return op.type == "Enter" or op.type == "RefEnter"
105
106
107def IsLoopExit(op):
108  """Return true if `op` is an Exit."""
109  return op.type == "Exit" or op.type == "RefExit"
110
111
112def IsCondSwitch(op):
113  """Return true if `op` is the Switch for a conditional."""
114  if not IsSwitch(op):
115    return False
116  if not op.outputs:
117    return False
118  # Switch nodes are not part of the cond control flow context that they
119  # represent, so consider the consumers of its outputs to determine if it is
120  # cond switch or not. A switch is a cond switch iff all its consumers are in
121  # cond contexts.
122  is_cond_switch = True
123  for o in op.outputs:
124    for c in o.consumers():
125      ctxt = c._get_control_flow_context()  # pylint: disable=protected-access
126      if IsLoopEnter(c):
127        ctxt = ctxt.outer_context
128      is_cond_switch = is_cond_switch and (ctxt is not None and
129                                           ctxt.IsCondContext())
130  return is_cond_switch
131
132
133def IsCondMerge(op):
134  """Return true if `op` is the Merge for a conditional."""
135  if not IsMerge(op):
136    return False
137  if not op.inputs:
138    return False
139  # Merge nodes are not part of the cond control flow context that they
140  # represent, so consider the inputs to the merge of to determine if it is
141  # cond merge or not: A merge is a cond merge iff all its inputs are in
142  # cond contexts.
143  is_cond_merge = True
144  for i in op.inputs:
145    ctxt = GetOutputContext(i.op)
146    is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
147  return is_cond_merge
148
149
150def IsLoopSwitch(op):
151  """Return true if `op` is the Switch for a while loop."""
152  if IsSwitch(op):
153    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
154    return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
155  return False
156
157
158def IsLoopMerge(op):
159  """Return true if `op` is the Merge for a while loop."""
160  if IsMerge(op):
161    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
162    return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
163  return False
164
165
166def IsLoopConstantEnter(op):
167  """Return true iff op is a loop invariant."""
168  return IsLoopEnter(op) and op.get_attr("is_constant")
169
170
171def GetLoopConstantEnter(value):
172  """Return the enter op if we can infer `value` to be a loop invariant."""
173  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
174  op = value.op
175  while op.type in id_ops:
176    op = op.inputs[0].op
177  return op if IsLoopConstantEnter(op) else None
178
179
180def GetOutputContext(op):
181  """Return the control flow context for the output of an op."""
182  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
183  # Exit nodes usually have a control flow context, except in the case where the
184  # exit node was imported via import_graph_def (in which case no nodes have
185  # control flow contexts).
186  if ctxt is not None and IsLoopExit(op):
187    ctxt = ctxt.outer_context
188  return ctxt
189
190
191def GetContainingWhileContext(ctxt, stop_ctxt=None):
192  """Returns the first ancestor WhileContext of `ctxt`.
193
194  Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
195  while loop.
196
197  Args:
198    ctxt: ControlFlowContext
199    stop_ctxt: ControlFlowContext, optional. If provided, the search will end
200      if it sees stop_ctxt.
201
202  Returns:
203    `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
204    `ctxt`, or None if `ctxt` is not in a while loop.  If `stop_ctxt` is not
205    `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
206  """
207  while ctxt:
208    if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
209    ctxt = ctxt.outer_context
210  return None
211
212
213def GetContainingXLAContext(ctxt):
214  """Returns the first ancestor XLAContext of `ctxt`.
215
216  Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
217  while loop.
218
219  Args:
220    ctxt: ControlFlowContext
221
222  Returns:
223    `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
224    `ctxt`, or None if `ctxt` is not in a while loop.
225  """
226  while ctxt:
227    if ctxt.IsXLAContext(): return ctxt
228    ctxt = ctxt.outer_context
229  return None
230
231
232def GetContainingCondContext(ctxt):
233  """Returns the first ancestor CondContext of `ctxt`.
234
235  Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
236
237  Args:
238    ctxt: ControlFlowContext
239
240  Returns:
241    `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
242    `ctxt`, or None if `ctxt` is not in a cond.
243  """
244  while ctxt:
245    if ctxt.IsCondContext(): return ctxt
246    ctxt = ctxt.outer_context
247  return None
248
249
250def IsContainingContext(ctxt, maybe_containing_ctxt):
251  """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
252  while ctxt is not maybe_containing_ctxt:
253    if ctxt is None: return False
254    ctxt = ctxt.outer_context
255  return True
256
257
258def OpInContext(op, ctxt):
259  return IsContainingContext(op._get_control_flow_context(), ctxt)  # pylint: disable=protected-access
260
261
262def TensorInContext(tensor, ctxt):
263  return OpInContext(tensor.op, ctxt)
264
265
266def CheckInputFromValidContext(op, input_op):
267  """Returns whether `input_op` can be used from `op`s context.
268
269  Conceptually, only inputs from op's while context or any ancestor while
270  context (including outside of any context) are valid. In practice, there are
271  many other edge cases as well.
272
273  Args:
274    op: Operation
275    input_op: Operation
276
277  Raises:
278    ValueError: if input_op is from an invalid context.
279  """
280  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
281  input_ctxt = GetOutputContext(input_op)
282  valid = False
283
284  if not input_ctxt:
285    # input_op isn't in a control flow context.
286    valid = True
287  elif op_ctxt is input_ctxt:
288    # input_op is in the same context as op.
289    valid = True
290  else:
291    while_ctxt = GetContainingWhileContext(op_ctxt)
292    input_while_ctxt = GetContainingWhileContext(input_ctxt)
293
294    if while_ctxt is None:
295      if input_while_ctxt is None:
296        # Neither op nor input_op is in a while loop, but one or both are in
297        # conds. We allow this, although execution will fail if the branch
298        # corresponding to input_op's cond context isn't taken.
299        valid = True
300      # Invalid if op isn't in a while loop and input_op is. Unless...
301      if IsLoopEnter(op):
302        # WhileContext._BuildLoop clears context for Enter nodes.
303        valid = True
304      if IsSwitch(op):
305        # CondContext.AddValue clears context for Switch nodes.
306        valid = True
307    elif IsContainingContext(while_ctxt, input_while_ctxt):
308      # input_op is in a while loop which contains op's while loop (or not in a
309      # while loop at all).
310      valid = True
311    elif (while_ctxt.grad_state and
312          IsContainingContext(while_ctxt.grad_state.forward_context,
313                              input_while_ctxt)):
314      # op is in a gradient context and input_op is in the associated forward
315      # pass context or an ancestor thereof. This case is need to build while
316      # loop gradients.
317      # NOTE(skyewm): we theoretically also need this case for custom gradient
318      # functions that close over tensors from ancestor contexts, but I haven't
319      # verified this.
320      valid = True
321    elif (while_ctxt.grad_state and
322          while_ctxt.grad_state.forward_context is
323          input_while_ctxt._outer_context):  # pylint: disable=protected-access
324      # op is in a gradient context and input_op is in a child of the associated
325      # forward pass context. This case is needed for the gradients of while
326      # loops with conds.
327      valid = True
328    elif (input_while_ctxt.grad_state and
329          input_while_ctxt.grad_state.forward_context is while_ctxt):
330      # input_op is in the gradient context of op's context. This case is needed
331      # when the gradient of a while loop gradient is requested (this will
332      # eventually fail unless there is a stop_gradient() or similar).
333      valid = True
334    elif (input_while_ctxt.grad_state and
335          input_ctxt.grad_state.forward_context.grad_state and
336          input_ctxt.grad_state.forward_context.grad_state.forward_context is
337          while_ctxt):
338      # input_op is in the grad grad context of op's context. This case is
339      # needed when the gradient of a while loop gradient is requested (this
340      # will eventually fail unless there is a stop_gradient() or similar).
341      valid = True
342
343  if not valid:
344    if while_ctxt:
345      error_msg = (
346          "Cannot use '%s' as input to '%s' because they are in different while"
347          " loops." % (input_op.name, op.name))
348    else:
349      error_msg = (
350          "Cannot use '%s' as input to '%s' because '%s' is in a while loop."
351          % (input_op.name, op.name, input_op.name))
352
353    # Log the error message plus the relevant stack traces. The stacks may be
354    # useful for debugging this error, but we don't want to raise an
355    # unreadable exception.
356    log_msg = error_msg
357    log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
358    log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
359    log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
360        op.name, "".join(traceback.format_list(op.traceback)),
361        input_op.name, "".join(traceback.format_list(input_op.traceback)))
362    logging.info(log_msg)
363    raise ValueError(error_msg + " See info log for more details.")
364
365
366def GetWhileContext(op):
367  """Get the WhileContext to which this op belongs."""
368  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
369  if ctxt:
370    ctxt = ctxt.GetWhileContext()
371  return ctxt
372