1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""cond_v2 and gradient.
16
17This is a version of cond that emits a single If op, as well as the gradient
18function for If ops produced by cond_v2. This will eventually replace the
19current tf.cond implementation once it reaches feature and performance parity.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.eager import backprop_util
30from tensorflow.python.framework import auto_control_deps
31from tensorflow.python.framework import auto_control_deps_utils as acd
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors_impl
35from tensorflow.python.framework import func_graph as func_graph_module
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import control_flow_util_v2 as util
42from tensorflow.python.ops import custom_gradient
43from tensorflow.python.ops import default_gradient
44from tensorflow.python.ops import gen_dataset_ops
45from tensorflow.python.ops import gen_functional_ops
46from tensorflow.python.ops import gradients_util
47from tensorflow.python.ops import handle_data_util
48from tensorflow.python.ops import math_ops
49from tensorflow.python.util import nest
50
51
52# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
53# that they aren't part of the official public API. These protected members
54# often need to be used by implementation code however. Rather than litter the
55# code with pylint comments, we ignore protected access violations for
56# readability.
57# pylint: disable=protected-access
58
59_COND = 1
60_CASE = 2
61
62
63def cond_v2(pred, true_fn, false_fn, name="cond"):
64  """Like tf.cond, except emits a single If op."""
65  if isinstance(pred, bool):
66    raise TypeError("pred must not be a Python bool", pred)
67
68  if not name:
69    name = "cond"
70
71  with ops.name_scope(name) as scope:
72    true_name = util.unique_fn_name(scope, "true")
73    false_name = util.unique_fn_name(scope, "false")
74
75    # Automatic control dependencies are added in defuns, but not in v1
76    # graphs. Propagate that behavior here.
77    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
78    pred = ops.convert_to_tensor(pred)
79    if (tensor_util.is_tf_type(pred) and
80        (pred.shape.dims is None or pred.shape.dims)):
81      pred = array_ops.squeeze_v2(pred)
82
83    true_graph = func_graph_module.func_graph_from_py_func(
84        true_name,
85        true_fn, [], {},
86        func_graph=util.CondBranchFuncGraph(
87            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
88        add_control_dependencies=add_control_dependencies,
89        op_return_value=pred)
90    false_graph = func_graph_module.func_graph_from_py_func(
91        false_name,
92        false_fn, [], {},
93        func_graph=util.CondBranchFuncGraph(
94            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
95        add_control_dependencies=add_control_dependencies,
96        op_return_value=pred)
97
98    verify_captures(_COND, [true_graph, false_graph])
99    return _build_cond(
100        pred,
101        true_graph,
102        false_graph,
103        true_graph.external_captures,
104        false_graph.external_captures,
105        building_gradient=False,
106        name=scope)
107
108
109@ops.RegisterGradient("StatelessIf")
110@ops.RegisterGradient("If")
111def _IfGrad(op, *grads):  # pylint: disable=invalid-name
112  """The gradient of an If op produced by cond_v2."""
113  # Get the if operator (this logic handles the case where op is a MockOp)
114  if_op = op.outputs[0].op
115  true_graph, false_graph = get_func_graphs(if_op)
116  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
117  # of a nested cond.
118  assert true_graph.outer_graph == if_op.graph
119  assert false_graph.outer_graph == if_op.graph
120
121  # Create grad functions that compute the gradient of the true/false forward
122  # graphs. These functions will capture tensors from the forward pass
123  # functions.
124  true_grad_graph = _create_grad_func(
125      true_graph, grads, util.unique_grad_fn_name(true_graph.name))
126  false_grad_graph = _create_grad_func(
127      false_graph, grads, util.unique_grad_fn_name(false_graph.name))
128
129  # Replaces output None grads with zeros if at least one branch has non-None
130  # grad at that index.
131  _create_zeros_for_none_grads([true_graph, false_graph],
132                               [true_grad_graph, false_grad_graph])
133
134  if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
135    # Modify 'op' to output the intermediates needed by the grad functions. Note
136    # that all needed intermediates are wrapped in optionals. Each optional
137    # intermediate output will have a value iff its corresponding branch is
138    # taken.
139    # NOTE(skyewm): if there are any active sessions, this modification to `op`
140    # may make them unrunnable!
141
142    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
143      # XLA does not yet support optionals, so output intermediates directly and
144      # make them match via FakeParams, which can be converted to zeros in XLA.
145      # TODO(skyewm,jpienaar): can XLA support optionals?
146      true_intermediates = true_grad_graph.xla_intermediates
147      false_intermediates = false_grad_graph.xla_intermediates
148      extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
149          [true_graph, false_graph], [true_intermediates, false_intermediates])
150    else:
151      true_intermediates = true_grad_graph.wrapped_intermediates
152      false_intermediates = false_grad_graph.wrapped_intermediates
153      # Make outputs match by adding none optionals.
154      extra_true_outputs, extra_false_outputs = _make_intermediates_match(
155          [true_graph, false_graph], [true_intermediates, false_intermediates])
156
157    true_graph.outputs.extend(extra_true_outputs)
158    false_graph.outputs.extend(extra_false_outputs)
159    # TODO(skyewm): indicate it's an internal bug if this fails.
160    _check_same_outputs(_COND, [true_graph, false_graph])
161
162    true_graph.name += "_rewritten"
163    false_graph.name += "_rewritten"
164
165    if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
166    if_op._set_func_attr("else_branch",
167                         util.create_new_tf_function(false_graph))
168    if_op._set_type_list_attr("Tout", true_graph.output_types)
169    if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
170    if_op._add_outputs(
171        [t.dtype for t in extra_true_outputs],
172        [t.shape for t in extra_true_outputs])
173
174  # Resolve references to forward graph tensors in grad graphs and ensure
175  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
176  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
177  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
178
179  # This modifies true_grad_graph and false_grad_graph.
180  _make_output_composite_tensors_match(_COND,
181                                       [true_grad_graph, false_grad_graph])
182
183  outputs = _build_cond(
184      if_op.inputs[0],
185      true_grad_graph,
186      false_grad_graph,
187      true_grad_inputs,
188      false_grad_inputs,
189      building_gradient=True,
190  )
191
192  # The predicate has no gradient.
193  return [None] + outputs
194
195
196def _build_cond(pred,
197                true_graph,
198                false_graph,
199                true_inputs,
200                false_inputs,
201                building_gradient,
202                name=None):
203  """Creates an If op from the specified predicate, branch functions and inputs.
204
205  Note that this modifies true_graph and false_graph to make the inputs match,
206  and to output all intermediates values so they're available for the gradient
207  computation.
208
209  true_graph and false_graph need not have the same input types, but they must
210  have the same output types.
211
212  Args:
213    pred: boolean Tensor
214    true_graph: FuncGraph
215    false_graph: FuncGraph
216    true_inputs: a list of Tensors to be passed to true_graph as input.
217    false_inputs: a list of Tensors to be passed to false_graph as input.
218    building_gradient: Whether this is a gradient If op.
219    name: the name for the If op.
220
221  Returns:
222    A list of Tensors which are the outputs of the If op. Does not include added
223    intermediate outputs.
224  """
225  _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
226  _check_same_outputs(_COND, [true_graph, false_graph])
227
228  # Add inputs to true_graph and false_graph to make them match. Note that
229  # this modifies true_graph and false_graph.
230  cond_inputs = _make_inputs_match([true_graph, false_graph],
231                                   [true_inputs, false_inputs])
232  # We do not output intermediates of the gradient If op since this is just
233  # for backwards compatibility with existing code.
234  if not building_gradient and util.output_all_intermediates():
235    # Add all intermediate tensors as function outputs so they're available for
236    # the gradient computation. Since the outputs of the two functions must
237    # match, we wrap all the intermediates in optionals. Each intermediate
238    # output will have a value iff its corresponding branch is taken.
239
240    true_intermediates = _get_intermediates(true_graph)
241    false_intermediates = _get_intermediates(false_graph)
242
243    # Wrap intermediates in optionals.
244    wrapped_true_intermediates = _wrap_intermediates(true_graph,
245                                                     true_intermediates)
246    wrapped_false_intermediates = _wrap_intermediates(false_graph,
247                                                      false_intermediates)
248
249    # Make outputs match by adding none optionals.
250    extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
251        [true_graph, false_graph],
252        [wrapped_true_intermediates, wrapped_false_intermediates])
253
254    true_graph.outputs.extend(extra_true_outputs)
255    false_graph.outputs.extend(extra_false_outputs)
256    _check_same_outputs(_COND, [true_graph, false_graph])
257
258  # Create the If op.
259  with ops.control_dependencies(
260      list(true_graph.control_captures) + list(false_graph.control_captures)):
261    true_stateful_ops = [
262        op for op in true_graph.get_operations() if op._is_stateful
263    ]
264    false_stateful_ops = [
265        op for op in false_graph.get_operations() if op._is_stateful
266    ]
267    if (true_stateful_ops or false_stateful_ops):
268      op_fn = gen_functional_ops._if
269    else:
270      op_fn = gen_functional_ops.stateless_if
271
272    def _make_op(inputs):
273      if_op, tensors = util.get_op_and_outputs(op_fn(
274          pred,
275          inputs, [t.dtype for t in true_graph.outputs],
276          util.create_new_tf_function(true_graph),
277          util.create_new_tf_function(false_graph),
278          output_shapes=_get_output_shapes(true_graph.outputs,
279                                           false_graph.outputs),
280          name=name))
281      _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
282      # `if_op` is None if this is a `StatelessIf` op with no outputs.
283      if if_op is not None:
284        # The true and false graphs have already been created, and we need that
285        # to happen before we know which tensors will be captured and so whether
286        # to wrap the cond in a tf.function. Post-hoc mutation of the branch
287        # `outer_graph` properties seems like the only option if we want to
288        # conditionally wrap in a function.
289        true_graph.outer_graph = ops.get_default_graph()
290        false_graph.outer_graph = ops.get_default_graph()
291        if_op._true_graph = true_graph
292        if_op._false_graph = false_graph
293        util.maybe_set_lowering_attr(if_op)
294        util.maybe_propagate_compile_time_consts_in_xla(if_op)
295        _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
296        # Prevent fetching since the variant outputs can't be fetched directly.
297        if_op.graph.prevent_fetching(if_op)
298      return tensors
299    tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
300
301  # Return identities for each output of the If op, rather than the output of
302  # the If op directly. This makes pruning work if the output of cond() is
303  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
304  # which if fetched will cause all ops in the taken branch to be run (since
305  # it takes all merge ops as input). After lowering, each output identity op
306  # will end up with only the appropriate merge op as input.
307  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
308  # correct output structure
309  tensors = [array_ops.identity(t) for t in tensors]
310
311  return _pack_sequence_as(true_graph.structured_outputs, tensors)
312
313
314def get_func_graphs(op):
315  """Returns `FuncGraph`s for the input op branches.
316
317  Args:
318    op: The If or Case Operation.
319
320  Returns:
321    A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
322    for Case).
323  """
324
325  def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
326    """Generates and returns a FuncGraph for the given branch."""
327    func_graph = None
328    if cached_attr_name is not None:
329      func_graph = getattr(op, cached_attr_name, None)
330    inputs = op.inputs[1:]  # First input is pred.
331    if func_graph is None:
332      input_shapes = [t.shape for t in inputs]
333      func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
334    for external_t, internal_t in zip(inputs, func_graph.inputs):
335      custom_gradient.copy_handle_data(external_t, internal_t)
336    func_graph.reset_captures(zip(inputs, func_graph.inputs))
337    # Link the op so that the gradient code can use it.
338    func_graph._forward_cond = op
339    return func_graph
340
341  if op.type in ["If", "StatelessIf"]:
342    return (_get_func_graph_for_branch(
343        op.get_attr("then_branch"), "_true_graph"),
344            _get_func_graph_for_branch(
345                op.get_attr("else_branch"), "_false_graph"))
346  elif op.type in ["Case", "StatelessCase"]:
347    return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i))
348            for i, branch_fn in enumerate(op.get_attr("branches"))]
349  else:
350    raise ValueError("Unsupported op type: {}".format(op.type))
351
352
353def _grad_fn(func_graph, grads):
354  """The gradient function for each conditional branch.
355
356  This function builds the gradient graph of the corresponding forward-pass
357  conditional branch in `func_graph`. This is done by differentiating
358  func_graph's outputs w.r.t. its inputs.
359
360  Args:
361    func_graph: FuncGraph. The corresponding forward-pass function.
362    grads: The list of input gradient Tensors.
363
364  Returns:
365    The output gradient Tensors.
366  """
367  # Filter out untrainable function outputs.
368  # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
369  # cause _GradientsHelper to raise an exception (e.g. the implementation
370  # doesn't expect 'ys' to contain boolean tensors).
371  assert len(func_graph.outputs) == len(grads)
372  ys = []
373  grad_ys = []
374  for y, grad_y in zip(func_graph.outputs, grads):
375    if not backprop_util.IsTrainable(y):
376      continue
377    ys.append(y)
378    grad_ys.append(grad_y)
379
380  # Build the gradient graph. Note that this builds the gradient computation of
381  # func_graph in the current graph, which requires capturing tensors from
382  # func_graph. The captured func_graph tensors are resolved to external tensors
383  # in _resolve_grad_inputs.
384  result = gradients_util._GradientsHelper(
385      ys, func_graph.inputs, grad_ys=grad_ys,
386      src_graph=func_graph)
387
388  return result
389
390
391def _create_grad_func(func_graph, grads, name):
392  """Returns the FuncGraph representation of _grad_fn."""
393  return func_graph_module.func_graph_from_py_func(
394      name,
395      lambda: _grad_fn(func_graph, grads), [], {},
396      func_graph=_CondGradFuncGraph(name, func_graph))
397
398
399def _resolve_grad_inputs(cond_graph, grad_graph):
400  """Returns the tensors to pass as inputs to `grad_graph`.
401
402  The `grad_graph` may have external references to
403  1. Its outer graph containing the input gradients. These references are kept
404     as is.
405  2. Tensors in the forward pass graph. These tensors may not be "live"
406     when the gradient is being computed. We replace such references by their
407     corresponding tensor in `cond_graph.outer_graph`. In the case of nested
408     control flow or functions, the gradient logic handling
409     `grad_graph.outer_graph` will make sure the tensor from
410     `cond_graph.outer_graph` is also correctly captured.
411
412  Args:
413    cond_graph: FuncGraph. The forward-pass function.
414    grad_graph: FuncGraph. The gradients function.
415
416  Returns:
417    A list of inputs tensors to be passed to grad_graph.
418  """
419  new_inputs = []
420
421  for t in grad_graph.external_captures:
422    # `t` must either be in `grad_graph.outer_graph` or in the forward
423    # `cond_graph`.
424    if t.graph != grad_graph.outer_graph:
425      assert t.graph == cond_graph
426      # `internal_captures` are not treated as intermediates and hence not added
427      # to If op outputs. So we get the outer tensor corresponding to those
428      # from the list of `external_captures`.
429      for i, output in enumerate(t.graph.outputs):
430        if output is t:
431          t = t.graph._forward_cond.outputs[i]
432          break
433      else:
434        for i, output in enumerate(t.graph.internal_captures):
435          if output is t:
436            t = t.graph.external_captures[i]
437            break
438        else:
439          raise ValueError("Could not find external tensor capture {tensor} in "
440                           "captures or outputs".format(tensor=t))
441
442      # Note: We rely on the capturing logic of the gradient If op graph to
443      # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
444      # and while_v2 handle this while building their gradient functions.
445      assert t.graph == cond_graph.outer_graph
446    new_inputs.append(t)
447
448  return new_inputs
449
450
451def _get_intermediates(func_graph):
452  """Returns intermediate tensors of `func_graph` for gradient computation."""
453  intermediates = []
454  for op in func_graph.get_operations():
455    for t in op.outputs:
456      if t in func_graph.inputs: continue
457      if t in func_graph.outputs: continue
458      if t.dtype is dtypes.resource:
459        continue
460      # Accumulating mutexes can cause deadlock.
461      if op.type == "MutexLock":
462        continue
463      intermediates.append(t)
464  return intermediates
465
466
467def _make_intermediates_match(branch_graphs, branch_optionals):
468  """Returns new optionals lists that have matching signatures.
469
470  This is done by mirroring each list in the other using none optionals.
471  There is no merging of like optionals.
472
473  Args:
474    branch_graphs: `list` of `FuncGraph`.
475    branch_optionals: `list` of `list`s of optional `Tensor`s from other
476      branch_graphs
477
478  Returns:
479    A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
480    same number of `Tensor`s, all of which will be optionals of the same
481    shape/type.
482  """
483  new_branch_optionals = []
484  # Since the intermediates are optionals with dtype variant, we only need
485  # enough room for the longest list of intermediates.
486  intermediates_size = max(len(o) for o in branch_optionals)
487  for i, branch_graph in enumerate(branch_graphs):
488    other_optionals = _create_none_optionals(
489        branch_graph, intermediates_size - len(branch_optionals[i]))
490    new_branch_optionals.append(branch_optionals[i] + other_optionals)
491  return new_branch_optionals
492
493
494def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
495  """Like _make_intermediates_match but for the XLA case."""
496  new_branch_intermediates = []
497  for i, branch_graph in enumerate(branch_graphs):
498    other_fakeparams = _create_fakeparams(
499        branch_graph,
500        sum((bi for bi in branch_intermediates
501             if bi is not branch_intermediates[i]), []))
502    num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
503    new_branch_intermediates.append(other_fakeparams[:num_preceding] +
504                                    branch_intermediates[i] +
505                                    other_fakeparams[num_preceding:])
506  return new_branch_intermediates
507
508
509def _make_inputs_match(branch_graphs, branch_inputs):
510  """Modifies branch_graphs so they have the same input signature.
511
512  This method reorders and/or adds parameters to each graph in branch_graphs so
513  they have the same input signature, and updates the 'inputs' and 'captured'
514  fields of each graph accordingly. It uses the input tensors from the outer
515  graph to avoid duplicating shared arguments.
516
517  Args:
518    branch_graphs: a `list` of `FuncGraph`
519    branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
520      inputs for the corresponding graph in `branch_graphs`.
521
522  Returns:
523    A new list of Tensors from the outer graph that are the new inputs for each
524    branch_graph. This is a deduped version of `sum(branch_inputs)`.
525  """
526  assert len(branch_graphs) == len(branch_inputs)
527  added_inputs = set()
528  new_inputs = []
529  for branch_in in branch_inputs:
530    for tensor in branch_in:
531      tensor_id = ops.tensor_id(tensor)
532      if tensor_id not in added_inputs:
533        added_inputs.add(tensor_id)
534        new_inputs.append(tensor)
535
536  for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
537    input_ids = [ops.tensor_id(t) for t in branch_in]
538    branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
539    input_list = []
540    for in_t in new_inputs:
541      param = branch_input_to_param.get(ops.tensor_id(in_t))
542      if param is None:
543        param = _create_dummy_input(branch_graph, in_t)
544      input_list.append(param)
545
546    branch_graph.inputs = input_list
547
548    # Rewrite the FuncGraphs' state to reflect the new inputs.
549    branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
550
551  return new_inputs
552
553
554def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
555  """Creates zeros for None out grads if at least one branch has non-None grad.
556
557  Args:
558    forward_graphs: List of forward FuncGraphs.
559    grad_graphs: List of grad FuncGraphs.
560  """
561  assert len(forward_graphs) == len(grad_graphs)
562  branch_outputs = [g.structured_outputs for g in grad_graphs]
563  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
564  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
565  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
566    if (any(t is None for t in branch_outs) and
567        any(t is not None for t in branch_outs)):
568      for branch_index, t in enumerate(branch_outs):
569        if t is None:
570          with grad_graphs[branch_index].as_default():
571            zeros = default_gradient.zeros_like(
572                forward_graphs[branch_index].inputs[output_idx])
573            grad_graphs[branch_index].structured_outputs[output_idx] = zeros
574
575  for grad_graph in grad_graphs:
576    grad_graph.outputs = [
577        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
578        if t is not None
579    ]
580
581
582def _make_output_composite_tensors_match(op_type, branch_graphs):
583  """Modifies each branch_graph's outputs to have the same output signature.
584
585  Currently the only transformation implemented is turning a Tensor into an
586  equivalent IndexedSlices if the other branch returns an IndexedSlices.
587  Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
588  branch_graphs.
589
590  Args:
591    op_type: _COND or _CASE
592    branch_graphs: `list` of `FuncGraph`
593
594  Raises:
595    TypeError: if a set of outputs cannot be rewritten.
596  """
597  # Note: since this is only used for gradient graphs, we do not expect the
598  # outputs to be structured (e.g. nested lists), and thus do not need to use
599  # nest.flatten, etc.
600  assert branch_graphs
601  branch_outputs = [g.structured_outputs for g in branch_graphs]
602  outputs_per_branch = list(len(outs) for outs in branch_outputs)
603  assert len(set(outputs_per_branch)) == 1, outputs_per_branch
604
605  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
606    if len(set(type(out) for out in branch_outs)) == 1:
607      continue
608    if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
609      continue
610    for branch_idx, branch_out in enumerate(branch_outs):
611      if isinstance(branch_out, ops.IndexedSlices):
612        continue
613      elif isinstance(branch_out, ops.Tensor):
614        with branch_graphs[branch_idx].as_default():
615          branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
616              branch_out)
617      else:
618        raise TypeError(
619            "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
620            "  outputs from all branches: {outputs}".format(
621                op_name="tf.cond" if op_type == _COND else "tf.switch_case",
622                output_idx=output_idx,
623                outputs=branch_outs))
624
625  for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
626    branch_graph.structured_outputs = branch_outs
627    branch_graph.outputs = [
628        t for t in func_graph_module.flatten(branch_outs) if t is not None
629    ]
630
631
632def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
633  """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
634  assert branch_graphs
635  # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
636  indexed_slice_indices = []
637  current_index = 0
638  # Note that this still contains Nones. We leave those in so that error
639  # messages contain the correct indices. We handle the Nones later when
640  # updating `current_index`.
641  branch_outputs_flat_with_composites = [
642      nest.flatten(branch_graph.structured_outputs, expand_composites=False)
643      for branch_graph in branch_graphs
644  ]
645  outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
646  assert len(set(outs_per_branch)) == 1, outs_per_branch
647  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
648  for output_idx, branch_outs in enumerate(
649      zip(*branch_outputs_flat_with_composites)):
650    if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
651      raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
652                      "  branches returned: {outputs}".format(
653                          op_name="cond" if op_type == _COND else "switch_case",
654                          output_idx=output_idx,
655                          outputs=branch_outs))
656    if isinstance(branch_outs[0], ops.IndexedSlices):
657      # indices is the second component of the composite tensor.
658      indexed_slice_indices.append(current_index + 1)
659    if nest.is_sequence_or_composite(branch_outs[0]):
660      current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
661    elif branch_outs[0] is not None:
662      # `FuncGraph.outputs` does not contain Nones so no need to update the
663      # counter in that case.
664      current_index += 1
665
666  if not indexed_slice_indices:
667    return
668
669  # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
670  # the Nones.
671  if current_index != len(branch_graphs[0].outputs):
672    raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
673                     "Expected: %i\n"
674                     "Actual: %i" %
675                     (current_index, len(branch_graphs[0].outputs)))
676
677  # Cast indices with mismatching types to int64.
678  for index in indexed_slice_indices:
679    if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
680           for bg in branch_graphs):
681      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
682                      "Found: %s" %
683                      str([bg.outputs[index].dtype for bg in branch_graphs]))
684    if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
685      for branch_graph in branch_graphs:
686        if branch_graph.outputs[index].dtype == dtypes.int32:
687          with branch_graph.as_default():
688            branch_graph.outputs[index] = math_ops.cast(
689                branch_graph.outputs[index], dtypes.int64)
690
691  for branch_graph in branch_graphs:
692    branch_graph.structured_outputs = _pack_sequence_as(
693        branch_graph.structured_outputs, branch_graph.outputs)
694
695
696def _pack_sequence_as(structured_outputs, op_outputs):
697  """Packs the outputs of the gradient If/Case op.
698
699  The branch functions may contain None's in the list of `structured_outputs`.
700  `op_outputs` has those outputs missing. So we need to add those Nones to the
701  list of `op_outputs` and then pack it in the same structure as
702  `structured_outputs`.
703
704  Args:
705    structured_outputs: structured_outputs from one of the branch functions.
706    op_outputs: List of output tensors of the op.
707
708  Returns:
709    `op_outputs` packed like `structured_outputs`.
710  """
711  outputs_with_nones = []
712  counter = 0
713  for output in nest.flatten(structured_outputs, expand_composites=True):
714    if output is None:
715      outputs_with_nones.append(None)
716    else:
717      outputs_with_nones.append(op_outputs[counter])
718      counter += 1
719  return func_graph_module.pack_sequence_as(structured_outputs,
720                                            outputs_with_nones)
721
722
723def _wrap_intermediates(func_graph, intermediates):
724  with func_graph.as_default():
725    return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
726
727
728def _create_dummy_input(func_graph, template_tensor):
729  """Creates tensors in func_graph to represent template_tensors.
730
731  Args:
732    func_graph: FuncGraph.
733    template_tensor: a tensor in the outer graph.
734
735  Returns:
736    A tensor in func_graph.
737  """
738  with func_graph.as_default():
739    return array_ops.placeholder(
740        template_tensor.dtype, shape=template_tensor.shape)
741
742
743def _create_none_optionals(func_graph, n):
744  """Creates `n` `None` optionals in func_graph.
745
746  Args:
747    func_graph: FuncGraph.
748    n: `int` the number of `None` optionals to make.
749
750  Returns:
751    A list of tensors in func_graph.
752  """
753  with func_graph.as_default():
754    return [gen_dataset_ops.optional_none() for _ in range(n)]
755
756
757def _create_fakeparams(func_graph, template_tensors):
758  """Create FakeParams for the XLA case."""
759  with func_graph.as_default():
760    return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
761            for t in template_tensors]
762
763
764def _check_same_outputs(op_type, graphs):
765  """Raises an error if `graphs` have different outputs."""
766
767  def error(branch_idx, error_detail):
768    raise TypeError(
769        "{b0_name} and {bn_name} arguments to {op_name} must have the same "
770        "number, type, and overall structure of return values.\n"
771        "\n"
772        "{b0_name} output: {b0_out}\n"
773        "{bn_name} output: {bn_out}\n"
774        "\n"
775        "Error details:\n"
776        "{detail}".format(
777            b0_name="true_fn" if op_type == _COND else "branches[0]",
778            bn_name=("false_fn" if op_type == _COND else
779                     "branches[{}]".format(branch_idx)),
780            op_name="tf.cond" if op_type == _COND else "tf.switch_case",
781            b0_out=graphs[0].structured_outputs,
782            bn_out=graphs[branch_idx].structured_outputs,
783            detail=error_detail))
784
785  for b in range(1, len(graphs)):
786    try:
787      nest.assert_same_structure(
788          graphs[0].structured_outputs,
789          graphs[b].structured_outputs,
790          expand_composites=True)
791    except (ValueError, TypeError) as e:
792      error(b, str(e))
793
794    op_type_str = "cond" if op_type == _COND else "case"
795    if len(graphs[0].outputs) != len(graphs[b].outputs):
796      raise ValueError("Lengths of branch outputs of {op_type} must match.\n"
797                       "len(graphs[0].outputs): {len_0}\n"
798                       "len(graphs[{b}].outputs): {len_b}\n".format(
799                           op_type=op_type_str,
800                           len_0=len(graphs[0].outputs),
801                           b=b,
802                           len_b=len(graphs[b].outputs)))
803    for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
804      if b0_out.dtype != bn_out.dtype:
805        error(b, "%s and %s have different types" % (b0_out, bn_out))
806
807
808def _get_output_shapes(*branch_graph_outputs):
809  output_shapes = []
810  for out_by_branch in zip(*branch_graph_outputs):
811    shape = out_by_branch[0].shape
812    for other_out in out_by_branch[1:]:
813      shape = shape.most_specific_compatible_shape(other_out.shape)
814    output_shapes.append(shape)
815  return output_shapes
816
817
818def _copy_handle_data(external_tensors, *branch_graph_outputs):
819  """Combines shapes in handle data and sets metadata on `external_tensors`."""
820  for tensors in zip(external_tensors, *branch_graph_outputs):
821    external = tensors[0]
822    internal = tensors[1:]
823    internal_handle_data = []
824    for tensor in internal:
825      handle_data = handle_data_util.get_resource_handle_data(tensor)
826      # NOTE: Assumes handle data has only one ShapeAndType entry. It's
827      # unclear how to combine different lengths across branches.
828      if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
829        break
830      internal_handle_data.append(handle_data)
831    else:  # There is handle data, so we need to combine it.
832      combined_shape = tensor_shape.TensorShape(None)
833      combined_dtype = None
834      for handle_data in internal_handle_data:
835        handle_shape = tensor_shape.TensorShape(
836            handle_data.shape_and_type[0].shape)
837        combined_shape = combined_shape.most_specific_compatible_shape(
838            handle_shape)
839        if combined_dtype is None:
840          combined_dtype = handle_data.shape_and_type[0].dtype
841        elif handle_data.shape_and_type[0].dtype != combined_dtype:
842          # Variants from different branches have different dtypes. The
843          # combined variant has no static dtype.
844          combined_dtype = types_pb2.DT_INVALID
845      combined_handle_data = internal_handle_data[0]
846      combined_handle_data.shape_and_type[0].shape.CopyFrom(
847          combined_shape.as_proto())
848      combined_handle_data.shape_and_type[0].dtype = combined_dtype
849      handle_data_util.set_handle_data(external, combined_handle_data)
850
851
852def verify_captures(op_type, branch_graphs):
853  """Verify that a branch's tensor is not accessed in another branch fn."""
854  # Note: It is technically not possible for lower-branch_index branches to
855  # capture tensors from higher-branch_index branches, because of the order of
856  # branch graph construction, but we check all for completeness and to
857  # guard against potential future changes.
858  other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
859  for i, branch_graph in enumerate(branch_graphs):
860    for t in branch_graph.external_captures:
861      if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
862        branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
863            "branch {}".format(bi) for bi in range(len(branch_graphs))]
864        raise ValueError(
865            "Tensor {tname} in {b0name} is accessed from {b1name}.".format(
866                tname=t.name,
867                b0name=branch_names[other_branch_graphs[t.graph]],
868                b1name=branch_names[i]))
869
870
871class _CondGradFuncGraph(util.CondBranchFuncGraph):
872  """FuncGraph for the gradient function of the branch of an If op.
873
874  Handles wrapping and unwrapping intermediate values that are captured by the
875  gradient computation in optionals.
876
877  Attributes:
878    op_needs_rewrite: True if any intermediates were captured, meaning the
879      forward If op needs to be written to output the wrapped intermediates.
880  """
881
882  def __init__(self, name, forward_graph):
883    super(_CondGradFuncGraph, self).__init__(
884        name, collections=ops.get_default_graph()._collections)  # pylint: disable=protected-access
885    self.op_needs_rewrite = False
886    self._forward_graph = forward_graph
887    # Maps from forward intermediate tensor -> the unwrapped captured
888    # intermediate.
889    self._indirect_captures = {}
890    # Maps unwrapped intermediate -> optional-wrapped intermediate in the
891    # forward graph.
892    self._wrapped_intermediates = collections.OrderedDict()
893    # Raw intermediates captured from the forward graph. Populated iff we're in
894    # an XLA context.
895    self._xla_intermediates = []
896    # Maps forward intermediate constant valued tensor's id to the constant
897    # created in this graph for that tensor.
898    self._captured_constants = {}
899
900  @property
901  def wrapped_intermediates(self):
902    """The optional-wrapped intermediates captured from the forward graph."""
903    return list(self._wrapped_intermediates.values())
904
905  @property
906  def xla_intermediates(self):
907    """Raw intermediates captured from the forward graph if XLA is enabled."""
908    return self._xla_intermediates
909
910  def _capture_helper(self, tensor, name):
911    if (tensor.graph is not self._forward_graph or
912        any(tensor is t for t in self._forward_graph.inputs) or
913        any(tensor is t for t in self._forward_graph.outputs)):
914      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
915
916    tensor_id = ops.tensor_id(tensor)
917
918    # If `tensor` is a graph-building time constant, we create a constant with
919    # the same value in the backward graph instead of capturing it.
920    if tensor_id in self._captured_constants:
921      return self._captured_constants[tensor_id]
922    elif constant_op.is_constant(tensor):
923      self._captured_constants[tensor_id] = constant_op.constant(
924          tensor_util.constant_value(tensor), dtype=tensor.dtype)
925      return self._captured_constants[tensor_id]
926
927    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
928      # XLA does not yet support optionals, so capture intermediates directly.
929      # TODO(skyewm,jpienaar): can XLA support optionals?
930      if all(tensor is not capture for capture in self.external_captures):
931        self.xla_intermediates.append(tensor)
932        self.op_needs_rewrite = True
933      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
934
935    captured_tensor = self._indirect_captures.get(tensor_id)
936    if captured_tensor is not None:
937      return captured_tensor
938
939    # 'tensor' is an uncaptured intermediate in the forward graph.
940    # If it is not a resource, we wrap it in an optional in the forward graph
941    # and capture the optional normally. We then unwrap the captured optional
942    # value in the gradient graph to get the raw intermediate value.
943    # If it is a resource, we trace the resource up to the input in the forward
944    # graph and capture that.
945
946    if tensor.dtype == dtypes.resource:
947      # Index of the forward graph input corresponding to the resource tensor.
948      index = util.resource_input_index(
949          tensor.name, [t.name for t in self._forward_graph.inputs],
950          {op.name: op.node_def for op in self._forward_graph.get_operations()},
951          self._forward_graph._functions)
952      # This gets mapped to the corresponding If op input in
953      # `_resolve_grad_inputs`.
954      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
955          self._forward_graph.inputs[index], name)
956    else:
957      if tensor_id not in self._wrapped_intermediates:
958        # If the gradient has already been computed for this If op, 'tensor' may
959        # already be wrapped.
960        for consumer in tensor.consumers():
961          if (consumer.type == "OptionalFromValue" and
962              any(consumer.outputs[0] is output
963                  for output in self._forward_graph.outputs)):
964            optional = consumer.outputs[0]
965            break
966        else:
967          # 'tensor' hasn't been wrapped, do it now.
968          with self._forward_graph.as_default():
969            optional = gen_dataset_ops.optional_from_value([tensor])
970          self.op_needs_rewrite = True
971        self._wrapped_intermediates[tensor_id] = optional
972
973      optional = self._wrapped_intermediates[tensor_id]
974      captured_optional = super(_CondGradFuncGraph,
975                                self)._capture_helper(optional, name)
976      captured_tensor = gen_dataset_ops.optional_get_value(
977          captured_optional, [tensor.dtype], [tensor.shape])[0]
978
979    self._indirect_captures[tensor_id] = captured_tensor
980    return captured_tensor
981
982
983def indexed_case(branch_index,
984                 branch_fns,
985                 name="indexed_case",
986                 lower_using_switch_merge=None):
987  """Like conv_v2, except emits a Case op instead of an If."""
988  if isinstance(branch_index, int):
989    raise TypeError("branch_index must not be a Python int", branch_index)
990
991  with ops.name_scope(name) as scope:
992    branch_names = [
993        util.unique_fn_name(scope, "branch{}".format(b))
994        for b in range(len(branch_fns))
995    ]
996
997    # Automatic control dependencies are added in defuns, but not in v1
998    # graphs. Propagate that behavior here.
999    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
1000    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
1001
1002    branch_graphs = []
1003    for branch_name, branch_fn in zip(branch_names, branch_fns):
1004      branch_graphs.append(
1005          func_graph_module.func_graph_from_py_func(
1006              branch_name,
1007              branch_fn,
1008              [],
1009              {},
1010              func_graph=util.CondBranchFuncGraph(
1011                  branch_name,
1012                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
1013              add_control_dependencies=add_control_dependencies,
1014              op_return_value=branch_index))
1015
1016    verify_captures(_CASE, branch_graphs)
1017    return _build_case(
1018        branch_index,
1019        branch_graphs, [g.external_captures for g in branch_graphs],
1020        name=scope,
1021        lower_using_switch_merge=lower_using_switch_merge)
1022
1023
1024@ops.RegisterGradient("Case")
1025@ops.RegisterGradient("StatelessCase")
1026def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
1027  """The gradient of a Case op produced by tf.switch_case."""
1028  # Get the Case operator (this logic handles the case where op is a MockOp)
1029  case_op = op.outputs[0].op
1030  branch_graphs = get_func_graphs(case_op)
1031  assert branch_graphs
1032  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
1033  # of a nested cond.
1034  for branch_graph in branch_graphs:
1035    assert branch_graph.outer_graph == case_op.graph
1036
1037  # Create grad functions that compute the gradient of the branch forward
1038  # graphs. These functions will capture tensors from the forward pass
1039  # functions.
1040  branch_grad_graphs = []
1041  for branch_graph in branch_graphs:
1042    branch_grad_graphs.append(
1043        _create_grad_func(branch_graph, grads,
1044                          util.unique_grad_fn_name(branch_graph.name)))
1045  # Replaces output None grads with zeros if at least one branch has non-None
1046  # grad at that index.
1047  _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
1048
1049  if any(g.op_needs_rewrite for g in branch_grad_graphs):
1050    # Modify 'op' to output the intermediates needed by the grad functions. Note
1051    # that all needed intermediates are wrapped in optionals. Each optional
1052    # intermediate output will have a value iff its corresponding branch is
1053    # taken.
1054    # NOTE(bjp): if there are any active sessions, this modification to `op`
1055    # may make them unrunnable!
1056
1057    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
1058      # XLA does not yet support optionals, so output intermediates directly and
1059      # make them match via FakeParams, which can be converted to zeros in XLA.
1060      # TODO(bjp,jpienaar): can XLA support optionals?
1061      branches_intermediates = [
1062          branch_grad_graph.xla_intermediates
1063          for branch_grad_graph in branch_grad_graphs
1064      ]
1065      extra_branch_outputs = _make_intermediates_match_xla(
1066          branch_graphs, branches_intermediates)
1067    else:
1068      branch_intermediates = [
1069          g.wrapped_intermediates for g in branch_grad_graphs
1070      ]
1071      # Make outputs match by adding none optionals.
1072      extra_branch_outputs = _make_intermediates_match(branch_graphs,
1073                                                       branch_intermediates)
1074
1075    for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
1076      branch_graph.outputs.extend(extra_outputs)
1077    # TODO(bjp): indicate it's an internal bug if this fails.
1078    _check_same_outputs(_CASE, branch_graphs)
1079
1080    for branch_graph in branch_graphs:
1081      branch_graph.name += "_rewritten"
1082
1083    case_op._set_func_list_attr("branches", [
1084        util.create_new_tf_function(branch_graph)
1085        for branch_graph in branch_graphs
1086    ])
1087    case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
1088    case_op._set_shape_list_attr("output_shapes",
1089                                 branch_graphs[0].output_shapes)
1090    case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
1091                         [t.shape for t in extra_branch_outputs[0]])
1092
1093  # Resolve references to forward graph tensors in grad graphs and ensure
1094  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
1095  branches_grad_inputs = [
1096      _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
1097      branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
1098  ]
1099
1100  # This modifies the graphs in branch_grad_graphs.
1101  _make_output_composite_tensors_match(_CASE, branch_grad_graphs)
1102
1103  try:
1104    lowering = case_op._get_attr_bool("_lower_using_switch_merge")
1105  except errors_impl.NotFoundError:
1106    lowering = None
1107
1108  outputs = _build_case(
1109      case_op.inputs[0],
1110      branch_grad_graphs,
1111      branches_grad_inputs,
1112      name="gradient",
1113      lower_using_switch_merge=lowering)
1114
1115  # The predicate has no gradient.
1116  return [None] + outputs
1117
1118
1119def _build_case(branch_index,
1120                branch_graphs,
1121                branch_inputs,
1122                name=None,
1123                lower_using_switch_merge=None):
1124  """Creates an `Case` op from `branch_index`, branch graphs and inputs.
1125
1126  Note that this modifies `branch_graphs` to make the inputs match, and to
1127  output all intermediates values so they're available for the gradient
1128  computation.
1129
1130  `branch_graphs` need not have the same input types, but they must
1131  have the same output types.
1132
1133  Args:
1134    branch_index: integer Tensor
1135    branch_graphs: List of FuncGraph
1136    branch_inputs: List of lists of Tensors to be passed to corresponding
1137      branch_graph as input.
1138    name: the name for the Case op.
1139    lower_using_switch_merge: Lower this op using switch merge ops (optional).
1140
1141  Returns:
1142    A list of Tensors which are the outputs of the Case op. Does not include
1143    added intermediate outputs.
1144  """
1145  _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
1146  _check_same_outputs(_CASE, branch_graphs)
1147
1148  # Add inputs to branch_graphs to make them match. Note that this modifies the
1149  # graphs in `branch_graphs`.
1150  case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
1151
1152  stateful_ops = []
1153  for bg in branch_graphs:
1154    stateful_ops.extend([
1155        op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
1156    ])
1157
1158  if stateful_ops:
1159    op_fn = gen_functional_ops.case
1160  else:
1161    op_fn = gen_functional_ops.stateless_case
1162
1163  # Create the Case op.
1164  with ops.control_dependencies(
1165      sum((list(bg.control_captures) for bg in branch_graphs), [])):
1166
1167    def _make_op(inputs):
1168      case_op, tensors = util.get_op_and_outputs(op_fn(
1169          branch_index,
1170          inputs, [t.dtype for t in branch_graphs[0].outputs],
1171          [util.create_new_tf_function(g) for g in branch_graphs],
1172          output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
1173          name=name))
1174      _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
1175      if case_op is not None:
1176        util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
1177        util.maybe_propagate_compile_time_consts_in_xla(case_op)
1178        _set_read_only_resource_inputs_attr(case_op, branch_graphs)
1179        # Prevent fetching since the variant outputs can't be fetched directly.
1180        case_op.graph.prevent_fetching(case_op)
1181
1182        # Store the branch graphs so they can be reused during the gradient
1183        # pass.
1184        for i, bg in enumerate(branch_graphs):
1185          bg.outer_graph = ops.get_default_graph()
1186          setattr(case_op, "_branch_graph_{}".format(i), bg)
1187
1188      return tensors
1189    tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
1190
1191  # Return identities for each output of the Case op, rather than the output of
1192  # the Case op directly. This makes pruning work if the output of switch_case()
1193  # is fetched: the lowering pass converts the Case outputs into IdentityN
1194  # outputs, which if fetched will cause all ops in the taken branch to be run
1195  # (since it takes all merge ops as input). After lowering, each output
1196  # identity op will end up with only the appropriate merge op as input.
1197  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
1198  # correct output structure
1199  tensors = [array_ops.identity(t) for t in tensors]
1200
1201  return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
1202
1203
1204def _set_read_only_resource_inputs_attr(op, branch_graphs):
1205  """Sets the list of resource inputs which are read-only.
1206
1207  This is used by AutomaticControlDependencies.
1208
1209  Args:
1210    op: If or Case Operation.
1211    branch_graphs: List of branch FuncGraphs.
1212  """
1213  # The first entry in `op.inputs` is the predicate which is not passed to
1214  # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1.
1215  read_only_indices = set(range(len(op.inputs) - 1))
1216  for branch_graph in branch_graphs:
1217    assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen"
1218    if not read_only_indices:
1219      break
1220    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1221        branch_graph)
1222    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1223  # Convert indices in `branch_graphs[i].inputs` to `op.inputs`.
1224  read_only_indices = [i + 1 for i in read_only_indices]
1225  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1226                        sorted(read_only_indices))
1227