1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilty functions for control flow.
17
18This file is necessary to avoid cyclic dependencies between ops.py and
19control_flow_ops.py.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import traceback
27
28from tensorflow.python.platform import tf_logging as logging
29
30
31def IsInXLAContext(op):
32  try:
33    xla_compile = op.get_attr("_XlaCompile")
34    if xla_compile: return True
35  except ValueError:
36    pass
37  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
38  return GetContainingXLAContext(ctxt) is not None
39
40
41def IsInWhileLoop(op):
42  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
43  return GetContainingWhileContext(ctxt) is not None
44
45
46def IsInCond(op):
47  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
48  return GetContainingCondContext(ctxt) is not None
49
50
51def IsSwitch(op):
52  """Return true if `op` is a Switch."""
53  return op.type == "Switch" or op.type == "RefSwitch"
54
55
56def IsLoopEnter(op):
57  """Returns true if `op` is an Enter."""
58  return op.type == "Enter" or op.type == "RefEnter"
59
60
61def IsLoopExit(op):
62  """Return true if `op` is an Exit."""
63  return op.type == "Exit" or op.type == "RefExit"
64
65
66def IsLoopSwitch(op):
67  """Return true if `op` is the Switch for a while loop."""
68  if IsSwitch(op):
69    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
70    return ctxt and ctxt.IsWhileContext()
71  return False
72
73
74def IsLoopConstantEnter(op):
75  """Return true iff op is a loop invariant."""
76  return IsLoopEnter(op) and op.get_attr("is_constant")
77
78
79def GetLoopConstantEnter(value):
80  """Return the enter op if we can infer `value` to be a loop invariant."""
81  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
82  op = value.op
83  while op.type in id_ops:
84    op = op.inputs[0].op
85  return op if IsLoopConstantEnter(op) else None
86
87
88def GetOutputContext(op):
89  """Return the control flow context for the output of an op."""
90  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
91  # Exit nodes usually have a control flow context, except in the case where the
92  # exit node was imported via import_graph_def (in which case no nodes have
93  # control flow contexts).
94  if ctxt is not None and IsLoopExit(op):
95    ctxt = ctxt.outer_context
96  return ctxt
97
98
99def GetContainingWhileContext(ctxt, stop_ctxt=None):
100  """Returns the first ancestor WhileContext of `ctxt`.
101
102  Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
103  while loop.
104
105  Args:
106    ctxt: ControlFlowContext
107    stop_ctxt: ControlFlowContext, optional. If provided, the search will end
108      if it sees stop_ctxt.
109
110  Returns:
111    `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
112    `ctxt`, or None if `ctxt` is not in a while loop.  If `stop_ctxt` is not
113    `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
114  """
115  while ctxt:
116    if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
117    ctxt = ctxt.outer_context
118  return None
119
120
121def GetContainingXLAContext(ctxt):
122  """Returns the first ancestor XLAContext of `ctxt`.
123
124  Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
125  while loop.
126
127  Args:
128    ctxt: ControlFlowContext
129
130  Returns:
131    `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
132    `ctxt`, or None if `ctxt` is not in a while loop.
133  """
134  while ctxt:
135    if ctxt.IsXLAContext(): return ctxt
136    ctxt = ctxt.outer_context
137  return None
138
139
140def GetContainingCondContext(ctxt):
141  """Returns the first ancestor CondContext of `ctxt`.
142
143  Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
144
145  Args:
146    ctxt: ControlFlowContext
147
148  Returns:
149    `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
150    `ctxt`, or None if `ctxt` is not in a cond.
151  """
152  while ctxt:
153    if ctxt.IsCondContext(): return ctxt
154    ctxt = ctxt.outer_context
155  return None
156
157
158def IsContainingContext(ctxt, maybe_containing_ctxt):
159  """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
160  while ctxt is not maybe_containing_ctxt:
161    if ctxt is None: return False
162    ctxt = ctxt.outer_context
163  return True
164
165
166def CheckInputFromValidContext(op, input_op):
167  """Returns whether `input_op` can be used from `op`s context.
168
169  Conceptually, only inputs from op's while context or any ancestor while
170  context (including outside of any context) are valid. In practice, there are
171  many other edge cases as well.
172
173  Args:
174    op: Operation
175    input_op: Operation
176
177  Raises:
178    ValueError: if input_op is from an invalid context.
179  """
180  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
181  input_ctxt = GetOutputContext(input_op)
182  valid = False
183
184  if not input_ctxt:
185    # input_op isn't in a control flow context.
186    valid = True
187  elif op_ctxt is input_ctxt:
188    # input_op is in the same context as op.
189    valid = True
190  else:
191    while_ctxt = GetContainingWhileContext(op_ctxt)
192    input_while_ctxt = GetContainingWhileContext(input_ctxt)
193
194    if while_ctxt is None:
195      if input_while_ctxt is None:
196        # Neither op nor input_op is in a while loop, but one or both are in
197        # conds. We allow this, although execution will fail if the branch
198        # corresponding to input_op's cond context isn't taken.
199        valid = True
200      # Invalid if op isn't in a while loop and input_op is. Unless...
201      if IsLoopEnter(op):
202        # WhileContext._BuildLoop clears context for Enter nodes.
203        valid = True
204      if IsSwitch(op):
205        # CondContext.AddValue clears context for Switch nodes.
206        valid = True
207    elif IsContainingContext(while_ctxt, input_while_ctxt):
208      # input_op is in a while loop which contains op's while loop (or not in a
209      # while loop at all).
210      valid = True
211    elif (while_ctxt.grad_state and
212          IsContainingContext(while_ctxt.grad_state.forward_context,
213                              input_while_ctxt)):
214      # op is in a gradient context and input_op is in the associated forward
215      # pass context or an ancestor thereof. This case is need to build while
216      # loop gradients.
217      # NOTE(skyewm): we theoretically also need this case for custom gradient
218      # functions that close over tensors from ancestor contexts, but I haven't
219      # verified this.
220      valid = True
221    elif (while_ctxt.grad_state and
222          while_ctxt.grad_state.forward_context is
223          input_while_ctxt._outer_context):  # pylint: disable=protected-access
224      # op is in a gradient context and input_op is in a child of the associated
225      # forward pass context. This case is needed for the gradients of while
226      # loops with conds.
227      valid = True
228    elif (input_while_ctxt.grad_state and
229          input_while_ctxt.grad_state.forward_context is while_ctxt):
230      # input_op is in the gradient context of op's context. This case is needed
231      # when the gradient of a while loop gradient is requested (this will
232      # eventually fail unless there is a stop_gradient() or similar).
233      valid = True
234    elif (input_while_ctxt.grad_state and
235          input_ctxt.grad_state.forward_context.grad_state and
236          input_ctxt.grad_state.forward_context.grad_state.forward_context is
237          while_ctxt):
238      # input_op is in the grad grad context of op's context. This case is
239      # needed when the gradient of a while loop gradient is requested (this
240      # will eventually fail unless there is a stop_gradient() or similar).
241      valid = True
242
243  if not valid:
244    if while_ctxt:
245      error_msg = (
246          "Cannot use '%s' as input to '%s' because they are in different while"
247          " loops." % (op.name, input_op.name))
248    else:
249      error_msg = (
250          "Cannot use '%s' as input to '%s' because '%s' is in a while loop."
251          % (input_op.name, op.name, input_op.name))
252
253    # Log the error message plus the relevant stack traces. The stacks may be
254    # useful for debugging this error, but we don't want to raise an
255    # unreadable exception.
256    log_msg = error_msg
257    log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
258    log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
259    log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
260        op.name, "".join(traceback.format_list(op.traceback)),
261        input_op.name, "".join(traceback.format_list(input_op.traceback)))
262    logging.info(log_msg)
263    raise ValueError(error_msg + " See info log for more details.")
264