1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilty functions for control flow.
17
18This file is necessary to avoid cyclic dependencies between ops.py and
19control_flow_ops.py.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27import traceback
28
29from tensorflow.python.platform import tf_logging as logging
30
31ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
32                          os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
33                          os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
34                          os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
35
36
37def EnableControlFlowV2(graph):
38  """Returns whether control flow v2 should be used in `graph`."""
39  # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
40  # TODO(skyewm): do something better than hasattr without messing up imports.
41  return ENABLE_CONTROL_FLOW_V2 or (
42      graph.building_function and not hasattr(graph, "_captured"))
43
44
45def IsInXLAContext(op):
46  try:
47    xla_compile = op.get_attr("_XlaCompile")
48    if xla_compile: return True
49  except ValueError:
50    pass
51  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
52  return GetContainingXLAContext(ctxt) is not None
53
54
55def InXlaContext(graph):
56  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
57  return GetContainingXLAContext(ctxt) is not None
58
59
60def GraphOrParentsInXlaContext(graph):
61  while True:
62    if InXlaContext(graph): return True
63    try:
64      graph = graph.outer_graph
65    except AttributeError:
66      return False
67
68
69def IsInWhileLoop(op):
70  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
71  return GetContainingWhileContext(ctxt) is not None
72
73
74def IsInCond(op):
75  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
76  return GetContainingCondContext(ctxt) is not None
77
78
79def IsSwitch(op):
80  """Return true if `op` is a Switch."""
81  return op.type == "Switch" or op.type == "RefSwitch"
82
83
84def IsMerge(op):
85  """Return true if `op` is a Merge."""
86  return op.type == "Merge" or op.type == "RefMerge"
87
88
89def IsLoopEnter(op):
90  """Returns true if `op` is an Enter."""
91  return op.type == "Enter" or op.type == "RefEnter"
92
93
94def IsLoopExit(op):
95  """Return true if `op` is an Exit."""
96  return op.type == "Exit" or op.type == "RefExit"
97
98
99def IsCondSwitch(op):
100  """Return true if `op` is the Switch for a conditional."""
101  if not IsSwitch(op):
102    return False
103  if not op.outputs:
104    return False
105  # Switch nodes are not part of the cond control flow context that they
106  # represent, so consider the consumers of its outputs to determine if it is
107  # cond switch or not. A switch is a cond switch iff all its consumers are in
108  # cond contexts.
109  is_cond_switch = True
110  for o in op.outputs:
111    for c in o.consumers():
112      ctxt = c._get_control_flow_context()  # pylint: disable=protected-access
113      if IsLoopEnter(c):
114        ctxt = ctxt.outer_context
115      is_cond_switch = is_cond_switch and (ctxt is not None and
116                                           ctxt.IsCondContext())
117  return is_cond_switch
118
119
120def IsCondMerge(op):
121  """Return true if `op` is the Merge for a conditional."""
122  if not IsMerge(op):
123    return False
124  if not op.inputs:
125    return False
126  # Merge nodes are not part of the cond control flow context that they
127  # represent, so consider the inputs to the merge of to determine if it is
128  # cond merge or not: A merge is a cond merge iff all its inputs are in
129  # cond contexts.
130  is_cond_merge = True
131  for i in op.inputs:
132    ctxt = GetOutputContext(i.op)
133    is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
134  return is_cond_merge
135
136
137def IsLoopSwitch(op):
138  """Return true if `op` is the Switch for a while loop."""
139  if IsSwitch(op):
140    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
141    return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
142  return False
143
144
145def IsLoopMerge(op):
146  """Return true if `op` is the Merge for a while loop."""
147  if IsMerge(op):
148    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
149    return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
150  return False
151
152
153def IsLoopConstantEnter(op):
154  """Return true iff op is a loop invariant."""
155  return IsLoopEnter(op) and op.get_attr("is_constant")
156
157
158def GetLoopConstantEnter(value):
159  """Return the enter op if we can infer `value` to be a loop invariant."""
160  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
161  op = value.op
162  while op.type in id_ops:
163    op = op.inputs[0].op
164  return op if IsLoopConstantEnter(op) else None
165
166
167def GetOutputContext(op):
168  """Return the control flow context for the output of an op."""
169  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
170  # Exit nodes usually have a control flow context, except in the case where the
171  # exit node was imported via import_graph_def (in which case no nodes have
172  # control flow contexts).
173  if ctxt is not None and IsLoopExit(op):
174    ctxt = ctxt.outer_context
175  return ctxt
176
177
178def GetContainingWhileContext(ctxt, stop_ctxt=None):
179  """Returns the first ancestor WhileContext of `ctxt`.
180
181  Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
182  while loop.
183
184  Args:
185    ctxt: ControlFlowContext
186    stop_ctxt: ControlFlowContext, optional. If provided, the search will end
187      if it sees stop_ctxt.
188
189  Returns:
190    `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
191    `ctxt`, or None if `ctxt` is not in a while loop.  If `stop_ctxt` is not
192    `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
193  """
194  while ctxt:
195    if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
196    ctxt = ctxt.outer_context
197  return None
198
199
200def GetContainingXLAContext(ctxt):
201  """Returns the first ancestor XLAContext of `ctxt`.
202
203  Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
204  while loop.
205
206  Args:
207    ctxt: ControlFlowContext
208
209  Returns:
210    `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
211    `ctxt`, or None if `ctxt` is not in a while loop.
212  """
213  while ctxt:
214    if ctxt.IsXLAContext(): return ctxt
215    ctxt = ctxt.outer_context
216  return None
217
218
219def GetContainingCondContext(ctxt):
220  """Returns the first ancestor CondContext of `ctxt`.
221
222  Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
223
224  Args:
225    ctxt: ControlFlowContext
226
227  Returns:
228    `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
229    `ctxt`, or None if `ctxt` is not in a cond.
230  """
231  while ctxt:
232    if ctxt.IsCondContext(): return ctxt
233    ctxt = ctxt.outer_context
234  return None
235
236
237def IsContainingContext(ctxt, maybe_containing_ctxt):
238  """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
239  while ctxt is not maybe_containing_ctxt:
240    if ctxt is None: return False
241    ctxt = ctxt.outer_context
242  return True
243
244
245def OpInContext(op, ctxt):
246  return IsContainingContext(op._get_control_flow_context(), ctxt)  # pylint: disable=protected-access
247
248
249def TensorInContext(tensor, ctxt):
250  return OpInContext(tensor.op, ctxt)
251
252
253def CheckInputFromValidContext(op, input_op):
254  """Returns whether `input_op` can be used from `op`s context.
255
256  Conceptually, only inputs from op's while context or any ancestor while
257  context (including outside of any context) are valid. In practice, there are
258  many other edge cases as well.
259
260  Args:
261    op: Operation
262    input_op: Operation
263
264  Raises:
265    ValueError: if input_op is from an invalid context.
266  """
267  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
268  input_ctxt = GetOutputContext(input_op)
269  valid = False
270
271  if not input_ctxt:
272    # input_op isn't in a control flow context.
273    valid = True
274  elif op_ctxt is input_ctxt:
275    # input_op is in the same context as op.
276    valid = True
277  else:
278    while_ctxt = GetContainingWhileContext(op_ctxt)
279    input_while_ctxt = GetContainingWhileContext(input_ctxt)
280
281    if while_ctxt is None:
282      if input_while_ctxt is None:
283        # Neither op nor input_op is in a while loop, but one or both are in
284        # conds. We allow this, although execution will fail if the branch
285        # corresponding to input_op's cond context isn't taken.
286        valid = True
287      # Invalid if op isn't in a while loop and input_op is. Unless...
288      if IsLoopEnter(op):
289        # WhileContext._BuildLoop clears context for Enter nodes.
290        valid = True
291      if IsSwitch(op):
292        # CondContext.AddValue clears context for Switch nodes.
293        valid = True
294    elif IsContainingContext(while_ctxt, input_while_ctxt):
295      # input_op is in a while loop which contains op's while loop (or not in a
296      # while loop at all).
297      valid = True
298    elif (while_ctxt.grad_state and
299          IsContainingContext(while_ctxt.grad_state.forward_context,
300                              input_while_ctxt)):
301      # op is in a gradient context and input_op is in the associated forward
302      # pass context or an ancestor thereof. This case is need to build while
303      # loop gradients.
304      # NOTE(skyewm): we theoretically also need this case for custom gradient
305      # functions that close over tensors from ancestor contexts, but I haven't
306      # verified this.
307      valid = True
308    elif (while_ctxt.grad_state and
309          while_ctxt.grad_state.forward_context is
310          input_while_ctxt._outer_context):  # pylint: disable=protected-access
311      # op is in a gradient context and input_op is in a child of the associated
312      # forward pass context. This case is needed for the gradients of while
313      # loops with conds.
314      valid = True
315    elif (input_while_ctxt.grad_state and
316          input_while_ctxt.grad_state.forward_context is while_ctxt):
317      # input_op is in the gradient context of op's context. This case is needed
318      # when the gradient of a while loop gradient is requested (this will
319      # eventually fail unless there is a stop_gradient() or similar).
320      valid = True
321    elif (input_while_ctxt.grad_state and
322          input_ctxt.grad_state.forward_context.grad_state and
323          input_ctxt.grad_state.forward_context.grad_state.forward_context is
324          while_ctxt):
325      # input_op is in the grad grad context of op's context. This case is
326      # needed when the gradient of a while loop gradient is requested (this
327      # will eventually fail unless there is a stop_gradient() or similar).
328      valid = True
329
330  if not valid:
331    if while_ctxt:
332      error_msg = (
333          "Cannot use '%s' as input to '%s' because they are in different while"
334          " loops." % (input_op.name, op.name))
335    else:
336      error_msg = (
337          "Cannot use '%s' as input to '%s' because '%s' is in a while loop."
338          % (input_op.name, op.name, input_op.name))
339
340    # Log the error message plus the relevant stack traces. The stacks may be
341    # useful for debugging this error, but we don't want to raise an
342    # unreadable exception.
343    log_msg = error_msg
344    log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
345    log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
346    log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
347        op.name, "".join(traceback.format_list(op.traceback)),
348        input_op.name, "".join(traceback.format_list(input_op.traceback)))
349    logging.info(log_msg)
350    raise ValueError(error_msg + " See info log for more details.")
351