1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""cond_v2 and gradient. 16 17This is a version of cond that emits a single If op, as well as the gradient 18function for If ops produced by cond_v2. This will eventually replace the 19current tf.cond implementation once it reaches feature and performance parity. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27 28from tensorflow.core.framework import types_pb2 29from tensorflow.python.eager import backprop_util 30from tensorflow.python.framework import auto_control_deps 31from tensorflow.python.framework import auto_control_deps_utils as acd 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors_impl 35from tensorflow.python.framework import func_graph as func_graph_module 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_util 41from tensorflow.python.ops import control_flow_util_v2 as util 42from tensorflow.python.ops import custom_gradient 43from tensorflow.python.ops import default_gradient 44from tensorflow.python.ops import gen_dataset_ops 45from tensorflow.python.ops import gen_functional_ops 46from tensorflow.python.ops import gradients_util 47from tensorflow.python.ops import handle_data_util 48from tensorflow.python.ops import math_ops 49from tensorflow.python.util import nest 50 51 52# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify 53# that they aren't part of the official public API. These protected members 54# often need to be used by implementation code however. Rather than litter the 55# code with pylint comments, we ignore protected access violations for 56# readability. 57# pylint: disable=protected-access 58 59_COND = 1 60_CASE = 2 61 62 63def cond_v2(pred, true_fn, false_fn, name="cond"): 64 """Like tf.cond, except emits a single If op.""" 65 if isinstance(pred, bool): 66 raise TypeError("pred must not be a Python bool", pred) 67 68 if not name: 69 name = "cond" 70 71 with ops.name_scope(name) as scope: 72 true_name = util.unique_fn_name(scope, "true") 73 false_name = util.unique_fn_name(scope, "false") 74 75 # Automatic control dependencies are added in defuns, but not in v1 76 # graphs. Propagate that behavior here. 77 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 78 pred = ops.convert_to_tensor(pred) 79 if (tensor_util.is_tf_type(pred) and 80 (pred.shape.dims is None or pred.shape.dims)): 81 pred = array_ops.squeeze_v2(pred) 82 83 true_graph = func_graph_module.func_graph_from_py_func( 84 true_name, 85 true_fn, [], {}, 86 func_graph=util.CondBranchFuncGraph( 87 true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 88 add_control_dependencies=add_control_dependencies, 89 op_return_value=pred) 90 false_graph = func_graph_module.func_graph_from_py_func( 91 false_name, 92 false_fn, [], {}, 93 func_graph=util.CondBranchFuncGraph( 94 false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 95 add_control_dependencies=add_control_dependencies, 96 op_return_value=pred) 97 98 verify_captures(_COND, [true_graph, false_graph]) 99 return _build_cond( 100 pred, 101 true_graph, 102 false_graph, 103 true_graph.external_captures, 104 false_graph.external_captures, 105 building_gradient=False, 106 name=scope) 107 108 109@ops.RegisterGradient("StatelessIf") 110@ops.RegisterGradient("If") 111def _IfGrad(op, *grads): # pylint: disable=invalid-name 112 """The gradient of an If op produced by cond_v2.""" 113 # Get the if operator (this logic handles the case where op is a MockOp) 114 if_op = op.outputs[0].op 115 true_graph, false_graph = get_func_graphs(if_op) 116 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 117 # of a nested cond. 118 assert true_graph.outer_graph == if_op.graph 119 assert false_graph.outer_graph == if_op.graph 120 121 # Create grad functions that compute the gradient of the true/false forward 122 # graphs. These functions will capture tensors from the forward pass 123 # functions. 124 true_grad_graph = _create_grad_func( 125 true_graph, grads, util.unique_grad_fn_name(true_graph.name)) 126 false_grad_graph = _create_grad_func( 127 false_graph, grads, util.unique_grad_fn_name(false_graph.name)) 128 129 # Replaces output None grads with zeros if at least one branch has non-None 130 # grad at that index. 131 _create_zeros_for_none_grads([true_graph, false_graph], 132 [true_grad_graph, false_grad_graph]) 133 134 if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): 135 # Modify 'op' to output the intermediates needed by the grad functions. Note 136 # that all needed intermediates are wrapped in optionals. Each optional 137 # intermediate output will have a value iff its corresponding branch is 138 # taken. 139 # NOTE(skyewm): if there are any active sessions, this modification to `op` 140 # may make them unrunnable! 141 142 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 143 # XLA does not yet support optionals, so output intermediates directly and 144 # make them match via FakeParams, which can be converted to zeros in XLA. 145 # TODO(skyewm,jpienaar): can XLA support optionals? 146 true_intermediates = true_grad_graph.xla_intermediates 147 false_intermediates = false_grad_graph.xla_intermediates 148 extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( 149 [true_graph, false_graph], [true_intermediates, false_intermediates]) 150 else: 151 true_intermediates = true_grad_graph.wrapped_intermediates 152 false_intermediates = false_grad_graph.wrapped_intermediates 153 # Make outputs match by adding none optionals. 154 extra_true_outputs, extra_false_outputs = _make_intermediates_match( 155 [true_graph, false_graph], [true_intermediates, false_intermediates]) 156 157 true_graph.outputs.extend(extra_true_outputs) 158 false_graph.outputs.extend(extra_false_outputs) 159 # TODO(skyewm): indicate it's an internal bug if this fails. 160 _check_same_outputs(_COND, [true_graph, false_graph]) 161 162 true_graph.name += "_rewritten" 163 false_graph.name += "_rewritten" 164 165 if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) 166 if_op._set_func_attr("else_branch", 167 util.create_new_tf_function(false_graph)) 168 if_op._set_type_list_attr("Tout", true_graph.output_types) 169 if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) 170 if_op._add_outputs( 171 [t.dtype for t in extra_true_outputs], 172 [t.shape for t in extra_true_outputs]) 173 174 # Resolve references to forward graph tensors in grad graphs and ensure 175 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 176 true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) 177 false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) 178 179 # This modifies true_grad_graph and false_grad_graph. 180 _make_output_composite_tensors_match(_COND, 181 [true_grad_graph, false_grad_graph]) 182 183 outputs = _build_cond( 184 if_op.inputs[0], 185 true_grad_graph, 186 false_grad_graph, 187 true_grad_inputs, 188 false_grad_inputs, 189 building_gradient=True, 190 ) 191 192 # The predicate has no gradient. 193 return [None] + outputs 194 195 196def _build_cond(pred, 197 true_graph, 198 false_graph, 199 true_inputs, 200 false_inputs, 201 building_gradient, 202 name=None): 203 """Creates an If op from the specified predicate, branch functions and inputs. 204 205 Note that this modifies true_graph and false_graph to make the inputs match, 206 and to output all intermediates values so they're available for the gradient 207 computation. 208 209 true_graph and false_graph need not have the same input types, but they must 210 have the same output types. 211 212 Args: 213 pred: boolean Tensor 214 true_graph: FuncGraph 215 false_graph: FuncGraph 216 true_inputs: a list of Tensors to be passed to true_graph as input. 217 false_inputs: a list of Tensors to be passed to false_graph as input. 218 building_gradient: Whether this is a gradient If op. 219 name: the name for the If op. 220 221 Returns: 222 A list of Tensors which are the outputs of the If op. Does not include added 223 intermediate outputs. 224 """ 225 _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) 226 _check_same_outputs(_COND, [true_graph, false_graph]) 227 228 # Add inputs to true_graph and false_graph to make them match. Note that 229 # this modifies true_graph and false_graph. 230 cond_inputs = _make_inputs_match([true_graph, false_graph], 231 [true_inputs, false_inputs]) 232 # We do not output intermediates of the gradient If op since this is just 233 # for backwards compatibility with existing code. 234 if not building_gradient and util.output_all_intermediates(): 235 # Add all intermediate tensors as function outputs so they're available for 236 # the gradient computation. Since the outputs of the two functions must 237 # match, we wrap all the intermediates in optionals. Each intermediate 238 # output will have a value iff its corresponding branch is taken. 239 240 true_intermediates = _get_intermediates(true_graph) 241 false_intermediates = _get_intermediates(false_graph) 242 243 # Wrap intermediates in optionals. 244 wrapped_true_intermediates = _wrap_intermediates(true_graph, 245 true_intermediates) 246 wrapped_false_intermediates = _wrap_intermediates(false_graph, 247 false_intermediates) 248 249 # Make outputs match by adding none optionals. 250 extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking 251 [true_graph, false_graph], 252 [wrapped_true_intermediates, wrapped_false_intermediates]) 253 254 true_graph.outputs.extend(extra_true_outputs) 255 false_graph.outputs.extend(extra_false_outputs) 256 _check_same_outputs(_COND, [true_graph, false_graph]) 257 258 # Create the If op. 259 with ops.control_dependencies( 260 list(true_graph.control_captures) + list(false_graph.control_captures)): 261 true_stateful_ops = [ 262 op for op in true_graph.get_operations() if op._is_stateful 263 ] 264 false_stateful_ops = [ 265 op for op in false_graph.get_operations() if op._is_stateful 266 ] 267 if (true_stateful_ops or false_stateful_ops): 268 op_fn = gen_functional_ops._if 269 else: 270 op_fn = gen_functional_ops.stateless_if 271 272 def _make_op(inputs): 273 if_op, tensors = util.get_op_and_outputs(op_fn( 274 pred, 275 inputs, [t.dtype for t in true_graph.outputs], 276 util.create_new_tf_function(true_graph), 277 util.create_new_tf_function(false_graph), 278 output_shapes=_get_output_shapes(true_graph.outputs, 279 false_graph.outputs), 280 name=name)) 281 _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) 282 # `if_op` is None if this is a `StatelessIf` op with no outputs. 283 if if_op is not None: 284 # The true and false graphs have already been created, and we need that 285 # to happen before we know which tensors will be captured and so whether 286 # to wrap the cond in a tf.function. Post-hoc mutation of the branch 287 # `outer_graph` properties seems like the only option if we want to 288 # conditionally wrap in a function. 289 true_graph.outer_graph = ops.get_default_graph() 290 false_graph.outer_graph = ops.get_default_graph() 291 if_op._true_graph = true_graph 292 if_op._false_graph = false_graph 293 util.maybe_set_lowering_attr(if_op) 294 util.maybe_propagate_compile_time_consts_in_xla(if_op) 295 _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) 296 # Prevent fetching since the variant outputs can't be fetched directly. 297 if_op.graph.prevent_fetching(if_op) 298 return tensors 299 tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs) 300 301 # Return identities for each output of the If op, rather than the output of 302 # the If op directly. This makes pruning work if the output of cond() is 303 # fetched: the lowering pass converts the If outputs into IdentityN outputs, 304 # which if fetched will cause all ops in the taken branch to be run (since 305 # it takes all merge ops as input). After lowering, each output identity op 306 # will end up with only the appropriate merge op as input. 307 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 308 # correct output structure 309 tensors = [array_ops.identity(t) for t in tensors] 310 311 return _pack_sequence_as(true_graph.structured_outputs, tensors) 312 313 314def get_func_graphs(op): 315 """Returns `FuncGraph`s for the input op branches. 316 317 Args: 318 op: The If or Case Operation. 319 320 Returns: 321 A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches 322 for Case). 323 """ 324 325 def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None): 326 """Generates and returns a FuncGraph for the given branch.""" 327 func_graph = None 328 if cached_attr_name is not None: 329 func_graph = getattr(op, cached_attr_name, None) 330 inputs = op.inputs[1:] # First input is pred. 331 if func_graph is None: 332 input_shapes = [t.shape for t in inputs] 333 func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) 334 for external_t, internal_t in zip(inputs, func_graph.inputs): 335 custom_gradient.copy_handle_data(external_t, internal_t) 336 func_graph.reset_captures(zip(inputs, func_graph.inputs)) 337 # Link the op so that the gradient code can use it. 338 func_graph._forward_cond = op 339 return func_graph 340 341 if op.type in ["If", "StatelessIf"]: 342 return (_get_func_graph_for_branch( 343 op.get_attr("then_branch"), "_true_graph"), 344 _get_func_graph_for_branch( 345 op.get_attr("else_branch"), "_false_graph")) 346 elif op.type in ["Case", "StatelessCase"]: 347 return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i)) 348 for i, branch_fn in enumerate(op.get_attr("branches"))] 349 else: 350 raise ValueError("Unsupported op type: {}".format(op.type)) 351 352 353def _grad_fn(func_graph, grads): 354 """The gradient function for each conditional branch. 355 356 This function builds the gradient graph of the corresponding forward-pass 357 conditional branch in `func_graph`. This is done by differentiating 358 func_graph's outputs w.r.t. its inputs. 359 360 Args: 361 func_graph: FuncGraph. The corresponding forward-pass function. 362 grads: The list of input gradient Tensors. 363 364 Returns: 365 The output gradient Tensors. 366 """ 367 # Filter out untrainable function outputs. 368 # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes 369 # cause _GradientsHelper to raise an exception (e.g. the implementation 370 # doesn't expect 'ys' to contain boolean tensors). 371 assert len(func_graph.outputs) == len(grads) 372 ys = [] 373 grad_ys = [] 374 for y, grad_y in zip(func_graph.outputs, grads): 375 if not backprop_util.IsTrainable(y): 376 continue 377 ys.append(y) 378 grad_ys.append(grad_y) 379 380 # Build the gradient graph. Note that this builds the gradient computation of 381 # func_graph in the current graph, which requires capturing tensors from 382 # func_graph. The captured func_graph tensors are resolved to external tensors 383 # in _resolve_grad_inputs. 384 result = gradients_util._GradientsHelper( 385 ys, func_graph.inputs, grad_ys=grad_ys, 386 src_graph=func_graph) 387 388 return result 389 390 391def _create_grad_func(func_graph, grads, name): 392 """Returns the FuncGraph representation of _grad_fn.""" 393 return func_graph_module.func_graph_from_py_func( 394 name, 395 lambda: _grad_fn(func_graph, grads), [], {}, 396 func_graph=_CondGradFuncGraph(name, func_graph)) 397 398 399def _resolve_grad_inputs(cond_graph, grad_graph): 400 """Returns the tensors to pass as inputs to `grad_graph`. 401 402 The `grad_graph` may have external references to 403 1. Its outer graph containing the input gradients. These references are kept 404 as is. 405 2. Tensors in the forward pass graph. These tensors may not be "live" 406 when the gradient is being computed. We replace such references by their 407 corresponding tensor in `cond_graph.outer_graph`. In the case of nested 408 control flow or functions, the gradient logic handling 409 `grad_graph.outer_graph` will make sure the tensor from 410 `cond_graph.outer_graph` is also correctly captured. 411 412 Args: 413 cond_graph: FuncGraph. The forward-pass function. 414 grad_graph: FuncGraph. The gradients function. 415 416 Returns: 417 A list of inputs tensors to be passed to grad_graph. 418 """ 419 new_inputs = [] 420 421 for t in grad_graph.external_captures: 422 # `t` must either be in `grad_graph.outer_graph` or in the forward 423 # `cond_graph`. 424 if t.graph != grad_graph.outer_graph: 425 assert t.graph == cond_graph 426 # `internal_captures` are not treated as intermediates and hence not added 427 # to If op outputs. So we get the outer tensor corresponding to those 428 # from the list of `external_captures`. 429 for i, output in enumerate(t.graph.outputs): 430 if output is t: 431 t = t.graph._forward_cond.outputs[i] 432 break 433 else: 434 for i, output in enumerate(t.graph.internal_captures): 435 if output is t: 436 t = t.graph.external_captures[i] 437 break 438 else: 439 raise ValueError("Could not find external tensor capture {tensor} in " 440 "captures or outputs".format(tensor=t)) 441 442 # Note: We rely on the capturing logic of the gradient If op graph to 443 # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2 444 # and while_v2 handle this while building their gradient functions. 445 assert t.graph == cond_graph.outer_graph 446 new_inputs.append(t) 447 448 return new_inputs 449 450 451def _get_intermediates(func_graph): 452 """Returns intermediate tensors of `func_graph` for gradient computation.""" 453 intermediates = [] 454 for op in func_graph.get_operations(): 455 for t in op.outputs: 456 if t in func_graph.inputs: continue 457 if t in func_graph.outputs: continue 458 if t.dtype is dtypes.resource: 459 continue 460 # Accumulating mutexes can cause deadlock. 461 if op.type == "MutexLock": 462 continue 463 intermediates.append(t) 464 return intermediates 465 466 467def _make_intermediates_match(branch_graphs, branch_optionals): 468 """Returns new optionals lists that have matching signatures. 469 470 This is done by mirroring each list in the other using none optionals. 471 There is no merging of like optionals. 472 473 Args: 474 branch_graphs: `list` of `FuncGraph`. 475 branch_optionals: `list` of `list`s of optional `Tensor`s from other 476 branch_graphs 477 478 Returns: 479 A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the 480 same number of `Tensor`s, all of which will be optionals of the same 481 shape/type. 482 """ 483 new_branch_optionals = [] 484 # Since the intermediates are optionals with dtype variant, we only need 485 # enough room for the longest list of intermediates. 486 intermediates_size = max(len(o) for o in branch_optionals) 487 for i, branch_graph in enumerate(branch_graphs): 488 other_optionals = _create_none_optionals( 489 branch_graph, intermediates_size - len(branch_optionals[i])) 490 new_branch_optionals.append(branch_optionals[i] + other_optionals) 491 return new_branch_optionals 492 493 494def _make_intermediates_match_xla(branch_graphs, branch_intermediates): 495 """Like _make_intermediates_match but for the XLA case.""" 496 new_branch_intermediates = [] 497 for i, branch_graph in enumerate(branch_graphs): 498 other_fakeparams = _create_fakeparams( 499 branch_graph, 500 sum((bi for bi in branch_intermediates 501 if bi is not branch_intermediates[i]), [])) 502 num_preceding = sum(len(bi) for bi in branch_intermediates[:i]) 503 new_branch_intermediates.append(other_fakeparams[:num_preceding] + 504 branch_intermediates[i] + 505 other_fakeparams[num_preceding:]) 506 return new_branch_intermediates 507 508 509def _make_inputs_match(branch_graphs, branch_inputs): 510 """Modifies branch_graphs so they have the same input signature. 511 512 This method reorders and/or adds parameters to each graph in branch_graphs so 513 they have the same input signature, and updates the 'inputs' and 'captured' 514 fields of each graph accordingly. It uses the input tensors from the outer 515 graph to avoid duplicating shared arguments. 516 517 Args: 518 branch_graphs: a `list` of `FuncGraph` 519 branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The 520 inputs for the corresponding graph in `branch_graphs`. 521 522 Returns: 523 A new list of Tensors from the outer graph that are the new inputs for each 524 branch_graph. This is a deduped version of `sum(branch_inputs)`. 525 """ 526 assert len(branch_graphs) == len(branch_inputs) 527 added_inputs = set() 528 new_inputs = [] 529 for branch_in in branch_inputs: 530 for tensor in branch_in: 531 tensor_id = ops.tensor_id(tensor) 532 if tensor_id not in added_inputs: 533 added_inputs.add(tensor_id) 534 new_inputs.append(tensor) 535 536 for branch_graph, branch_in in zip(branch_graphs, branch_inputs): 537 input_ids = [ops.tensor_id(t) for t in branch_in] 538 branch_input_to_param = dict(zip(input_ids, branch_graph.inputs)) 539 input_list = [] 540 for in_t in new_inputs: 541 param = branch_input_to_param.get(ops.tensor_id(in_t)) 542 if param is None: 543 param = _create_dummy_input(branch_graph, in_t) 544 input_list.append(param) 545 546 branch_graph.inputs = input_list 547 548 # Rewrite the FuncGraphs' state to reflect the new inputs. 549 branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs)) 550 551 return new_inputs 552 553 554def _create_zeros_for_none_grads(forward_graphs, grad_graphs): 555 """Creates zeros for None out grads if at least one branch has non-None grad. 556 557 Args: 558 forward_graphs: List of forward FuncGraphs. 559 grad_graphs: List of grad FuncGraphs. 560 """ 561 assert len(forward_graphs) == len(grad_graphs) 562 branch_outputs = [g.structured_outputs for g in grad_graphs] 563 num_outputs_per_branch = [len(outs) for outs in branch_outputs] 564 assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch 565 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 566 if (any(t is None for t in branch_outs) and 567 any(t is not None for t in branch_outs)): 568 for branch_index, t in enumerate(branch_outs): 569 if t is None: 570 with grad_graphs[branch_index].as_default(): 571 zeros = default_gradient.zeros_like( 572 forward_graphs[branch_index].inputs[output_idx]) 573 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 574 575 for grad_graph in grad_graphs: 576 grad_graph.outputs = [ 577 t for t in func_graph_module.flatten(grad_graph.structured_outputs) 578 if t is not None 579 ] 580 581 582def _make_output_composite_tensors_match(op_type, branch_graphs): 583 """Modifies each branch_graph's outputs to have the same output signature. 584 585 Currently the only transformation implemented is turning a Tensor into an 586 equivalent IndexedSlices if the other branch returns an IndexedSlices. 587 Updates branch_graph.{outputs,structured_outputs} for each branch_graph in 588 branch_graphs. 589 590 Args: 591 op_type: _COND or _CASE 592 branch_graphs: `list` of `FuncGraph` 593 594 Raises: 595 TypeError: if a set of outputs cannot be rewritten. 596 """ 597 # Note: since this is only used for gradient graphs, we do not expect the 598 # outputs to be structured (e.g. nested lists), and thus do not need to use 599 # nest.flatten, etc. 600 assert branch_graphs 601 branch_outputs = [g.structured_outputs for g in branch_graphs] 602 outputs_per_branch = list(len(outs) for outs in branch_outputs) 603 assert len(set(outputs_per_branch)) == 1, outputs_per_branch 604 605 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 606 if len(set(type(out) for out in branch_outs)) == 1: 607 continue 608 if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs): 609 continue 610 for branch_idx, branch_out in enumerate(branch_outs): 611 if isinstance(branch_out, ops.IndexedSlices): 612 continue 613 elif isinstance(branch_out, ops.Tensor): 614 with branch_graphs[branch_idx].as_default(): 615 branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( 616 branch_out) 617 else: 618 raise TypeError( 619 "Cannot reconcile {op_name} {output_idx}-th outputs:\n" 620 " outputs from all branches: {outputs}".format( 621 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 622 output_idx=output_idx, 623 outputs=branch_outs)) 624 625 for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): 626 branch_graph.structured_outputs = branch_outs 627 branch_graph.outputs = [ 628 t for t in func_graph_module.flatten(branch_outs) if t is not None 629 ] 630 631 632def _make_indexed_slices_indices_types_match(op_type, branch_graphs): 633 """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" 634 assert branch_graphs 635 # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. 636 indexed_slice_indices = [] 637 current_index = 0 638 # Note that this still contains Nones. We leave those in so that error 639 # messages contain the correct indices. We handle the Nones later when 640 # updating `current_index`. 641 branch_outputs_flat_with_composites = [ 642 nest.flatten(branch_graph.structured_outputs, expand_composites=False) 643 for branch_graph in branch_graphs 644 ] 645 outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] 646 assert len(set(outs_per_branch)) == 1, outs_per_branch 647 # Store indices of IndexedSlices.indices in `indexed_slice_indices`. 648 for output_idx, branch_outs in enumerate( 649 zip(*branch_outputs_flat_with_composites)): 650 if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: 651 raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" 652 " branches returned: {outputs}".format( 653 op_name="cond" if op_type == _COND else "switch_case", 654 output_idx=output_idx, 655 outputs=branch_outs)) 656 if isinstance(branch_outs[0], ops.IndexedSlices): 657 # indices is the second component of the composite tensor. 658 indexed_slice_indices.append(current_index + 1) 659 if nest.is_sequence_or_composite(branch_outs[0]): 660 current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) 661 elif branch_outs[0] is not None: 662 # `FuncGraph.outputs` does not contain Nones so no need to update the 663 # counter in that case. 664 current_index += 1 665 666 if not indexed_slice_indices: 667 return 668 669 # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus 670 # the Nones. 671 if current_index != len(branch_graphs[0].outputs): 672 raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" 673 "Expected: %i\n" 674 "Actual: %i" % 675 (current_index, len(branch_graphs[0].outputs))) 676 677 # Cast indices with mismatching types to int64. 678 for index in indexed_slice_indices: 679 if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) 680 for bg in branch_graphs): 681 raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " 682 "Found: %s" % 683 str([bg.outputs[index].dtype for bg in branch_graphs])) 684 if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: 685 for branch_graph in branch_graphs: 686 if branch_graph.outputs[index].dtype == dtypes.int32: 687 with branch_graph.as_default(): 688 branch_graph.outputs[index] = math_ops.cast( 689 branch_graph.outputs[index], dtypes.int64) 690 691 for branch_graph in branch_graphs: 692 branch_graph.structured_outputs = _pack_sequence_as( 693 branch_graph.structured_outputs, branch_graph.outputs) 694 695 696def _pack_sequence_as(structured_outputs, op_outputs): 697 """Packs the outputs of the gradient If/Case op. 698 699 The branch functions may contain None's in the list of `structured_outputs`. 700 `op_outputs` has those outputs missing. So we need to add those Nones to the 701 list of `op_outputs` and then pack it in the same structure as 702 `structured_outputs`. 703 704 Args: 705 structured_outputs: structured_outputs from one of the branch functions. 706 op_outputs: List of output tensors of the op. 707 708 Returns: 709 `op_outputs` packed like `structured_outputs`. 710 """ 711 outputs_with_nones = [] 712 counter = 0 713 for output in nest.flatten(structured_outputs, expand_composites=True): 714 if output is None: 715 outputs_with_nones.append(None) 716 else: 717 outputs_with_nones.append(op_outputs[counter]) 718 counter += 1 719 return func_graph_module.pack_sequence_as(structured_outputs, 720 outputs_with_nones) 721 722 723def _wrap_intermediates(func_graph, intermediates): 724 with func_graph.as_default(): 725 return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] 726 727 728def _create_dummy_input(func_graph, template_tensor): 729 """Creates tensors in func_graph to represent template_tensors. 730 731 Args: 732 func_graph: FuncGraph. 733 template_tensor: a tensor in the outer graph. 734 735 Returns: 736 A tensor in func_graph. 737 """ 738 with func_graph.as_default(): 739 return array_ops.placeholder( 740 template_tensor.dtype, shape=template_tensor.shape) 741 742 743def _create_none_optionals(func_graph, n): 744 """Creates `n` `None` optionals in func_graph. 745 746 Args: 747 func_graph: FuncGraph. 748 n: `int` the number of `None` optionals to make. 749 750 Returns: 751 A list of tensors in func_graph. 752 """ 753 with func_graph.as_default(): 754 return [gen_dataset_ops.optional_none() for _ in range(n)] 755 756 757def _create_fakeparams(func_graph, template_tensors): 758 """Create FakeParams for the XLA case.""" 759 with func_graph.as_default(): 760 return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) 761 for t in template_tensors] 762 763 764def _check_same_outputs(op_type, graphs): 765 """Raises an error if `graphs` have different outputs.""" 766 767 def error(branch_idx, error_detail): 768 raise TypeError( 769 "{b0_name} and {bn_name} arguments to {op_name} must have the same " 770 "number, type, and overall structure of return values.\n" 771 "\n" 772 "{b0_name} output: {b0_out}\n" 773 "{bn_name} output: {bn_out}\n" 774 "\n" 775 "Error details:\n" 776 "{detail}".format( 777 b0_name="true_fn" if op_type == _COND else "branches[0]", 778 bn_name=("false_fn" if op_type == _COND else 779 "branches[{}]".format(branch_idx)), 780 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 781 b0_out=graphs[0].structured_outputs, 782 bn_out=graphs[branch_idx].structured_outputs, 783 detail=error_detail)) 784 785 for b in range(1, len(graphs)): 786 try: 787 nest.assert_same_structure( 788 graphs[0].structured_outputs, 789 graphs[b].structured_outputs, 790 expand_composites=True) 791 except (ValueError, TypeError) as e: 792 error(b, str(e)) 793 794 op_type_str = "cond" if op_type == _COND else "case" 795 if len(graphs[0].outputs) != len(graphs[b].outputs): 796 raise ValueError("Lengths of branch outputs of {op_type} must match.\n" 797 "len(graphs[0].outputs): {len_0}\n" 798 "len(graphs[{b}].outputs): {len_b}\n".format( 799 op_type=op_type_str, 800 len_0=len(graphs[0].outputs), 801 b=b, 802 len_b=len(graphs[b].outputs))) 803 for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs): 804 if b0_out.dtype != bn_out.dtype: 805 error(b, "%s and %s have different types" % (b0_out, bn_out)) 806 807 808def _get_output_shapes(*branch_graph_outputs): 809 output_shapes = [] 810 for out_by_branch in zip(*branch_graph_outputs): 811 shape = out_by_branch[0].shape 812 for other_out in out_by_branch[1:]: 813 shape = shape.most_specific_compatible_shape(other_out.shape) 814 output_shapes.append(shape) 815 return output_shapes 816 817 818def _copy_handle_data(external_tensors, *branch_graph_outputs): 819 """Combines shapes in handle data and sets metadata on `external_tensors`.""" 820 for tensors in zip(external_tensors, *branch_graph_outputs): 821 external = tensors[0] 822 internal = tensors[1:] 823 internal_handle_data = [] 824 for tensor in internal: 825 handle_data = handle_data_util.get_resource_handle_data(tensor) 826 # NOTE: Assumes handle data has only one ShapeAndType entry. It's 827 # unclear how to combine different lengths across branches. 828 if not handle_data.is_set or len(handle_data.shape_and_type) != 1: 829 break 830 internal_handle_data.append(handle_data) 831 else: # There is handle data, so we need to combine it. 832 combined_shape = tensor_shape.TensorShape(None) 833 combined_dtype = None 834 for handle_data in internal_handle_data: 835 handle_shape = tensor_shape.TensorShape( 836 handle_data.shape_and_type[0].shape) 837 combined_shape = combined_shape.most_specific_compatible_shape( 838 handle_shape) 839 if combined_dtype is None: 840 combined_dtype = handle_data.shape_and_type[0].dtype 841 elif handle_data.shape_and_type[0].dtype != combined_dtype: 842 # Variants from different branches have different dtypes. The 843 # combined variant has no static dtype. 844 combined_dtype = types_pb2.DT_INVALID 845 combined_handle_data = internal_handle_data[0] 846 combined_handle_data.shape_and_type[0].shape.CopyFrom( 847 combined_shape.as_proto()) 848 combined_handle_data.shape_and_type[0].dtype = combined_dtype 849 handle_data_util.set_handle_data(external, combined_handle_data) 850 851 852def verify_captures(op_type, branch_graphs): 853 """Verify that a branch's tensor is not accessed in another branch fn.""" 854 # Note: It is technically not possible for lower-branch_index branches to 855 # capture tensors from higher-branch_index branches, because of the order of 856 # branch graph construction, but we check all for completeness and to 857 # guard against potential future changes. 858 other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)} 859 for i, branch_graph in enumerate(branch_graphs): 860 for t in branch_graph.external_captures: 861 if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs: 862 branch_names = ["true_fn", "false_fn"] if op_type == _COND else [ 863 "branch {}".format(bi) for bi in range(len(branch_graphs))] 864 raise ValueError( 865 "Tensor {tname} in {b0name} is accessed from {b1name}.".format( 866 tname=t.name, 867 b0name=branch_names[other_branch_graphs[t.graph]], 868 b1name=branch_names[i])) 869 870 871class _CondGradFuncGraph(util.CondBranchFuncGraph): 872 """FuncGraph for the gradient function of the branch of an If op. 873 874 Handles wrapping and unwrapping intermediate values that are captured by the 875 gradient computation in optionals. 876 877 Attributes: 878 op_needs_rewrite: True if any intermediates were captured, meaning the 879 forward If op needs to be written to output the wrapped intermediates. 880 """ 881 882 def __init__(self, name, forward_graph): 883 super(_CondGradFuncGraph, self).__init__( 884 name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access 885 self.op_needs_rewrite = False 886 self._forward_graph = forward_graph 887 # Maps from forward intermediate tensor -> the unwrapped captured 888 # intermediate. 889 self._indirect_captures = {} 890 # Maps unwrapped intermediate -> optional-wrapped intermediate in the 891 # forward graph. 892 self._wrapped_intermediates = collections.OrderedDict() 893 # Raw intermediates captured from the forward graph. Populated iff we're in 894 # an XLA context. 895 self._xla_intermediates = [] 896 # Maps forward intermediate constant valued tensor's id to the constant 897 # created in this graph for that tensor. 898 self._captured_constants = {} 899 900 @property 901 def wrapped_intermediates(self): 902 """The optional-wrapped intermediates captured from the forward graph.""" 903 return list(self._wrapped_intermediates.values()) 904 905 @property 906 def xla_intermediates(self): 907 """Raw intermediates captured from the forward graph if XLA is enabled.""" 908 return self._xla_intermediates 909 910 def _capture_helper(self, tensor, name): 911 if (tensor.graph is not self._forward_graph or 912 any(tensor is t for t in self._forward_graph.inputs) or 913 any(tensor is t for t in self._forward_graph.outputs)): 914 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 915 916 tensor_id = ops.tensor_id(tensor) 917 918 # If `tensor` is a graph-building time constant, we create a constant with 919 # the same value in the backward graph instead of capturing it. 920 if tensor_id in self._captured_constants: 921 return self._captured_constants[tensor_id] 922 elif constant_op.is_constant(tensor): 923 self._captured_constants[tensor_id] = constant_op.constant( 924 tensor_util.constant_value(tensor), dtype=tensor.dtype) 925 return self._captured_constants[tensor_id] 926 927 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 928 # XLA does not yet support optionals, so capture intermediates directly. 929 # TODO(skyewm,jpienaar): can XLA support optionals? 930 if all(tensor is not capture for capture in self.external_captures): 931 self.xla_intermediates.append(tensor) 932 self.op_needs_rewrite = True 933 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 934 935 captured_tensor = self._indirect_captures.get(tensor_id) 936 if captured_tensor is not None: 937 return captured_tensor 938 939 # 'tensor' is an uncaptured intermediate in the forward graph. 940 # If it is not a resource, we wrap it in an optional in the forward graph 941 # and capture the optional normally. We then unwrap the captured optional 942 # value in the gradient graph to get the raw intermediate value. 943 # If it is a resource, we trace the resource up to the input in the forward 944 # graph and capture that. 945 946 if tensor.dtype == dtypes.resource: 947 # Index of the forward graph input corresponding to the resource tensor. 948 index = util.resource_input_index( 949 tensor.name, [t.name for t in self._forward_graph.inputs], 950 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 951 self._forward_graph._functions) 952 # This gets mapped to the corresponding If op input in 953 # `_resolve_grad_inputs`. 954 captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( 955 self._forward_graph.inputs[index], name) 956 else: 957 if tensor_id not in self._wrapped_intermediates: 958 # If the gradient has already been computed for this If op, 'tensor' may 959 # already be wrapped. 960 for consumer in tensor.consumers(): 961 if (consumer.type == "OptionalFromValue" and 962 any(consumer.outputs[0] is output 963 for output in self._forward_graph.outputs)): 964 optional = consumer.outputs[0] 965 break 966 else: 967 # 'tensor' hasn't been wrapped, do it now. 968 with self._forward_graph.as_default(): 969 optional = gen_dataset_ops.optional_from_value([tensor]) 970 self.op_needs_rewrite = True 971 self._wrapped_intermediates[tensor_id] = optional 972 973 optional = self._wrapped_intermediates[tensor_id] 974 captured_optional = super(_CondGradFuncGraph, 975 self)._capture_helper(optional, name) 976 captured_tensor = gen_dataset_ops.optional_get_value( 977 captured_optional, [tensor.dtype], [tensor.shape])[0] 978 979 self._indirect_captures[tensor_id] = captured_tensor 980 return captured_tensor 981 982 983def indexed_case(branch_index, 984 branch_fns, 985 name="indexed_case", 986 lower_using_switch_merge=None): 987 """Like conv_v2, except emits a Case op instead of an If.""" 988 if isinstance(branch_index, int): 989 raise TypeError("branch_index must not be a Python int", branch_index) 990 991 with ops.name_scope(name) as scope: 992 branch_names = [ 993 util.unique_fn_name(scope, "branch{}".format(b)) 994 for b in range(len(branch_fns)) 995 ] 996 997 # Automatic control dependencies are added in defuns, but not in v1 998 # graphs. Propagate that behavior here. 999 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 1000 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 1001 1002 branch_graphs = [] 1003 for branch_name, branch_fn in zip(branch_names, branch_fns): 1004 branch_graphs.append( 1005 func_graph_module.func_graph_from_py_func( 1006 branch_name, 1007 branch_fn, 1008 [], 1009 {}, 1010 func_graph=util.CondBranchFuncGraph( 1011 branch_name, 1012 collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 1013 add_control_dependencies=add_control_dependencies, 1014 op_return_value=branch_index)) 1015 1016 verify_captures(_CASE, branch_graphs) 1017 return _build_case( 1018 branch_index, 1019 branch_graphs, [g.external_captures for g in branch_graphs], 1020 name=scope, 1021 lower_using_switch_merge=lower_using_switch_merge) 1022 1023 1024@ops.RegisterGradient("Case") 1025@ops.RegisterGradient("StatelessCase") 1026def _CaseGrad(op, *grads): # pylint: disable=invalid-name 1027 """The gradient of a Case op produced by tf.switch_case.""" 1028 # Get the Case operator (this logic handles the case where op is a MockOp) 1029 case_op = op.outputs[0].op 1030 branch_graphs = get_func_graphs(case_op) 1031 assert branch_graphs 1032 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 1033 # of a nested cond. 1034 for branch_graph in branch_graphs: 1035 assert branch_graph.outer_graph == case_op.graph 1036 1037 # Create grad functions that compute the gradient of the branch forward 1038 # graphs. These functions will capture tensors from the forward pass 1039 # functions. 1040 branch_grad_graphs = [] 1041 for branch_graph in branch_graphs: 1042 branch_grad_graphs.append( 1043 _create_grad_func(branch_graph, grads, 1044 util.unique_grad_fn_name(branch_graph.name))) 1045 # Replaces output None grads with zeros if at least one branch has non-None 1046 # grad at that index. 1047 _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs) 1048 1049 if any(g.op_needs_rewrite for g in branch_grad_graphs): 1050 # Modify 'op' to output the intermediates needed by the grad functions. Note 1051 # that all needed intermediates are wrapped in optionals. Each optional 1052 # intermediate output will have a value iff its corresponding branch is 1053 # taken. 1054 # NOTE(bjp): if there are any active sessions, this modification to `op` 1055 # may make them unrunnable! 1056 1057 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 1058 # XLA does not yet support optionals, so output intermediates directly and 1059 # make them match via FakeParams, which can be converted to zeros in XLA. 1060 # TODO(bjp,jpienaar): can XLA support optionals? 1061 branches_intermediates = [ 1062 branch_grad_graph.xla_intermediates 1063 for branch_grad_graph in branch_grad_graphs 1064 ] 1065 extra_branch_outputs = _make_intermediates_match_xla( 1066 branch_graphs, branches_intermediates) 1067 else: 1068 branch_intermediates = [ 1069 g.wrapped_intermediates for g in branch_grad_graphs 1070 ] 1071 # Make outputs match by adding none optionals. 1072 extra_branch_outputs = _make_intermediates_match(branch_graphs, 1073 branch_intermediates) 1074 1075 for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): 1076 branch_graph.outputs.extend(extra_outputs) 1077 # TODO(bjp): indicate it's an internal bug if this fails. 1078 _check_same_outputs(_CASE, branch_graphs) 1079 1080 for branch_graph in branch_graphs: 1081 branch_graph.name += "_rewritten" 1082 1083 case_op._set_func_list_attr("branches", [ 1084 util.create_new_tf_function(branch_graph) 1085 for branch_graph in branch_graphs 1086 ]) 1087 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) 1088 case_op._set_shape_list_attr("output_shapes", 1089 branch_graphs[0].output_shapes) 1090 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], 1091 [t.shape for t in extra_branch_outputs[0]]) 1092 1093 # Resolve references to forward graph tensors in grad graphs and ensure 1094 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 1095 branches_grad_inputs = [ 1096 _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, 1097 branch_grad_graph in zip(branch_graphs, branch_grad_graphs) 1098 ] 1099 1100 # This modifies the graphs in branch_grad_graphs. 1101 _make_output_composite_tensors_match(_CASE, branch_grad_graphs) 1102 1103 try: 1104 lowering = case_op._get_attr_bool("_lower_using_switch_merge") 1105 except errors_impl.NotFoundError: 1106 lowering = None 1107 1108 outputs = _build_case( 1109 case_op.inputs[0], 1110 branch_grad_graphs, 1111 branches_grad_inputs, 1112 name="gradient", 1113 lower_using_switch_merge=lowering) 1114 1115 # The predicate has no gradient. 1116 return [None] + outputs 1117 1118 1119def _build_case(branch_index, 1120 branch_graphs, 1121 branch_inputs, 1122 name=None, 1123 lower_using_switch_merge=None): 1124 """Creates an `Case` op from `branch_index`, branch graphs and inputs. 1125 1126 Note that this modifies `branch_graphs` to make the inputs match, and to 1127 output all intermediates values so they're available for the gradient 1128 computation. 1129 1130 `branch_graphs` need not have the same input types, but they must 1131 have the same output types. 1132 1133 Args: 1134 branch_index: integer Tensor 1135 branch_graphs: List of FuncGraph 1136 branch_inputs: List of lists of Tensors to be passed to corresponding 1137 branch_graph as input. 1138 name: the name for the Case op. 1139 lower_using_switch_merge: Lower this op using switch merge ops (optional). 1140 1141 Returns: 1142 A list of Tensors which are the outputs of the Case op. Does not include 1143 added intermediate outputs. 1144 """ 1145 _make_indexed_slices_indices_types_match(_CASE, branch_graphs) 1146 _check_same_outputs(_CASE, branch_graphs) 1147 1148 # Add inputs to branch_graphs to make them match. Note that this modifies the 1149 # graphs in `branch_graphs`. 1150 case_inputs = _make_inputs_match(branch_graphs, branch_inputs) 1151 1152 stateful_ops = [] 1153 for bg in branch_graphs: 1154 stateful_ops.extend([ 1155 op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op) 1156 ]) 1157 1158 if stateful_ops: 1159 op_fn = gen_functional_ops.case 1160 else: 1161 op_fn = gen_functional_ops.stateless_case 1162 1163 # Create the Case op. 1164 with ops.control_dependencies( 1165 sum((list(bg.control_captures) for bg in branch_graphs), [])): 1166 1167 def _make_op(inputs): 1168 case_op, tensors = util.get_op_and_outputs(op_fn( 1169 branch_index, 1170 inputs, [t.dtype for t in branch_graphs[0].outputs], 1171 [util.create_new_tf_function(g) for g in branch_graphs], 1172 output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), 1173 name=name)) 1174 _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) 1175 if case_op is not None: 1176 util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) 1177 util.maybe_propagate_compile_time_consts_in_xla(case_op) 1178 _set_read_only_resource_inputs_attr(case_op, branch_graphs) 1179 # Prevent fetching since the variant outputs can't be fetched directly. 1180 case_op.graph.prevent_fetching(case_op) 1181 1182 # Store the branch graphs so they can be reused during the gradient 1183 # pass. 1184 for i, bg in enumerate(branch_graphs): 1185 bg.outer_graph = ops.get_default_graph() 1186 setattr(case_op, "_branch_graph_{}".format(i), bg) 1187 1188 return tensors 1189 tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs) 1190 1191 # Return identities for each output of the Case op, rather than the output of 1192 # the Case op directly. This makes pruning work if the output of switch_case() 1193 # is fetched: the lowering pass converts the Case outputs into IdentityN 1194 # outputs, which if fetched will cause all ops in the taken branch to be run 1195 # (since it takes all merge ops as input). After lowering, each output 1196 # identity op will end up with only the appropriate merge op as input. 1197 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 1198 # correct output structure 1199 tensors = [array_ops.identity(t) for t in tensors] 1200 1201 return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) 1202 1203 1204def _set_read_only_resource_inputs_attr(op, branch_graphs): 1205 """Sets the list of resource inputs which are read-only. 1206 1207 This is used by AutomaticControlDependencies. 1208 1209 Args: 1210 op: If or Case Operation. 1211 branch_graphs: List of branch FuncGraphs. 1212 """ 1213 # The first entry in `op.inputs` is the predicate which is not passed to 1214 # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1. 1215 read_only_indices = set(range(len(op.inputs) - 1)) 1216 for branch_graph in branch_graphs: 1217 assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen" 1218 if not read_only_indices: 1219 break 1220 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1221 branch_graph) 1222 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1223 # Convert indices in `branch_graphs[i].inputs` to `op.inputs`. 1224 read_only_indices = [i + 1 for i in read_only_indices] 1225 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1226 sorted(read_only_indices)) 1227