1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/framework/while_gradients.h"
17
18 #include "tensorflow/cc/framework/gradients.h"
19 #include "tensorflow/cc/framework/scope_internal.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/while_loop.h"
23
24 namespace tensorflow {
25 namespace {
26
27 using ops::BodyGraphBuilderFn;
28 using ops::BuildWhileLoop;
29 using ops::CondGraphBuilderFn;
30
ToOutput(OutputTensor output_tensor)31 Output ToOutput(OutputTensor output_tensor) {
32 return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
33 }
34
ToOutputVector(const std::vector<OutputTensor> & output_tensors)35 std::vector<Output> ToOutputVector(
36 const std::vector<OutputTensor>& output_tensors) {
37 size_t n = output_tensors.size();
38 std::vector<Output> result;
39 result.reserve(n);
40 for (int i = 0; i < n; ++i) result.push_back(ToOutput(output_tensors[i]));
41 return result;
42 }
43
44 // The backprop loop counter and main backprop loop run in their own execution
45 // frame (conceptually, the main forward loop and forward loop counter run
46 // together in a frame, then the backprop loop counter and backprop loop run
47 // together in a different frame). This returns the frame name to use for the
48 // backprop while loops.
49 // TODO(skyewm): make sure this is unique among existing frame names
BackPropFrameName(const string & forward_frame_name)50 string BackPropFrameName(const string& forward_frame_name) {
51 return strings::StrCat(forward_frame_name, "_backprop");
52 }
53
54 // Creates a loop that counts the number of iterations performed by the
55 // while loop associated with `while_ctx`. The returned output yields the
56 // iteration count.
AddForwardLoopCounter(WhileContext * while_ctx,const Scope & scope,Output * count)57 Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
58 Output* count) {
59 // Create while loop:
60 // i = 0
61 // while forward loop predicate is true:
62 // ++i
63
64 Output zero = ops::Const(scope, 0, {});
65
66 // Condition function that returns condition output from original while loop.
67 CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
68 const std::vector<Output>& inputs,
69 Output* output) {
70 *output = ToOutput(while_ctx->cond_output());
71 return Status::OK();
72 };
73
74 // Body function that adds one to input.
75 BodyGraphBuilderFn body_fn = [](const Scope& scope,
76 const std::vector<Output>& inputs,
77 std::vector<Output>* outputs) {
78 DCHECK_EQ(inputs.size(), 1);
79 outputs->emplace_back(ops::Add(scope, inputs[0], 1));
80 return scope.status();
81 };
82
83 // Note that this loop runs in the same execution frame as the forward loop.
84 std::vector<Output> outputs;
85 TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
86 while_ctx->frame_name(), &outputs,
87 /* create_while_ctx */ false));
88 *count = outputs[0];
89 return Status::OK();
90 }
91
92 // Creates a loop that executes `loop_count` times. The returned output is the
93 // boolean predicate indicating if the loop is still executing. This is used to
94 // drive the gradient computation for the while loop associated with
95 // `while_ctx`.
AddBackPropLoopCounter(WhileContext * while_ctx,const Output & loop_count,const Scope & scope,Output * backprop_execution_pred)96 Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
97 const Scope& scope,
98 Output* backprop_execution_pred) {
99 // Create while loop:
100 // n = loop_count
101 // while n > 0:
102 // --n
103
104 // Condition function that returns input > 0.
105 CondGraphBuilderFn cond_fn = [](const Scope& scope,
106 const std::vector<Output>& inputs,
107 Output* output) {
108 DCHECK_EQ(inputs.size(), 1);
109 *output = ops::Greater(scope, inputs[0], 0);
110 return scope.status();
111 };
112
113 // Body function that subtracts one from input.
114 BodyGraphBuilderFn body_fn = [](const Scope& scope,
115 const std::vector<Output>& inputs,
116 std::vector<Output>* outputs) {
117 DCHECK_EQ(inputs.size(), 1);
118 outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
119 return scope.status();
120 };
121
122 string frame_name = BackPropFrameName(while_ctx->frame_name());
123 std::vector<Output> outputs;
124 TF_RETURN_IF_ERROR(BuildWhileLoop(
125 scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
126 /* create_while_ctx */ false, backprop_execution_pred));
127 return Status::OK();
128 }
129
130 // Creates the main backprop loop that computes the gradient of the loop
131 // associated with `while_ctx`. `grad_inputs` are the partial derivatives
132 // w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
133 // the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
134 // The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
135 // returned in `grad_outputs`.
AddWhileGradientLoop(WhileContext * while_ctx,const std::vector<Output> & grad_inputs,const Output & backprop_execution_pred,const Scope & parent_scope,std::vector<Output> * grad_outputs)136 Status AddWhileGradientLoop(WhileContext* while_ctx,
137 const std::vector<Output>& grad_inputs,
138 const Output& backprop_execution_pred,
139 const Scope& parent_scope,
140 std::vector<Output>* grad_outputs) {
141 DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
142 DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
143
144 Scope scope = parent_scope.NewSubScope("while");
145
146 // Create while loop:
147 // while backprop_execution_pred:
148 // forward loop body gradient
149
150 // Condition function that returns 'backprop_execution_pred'.
151 CondGraphBuilderFn cond_fn = [backprop_execution_pred](
152 const Scope& scope,
153 const std::vector<Output>& inputs,
154 Output* output) {
155 *output = backprop_execution_pred;
156 return Status::OK();
157 };
158
159 // Body function that builds while body gradient subgraph.
160 BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
161 const std::vector<Output>& inputs,
162 std::vector<Output>* outputs) {
163 std::vector<Output> body_outputs =
164 ToOutputVector(while_ctx->body_outputs());
165 std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
166 return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
167 outputs);
168 };
169
170 string frame_name = BackPropFrameName(while_ctx->frame_name());
171 TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
172 frame_name, grad_outputs,
173 /* create_while_ctx */ false));
174 return Status::OK();
175 }
176
177 } // namespace
178
AddWhileLoopGradient(WhileContext * while_ctx,const Scope & scope,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)179 Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
180 const std::vector<Output>& grad_inputs,
181 std::vector<Output>* grad_outputs) {
182 Output forward_loop_count;
183 TF_RETURN_IF_ERROR(AddForwardLoopCounter(
184 while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
185
186 // TODO(skyewm): can we combine the backprop loop counter and main gradient
187 // loop into a single loop? The original Python code doesn't combine the
188 // loops, but I'm not sure why.
189 Output backprop_counter_cond;
190 TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
191 while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
192 &backprop_counter_cond));
193
194 return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
195 scope, grad_outputs);
196 }
197
198 } // namespace tensorflow
199