1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for V2 control flow.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import function_def_to_graph 27from tensorflow.python.framework import ops 28from tensorflow.python.framework.func_graph import FuncGraph 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import control_flow_v2_func_graphs 31from tensorflow.python.ops import gradients_util 32from tensorflow.python.util import keras_deps 33from tensorflow.python.util import tf_contextlib 34 35 36_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None 37_DISABLE_LOWER_USING_SWITCH_MERGE = False 38 39 40CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph 41WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph 42WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph 43 44 45def in_defun(): 46 """Returns if the current graph is, or is nested in, a defun.""" 47 if context.executing_eagerly(): return False 48 49 graph = ops.get_default_graph() 50 while (isinstance(graph, CondBranchFuncGraph) or 51 isinstance(graph, WhileBodyFuncGraph) or 52 isinstance(graph, WhileCondFuncGraph)): 53 graph = graph.outer_graph 54 return isinstance(graph, FuncGraph) 55 56 57def in_while_loop_defun(graph): 58 """Returns if the graph is a while loop FuncGraph.""" 59 if context.executing_eagerly(): return False 60 return (isinstance(graph, WhileCondFuncGraph) or 61 isinstance(graph, WhileBodyFuncGraph)) 62 63 64def create_new_tf_function(func_graph): 65 """Converts func_graph to a TF_Function and adds it to the current graph. 66 67 Args: 68 func_graph: FuncGraph 69 70 Returns: 71 The name of the new TF_Function. 72 """ 73 func = function._EagerDefinedFunction( # pylint: disable=protected-access 74 func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) 75 func.add_to_graph(func_graph.outer_graph) 76 return func_graph.name 77 78 79def unique_fn_name(scope, name): 80 """Returns a unique name to use for a control flow function. 81 82 Args: 83 scope: A name scope string. 84 name: An identifier for this function (e.g. "true", "body"). 85 86 Returns: 87 A string, the name to use for the function. 88 """ 89 return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_") 90 91 92def unique_grad_fn_name(forward_name): 93 return "%s_grad_%s" % (forward_name, ops.uid()) 94 95 96def maybe_set_lowering_attr(op, lower_using_switch_merge=None): 97 """Sets the flag to enable lowering on `op` if necessary. 98 99 Lowering allows cond_v2 and while_v2 to avoid some of the limitations of 100 Functions, allowing users to specify devices & colocation inside of cond_v2 101 and while_v2 input functions, and enabling non-strict evaluation & partial 102 pruning. This brings v2 control flow closer to feature parity with v1 control 103 flow. 104 105 However, we do not lower in the following cases: 106 - When the `If` or `While` ops are in the XLA context. Because it is easier 107 for XLA to apply its own optimizations when dealing with un-lowered 108 control flow operators than with low-level control flow primitives. 109 - When the eager execution context specifies the executor of functions to 110 be the single threaded executor (see context.function_executor_type()). 111 Because the single threaded executor does not support v1 control flow ops. 112 - When 'lower_using_switch_merge' is explicitly set to False. 113 114 Args: 115 op: An `If` or `While` Operation. 116 lower_using_switch_merge: Explicit value to lower or not (optional). 117 """ 118 if lower_using_switch_merge is not None: 119 # pylint: disable=protected-access 120 op._set_attr("_lower_using_switch_merge", 121 attr_value_pb2.AttrValue(b=lower_using_switch_merge)) 122 # pylint: enable=protected-access 123 elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and 124 not control_flow_util.GraphOrParentsInXlaContext(op.graph) and 125 context.context().function_call_options.executor_type != 126 "SINGLE_THREADED_EXECUTOR"): 127 # pylint: disable=protected-access 128 op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) 129 # pylint: enable=protected-access 130 131 132def maybe_propagate_compile_time_consts_in_xla(op): 133 """Tells XLA whether to propagate compile-time consts in the loop body. 134 135 This is needed to make compile time constants available to ops, for example 136 `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this 137 would always be turned on, but that doesn't work with legacy functionalized 138 while_loops. 139 140 Args: 141 op: A `While` Operation. 142 """ 143 if control_flow_util.GraphOrParentsInXlaContext(op.graph): 144 # pylint: disable=protected-access 145 op._set_attr("_xla_propagate_compile_time_consts", 146 attr_value_pb2.AttrValue(b=True)) 147 # pylint: enable=protected-access 148 149 150def resource_input_index(tensor_name, input_names, node_defs, functions): 151 """Returns the index of the input corresponding to `tensor_name`. 152 153 This method is used to find the corresponding index of an arbitrary resource 154 tensor in a function (the function could be a loop body). We assume that 155 resource handles are never created in functions, so that every resource 156 tensor can be traced back to a function input. 157 158 The awkward signature of this method is to make it work with both FuncGraphs 159 and FunctionDefs. This is so we can recurse on function call ops without 160 building the corresponding FuncGraph (note that even if a FuncGraph for a 161 FunctionDef already exists, the input/output/node names may have been 162 changed when the FuncGraph was serialized to the FunctionDef, which makes it 163 unusable with this algorithm). 164 165 Args: 166 tensor_name: the name of the resource tensor to be resolved to an input. 167 input_names: a list of the names of all inputs to the function. 168 node_defs: a dict mapping op name -> NodeDef for every op in the function. 169 functions: a dict mapping function name -> _EagerDefinedFunction. 170 171 Returns: 172 The index into input_names corresponding to `tensor_name`. 173 """ 174 while tensor_name not in input_names: 175 # FunctionDefs and graphs use different tensor naming conventions. 176 parts = tensor_name.split(":") 177 if len(parts) == 3: 178 op_name, _, output_idx = parts 179 elif len(parts) == 2: 180 op_name, output_idx = parts 181 else: 182 assert len(parts) == 1 183 op_name = parts[0] 184 output_idx = 0 185 tensor_name = "%s:%d" % (tensor_name, output_idx) 186 # Check again for cases where the tensor suffix (":0") is stripped out. 187 if tensor_name in input_names: 188 break 189 output_idx = int(output_idx) 190 node_def = node_defs[op_name] 191 192 if node_def.op in ("Identity", "While"): 193 # Captured resources occur at the same index in the lists of inputs and 194 # outputs of a while or identity op. So we lookup the input of `tensor.op` 195 # at the same index as the index of `tensor` in the `tensor.op.outputs`. 196 tensor_name = node_def.input[output_idx] 197 elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): 198 # Functions output any captured resource tensors used by their 199 # gradients. `tensor_name` is one of these outputs from a nested 200 # function call, so recursively find the corresponding input in the 201 # nested FunctionDef. 202 func_name = node_def.attr["f"].func.name 203 fdef = functions[func_name].definition 204 output_arg_name = fdef.signature.output_arg[output_idx].name 205 output_tensor_name = fdef.ret[output_arg_name] 206 input_index = resource_input_index( 207 output_tensor_name, [arg.name for arg in fdef.signature.input_arg], 208 {ndef.name: ndef for ndef in fdef.node_def}, functions) 209 tensor_name = node_def.input[input_index] 210 else: 211 # We assume there are no other ops types that will "forward" resource 212 # handles like this, so all other handles must have been created by the 213 # op. (Note that cond_v2 wraps resource handle outputs in optionals, 214 # which we'll end up accumulating). 215 raise ValueError("Taking gradient of a while loop which creates " 216 "a resource in its body is not supported: %s" % op_name) 217 218 return input_names.index(tensor_name) 219 220 221@tf_contextlib.contextmanager 222def clear_control_inputs(): 223 """Clears the control inputs but preserves the ControlFlowContext. 224 225 This is needed to preserve the XLAControlFlowControl when clearing 226 control inputs for the gradient accumulators in while_v2. 227 `ops.control_dependencies` does not allow that. 228 229 Yields: 230 A context manager in which the ops created will not have any control inputs 231 by default but the control flow context is the same. 232 """ 233 # pylint: disable=protected-access 234 control_flow_context = ops.get_default_graph()._get_control_flow_context() 235 with ops.control_dependencies(None): 236 ops.get_default_graph()._set_control_flow_context(control_flow_context) 237 yield 238 # pylint: enable=protected-access 239 240 241def _is_tpu_strategy(strategy): 242 return (strategy is not None and 243 strategy.__class__.__name__.startswith("TPUStrategy")) 244 245 246def _is_building_keras_layer(): 247 # TODO(srbs): Remove this function when we no long support session with Keras. 248 keras_call_context_function = keras_deps.get_call_context_function() 249 if keras_call_context_function: 250 return keras_call_context_function().layer is not None 251 else: 252 return False 253 254 255def output_all_intermediates(): 256 """Whether to output all intermediates of a functional control flow op. 257 258 The default behavior is to output intermediates only when building a Keras 259 Layer in graph mode and that too when certain other conditions are met: 260 1. We do not output intermediates if the functional control flow op 261 is being built inside a FuncGraph which is not a If/While graph. This 262 guards against outputting intermediates in eager mode since keras adds 263 tensors to a FuncGraph named "keras_graph" in that case. Also because we 264 do not output intermediates of tf.function (since this feature is only for 265 backwards compatibility) outputting intermediates of functional control 266 flow ops built inside tf.function is of no value. 267 2. We do not output intermediates when the compilation is using XLA or for a 268 TPU. 269 3. We do not output intermediates when a single threaded executor is used 270 since that does not perform inlining and pruning. 271 272 Returns: 273 A bool telling whether to output all intermediates. 274 """ 275 if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None: 276 return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 277 if in_defun(): 278 return False 279 if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or 280 _is_tpu_strategy(distribution_strategy_context.get_strategy())): 281 return False 282 if (context.context().function_call_options.executor_type == 283 "SINGLE_THREADED_EXECUTOR"): 284 return False 285 return _is_building_keras_layer() 286 287 288def get_func_graph(op, input_shapes, func_name): 289 """Generates and returns a FuncGraph for the given op and input_shapes.""" 290 fdef = None 291 graph = op.graph 292 # Recursively search the func in graphs. 293 while graph is not None: 294 func = graph._get_function(func_name) # pylint: disable=protected-access 295 if func is not None: 296 fdef = func.definition 297 break 298 if hasattr(graph, "outer_graph"): 299 graph = graph.outer_graph 300 else: 301 break 302 303 if fdef is None: 304 raise KeyError("%s cannot be found in the graph" % func_name) 305 306 # `op.graph` may not be the same as `ops.get_default_graph()` e.g. 307 # in the case of nested if ops or when the gradient is being computed 308 # from inside a Defun. We build the `func_graph` with `op.graph` as its 309 # `outer_graph`. This resembles how the `FuncGraph` was built in the 310 # forward pass. We need this so that we can resolve references to tensors 311 # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. 312 with op.graph.as_default(): 313 func_graph = function_def_to_graph.function_def_to_graph( 314 fdef, input_shapes) 315 return func_graph 316 317 318def get_op_and_outputs(op_or_outputs): 319 if isinstance(op_or_outputs, ops.Operation): 320 return op_or_outputs, [] 321 elif not op_or_outputs: # Empty list. 322 return None, [] 323 else: 324 return op_or_outputs[0].op, op_or_outputs 325 326 327def graph_wrapped_for_higher_order_tape_gradients(graph): 328 """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`.""" 329 while graph is not None: 330 if "cflow_gradient_wrapper" in getattr(graph, "name", ""): 331 return True 332 graph = getattr(graph, "outer_graph", None) 333 return False 334 335 336def run_as_function_for_tape_gradients(make_op, inputs): 337 """Fix higher-order tape gradients by wrapping `make_op` in a function. 338 339 Args: 340 make_op: A function that takes a list of inputs and returns a list of output 341 tensors. This function should set any handle data relevant to its outputs 342 before returning. 343 inputs: A list of tensors to check for tape gradients and pass to 344 `make_op`. These should include all tensors used in `make_op`. 345 346 Returns: 347 Tensors corresponding to `make_op`'s output. 348 """ 349 # GradientTapes created inside a function currently don't work well with 350 # un-wrapped control flow ops in that same function. Wrapping in an extra 351 # layer of intermediate function means we run extra logic in the function 352 # gradient code to record the correct intermediates on the tape. 353 # 354 # The function attribute inputs to control flow ops are not hashable, so we 355 # pass everything as a capture to bypass defun's caching. 356 if (gradients_util.PossibleTapeGradientTypes(inputs) 357 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER 358 # We only need one function between the tape and the op; if we've already 359 # wrapped once, we stop wrapping to avoid infinite recursion. 360 and not (ops.get_default_graph().building_function 361 and "cflow_gradient_wrapper" in ops.get_default_graph().name)): 362 results = function.defun_with_attributes( 363 make_op, 364 autograph=False, 365 attributes=dict(func_name="cflow_gradient_wrapper"))(inputs) 366 return results 367 else: 368 return make_op(inputs) 369