Home
last modified time | relevance | path

Searched refs:body_grad_graph (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/python/ops/
Dwhile_v2_indexed_slices_rewriter.py31 def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, argument
67 structured_outputs = body_grad_graph.structured_outputs[3:]
76 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
79 _rewrite_output_as_tensor(body_grad_graph, output)
84 def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): argument
91 with body_grad_graph.as_default():
94 idx = body_grad_graph.structured_outputs.index(grad_output_slices)
95 body_grad_graph.structured_outputs[idx] = new_output
96 body_grad_graph.outputs = func_graph.flatten(
97 body_grad_graph.structured_outputs)
[all …]
Dwhile_v2.py272 body_grad_graph, args = _create_grad_func(
276 if body_grad_graph.while_op_needs_rewrite:
285 new_inputs = body_grad_graph.empty_tensor_lists
297 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
303 grads, body_grad_graph, loop_vars, while_op.inputs)
314 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
319 util.create_new_tf_function(body_grad_graph),
320 output_shapes=[t.shape for t in body_grad_graph.outputs],
325 _copy_handle_data(body_grad_graph.outputs, outputs)
331 return _get_structured_grad_output(outputs, grads, body_grad_graph)
[all …]