Searched refs:body_grad_graph (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/python/ops/ |
D | while_v2_indexed_slices_rewriter.py | 31 def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, argument 67 structured_outputs = body_grad_graph.structured_outputs[3:] 76 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, 79 _rewrite_output_as_tensor(body_grad_graph, output) 84 def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): argument 91 with body_grad_graph.as_default(): 94 idx = body_grad_graph.structured_outputs.index(grad_output_slices) 95 body_grad_graph.structured_outputs[idx] = new_output 96 body_grad_graph.outputs = func_graph.flatten( 97 body_grad_graph.structured_outputs) [all …]
|
D | while_v2.py | 272 body_grad_graph, args = _create_grad_func( 276 if body_grad_graph.while_op_needs_rewrite: 285 new_inputs = body_grad_graph.empty_tensor_lists 297 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 303 grads, body_grad_graph, loop_vars, while_op.inputs) 314 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 319 util.create_new_tf_function(body_grad_graph), 320 output_shapes=[t.shape for t in body_grad_graph.outputs], 325 _copy_handle_data(body_grad_graph.outputs, outputs) 331 return _get_structured_grad_output(outputs, grads, body_grad_graph) [all …]
|