1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include "tensorflow/c/c_test_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22
23 using tensorflow::GraphDef;
24
25 namespace {
26
27 class CApiWhileLoopTest : public ::testing::Test {
28 protected:
CApiWhileLoopTest()29 CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
30
~CApiWhileLoopTest()31 ~CApiWhileLoopTest() override {
32 TF_DeleteGraph(graph_);
33 TF_DeleteStatus(s_);
34 }
35
Init(int ninputs)36 void Init(int ninputs) {
37 DCHECK(inputs_.empty());
38 DCHECK_GT(ninputs, 0);
39
40 for (int i = 0; i < ninputs; ++i) {
41 TF_Operation* placeholder = Placeholder(
42 graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
43 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
44 inputs_.push_back({placeholder, 0});
45 }
46
47 original_graph_description_ = GraphDebugString();
48
49 params_.reset(new TF_WhileParams(
50 TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
51 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
52 ASSERT_EQ(original_graph_description_, GraphDebugString())
53 << "TF_NewWhile() altered graph";
54
55 params_->name = "test_loop";
56
57 // Initialize outputs_ so we can easily detect errors/bugs
58 outputs_.resize(ninputs, {nullptr, -1});
59 }
60
ExpectOK()61 void ExpectOK() {
62 TF_FinishWhile(params_.get(), s_, &outputs_[0]);
63 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
64 }
65
ExpectError(TF_Code expected_code,const string & expected_msg)66 void ExpectError(TF_Code expected_code, const string& expected_msg) {
67 TF_FinishWhile(params_.get(), s_, &outputs_[0]);
68 EXPECT_EQ(expected_code, TF_GetCode(s_));
69 EXPECT_EQ(expected_msg, TF_Message(s_));
70 // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
71 // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
72 // "TF_FinishWhile() altered graph on error";
73 }
74
Run(std::initializer_list<int> input_values)75 void Run(std::initializer_list<int> input_values) {
76 Run(outputs_, input_values);
77 }
78
Run(const std::vector<TF_Output> & run_outputs,std::initializer_list<int> input_values)79 void Run(const std::vector<TF_Output>& run_outputs,
80 std::initializer_list<int> input_values) {
81 DCHECK_EQ(inputs_.size(), input_values.size());
82 std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
83 int i = 0;
84 for (int v : input_values) {
85 inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
86 ++i;
87 }
88 // TODO(skyewm): use std::make_unique or absl::make_unique when possible.
89 csession_.reset(new CSession(graph_, s_));
90 csession_->SetInputs(inputs);
91 csession_->SetOutputs(run_outputs);
92 csession_->Run(s_);
93 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94 }
95
ExpectOutputValue(int idx,int expected_value)96 void ExpectOutputValue(int idx, int expected_value) {
97 TF_Tensor* out = csession_->output_tensor(idx);
98 ASSERT_TRUE(out != nullptr);
99 EXPECT_EQ(TF_INT32, TF_TensorType(out));
100 EXPECT_EQ(0, TF_NumDims(out));
101 ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
102 int32_t* data = static_cast<int32_t*>(TF_TensorData(out));
103 EXPECT_EQ(expected_value, *data);
104 }
105
106 // Create a valid conditional graph. Useful for testing unrelated errors.
CreateCondGraph()107 void CreateCondGraph() {
108 TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
109 TF_Operation* less_than =
110 LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
111 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
112 params_->cond_output = {less_than, 0};
113 }
114
GraphDebugString() const115 string GraphDebugString() const {
116 TF_Buffer* buf = TF_NewBuffer();
117 TF_GraphToGraphDef(graph_, buf, s_);
118 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119 GraphDef def;
120 bool success = def.ParseFromArray(buf->data, buf->length);
121 DCHECK(success);
122 TF_DeleteBuffer(buf);
123 return def.DebugString();
124 }
125
126 TF_Status* s_;
127 TF_Graph* graph_;
128 std::vector<TF_Output> inputs_; // The inputs to the while loop
129 std::vector<TF_Output> outputs_; // The final outputs of the while loop
130 std::unique_ptr<TF_WhileParams> params_;
131 std::unique_ptr<CSession> csession_;
132
133 private:
134 // Used to verify that errors don't change graph_
135 string original_graph_description_;
136 };
137
TEST_F(CApiWhileLoopTest,BasicLoop)138 TEST_F(CApiWhileLoopTest, BasicLoop) {
139 Init(2);
140
141 // Validate TF_WhileParams returned by TF_NewWhile()
142 EXPECT_TRUE(params_->body_graph != nullptr);
143 EXPECT_TRUE(params_->cond_graph != nullptr);
144
145 EXPECT_EQ(params_->ninputs, 2);
146
147 ASSERT_TRUE(params_->cond_inputs != nullptr);
148 ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
149 EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
150
151 ASSERT_TRUE(params_->body_inputs != nullptr);
152 EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
153 EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
154
155 ASSERT_TRUE(params_->body_outputs != nullptr);
156
157 // Create loop: while (input1 < input2) input1 += input2 + 1
158 TF_Operation* less_than =
159 LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
160 params_->cond_graph, s_);
161 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
162 params_->cond_output = {less_than, 0};
163
164 TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
165 params_->body_graph, s_, "add1");
166 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
167 TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
168 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
169 TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
170 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
171 params_->body_outputs[0] = {add2, 0};
172 params_->body_outputs[1] = params_->body_inputs[1];
173
174 // Finalize while loop
175 ExpectOK();
176
177 // Validate while loop outputs returned by TF_FinishWhile()
178 EXPECT_TRUE(outputs_[0].oper != nullptr);
179 EXPECT_GE(outputs_[0].index, 0);
180 EXPECT_TRUE(outputs_[1].oper != nullptr);
181 EXPECT_GE(outputs_[1].index, 0);
182
183 // Check that cond and body inputs are not present
184 for (int i = 0; i < params_->ninputs; ++i) {
185 string cond_name =
186 ::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
187 string body_name =
188 ::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
189 EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
190 EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
191 }
192
193 // Run the graph
194 Run({-9, 2});
195 ExpectOutputValue(0, 3);
196 ExpectOutputValue(1, 2);
197 }
198
TEST_F(CApiWhileLoopTest,NestedLoop)199 TEST_F(CApiWhileLoopTest, NestedLoop) {
200 Init(2);
201 // Create nested loop:
202 // while (input1 < 6) {
203 // inner_input1 = input1
204 // while (inner_input1 < 3) {
205 // input2 += 1
206 // inner_input1 += 2
207 // }
208 // input1 += input2
209 // }
210 //
211 // Expected execution with initial values input1 = input2 = 0:
212 //
213 // outer inner inner_
214 // step# step# input1 input2 input1
215 // ------------------------------------
216 // 0 0 0 0 0
217 // 0 1 0 1 2
218 // 0 2 0 2 4
219 // 0 - 2 2 -
220 // 1 0 2 2 2
221 // 1 1 2 3 4
222 // 1 - 5 3 -
223 // 2 0 5 3 5
224 // 2 - 8 3 -
225
226 // Create outer cond graph
227 TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
228 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
229 TF_Operation* less_than =
230 LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
231 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
232 params_->cond_output = {less_than, 0};
233
234 // Create outer body graph
235 // Init inner graph
236 TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
237 TF_WhileParams inner_params =
238 TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
239 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
240 inner_params.name = "inner_loop";
241
242 // Create inner cond graph
243 TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
244 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
245 TF_Operation* inner_less_than = LessThan(
246 inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
247 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
248 inner_params.cond_output = {inner_less_than, 0};
249
250 // Create inner body graph
251 TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
252 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
253 TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
254 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
255
256 TF_Operation* input2_add =
257 Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
258 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
259 inner_params.body_outputs[1] = {input2_add, 0};
260
261 TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
262 inner_params.body_graph, s_, "add2");
263 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
264 inner_params.body_outputs[0] = {inner_input1_add, 0};
265
266 // Finalize inner graph
267 TF_Output inner_outputs[2] = {{nullptr, -1}};
268 TF_FinishWhile(&inner_params, s_, inner_outputs);
269 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
270
271 TF_Operation* input1_add =
272 Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
273 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
274 params_->body_outputs[0] = {input1_add, 0};
275
276 params_->body_outputs[1] = inner_outputs[1];
277
278 // Finalize outer graph
279 ExpectOK();
280
281 // Check for a few expected nodes
282 const char* node_name = "test_loop/cond/scalar";
283 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
284 node_name = "test_loop/body/add";
285 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
286 node_name = "test_loop/body/inner_loop/body/one";
287 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
288 node_name = "test_loop/body/inner_loop/cond/less_than";
289 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
290
291 // Run the graph
292 Run({0, 0});
293 ExpectOutputValue(0, 8);
294 ExpectOutputValue(1, 3);
295 }
296
TEST_F(CApiWhileLoopTest,UnsetCondOutput)297 TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
298 Init(1);
299 params_->body_outputs[0] = params_->body_inputs[0];
300 ExpectError(TF_INVALID_ARGUMENT,
301 "TF_WhileParams `cond_output` field isn't set");
302 }
303
TEST_F(CApiWhileLoopTest,WrongCondOutputType)304 TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
305 Init(1);
306 params_->cond_output = params_->cond_inputs[0];
307 params_->body_outputs[0] = params_->body_inputs[0];
308 ExpectError(TF_INVALID_ARGUMENT,
309 "BuildWhileLoop: 'cond' argument must return a boolean output, "
310 "got int32");
311 }
312
TEST_F(CApiWhileLoopTest,InvalidCondOutputNode)313 TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
314 Init(1);
315 // Try to reuse node from parent graph
316 params_->cond_output = inputs_[0];
317 params_->body_outputs[0] = params_->body_inputs[0];
318 // TODO(skyewm): this error message could be more informative. Add explicit
319 // checks for this case in the while loop implementation?
320 ExpectError(TF_INVALID_ARGUMENT,
321 "Requested return tensor 'p0:0' not found in graph def");
322 }
323
TEST_F(CApiWhileLoopTest,InvalidCondOutputIndex)324 TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
325 Init(1);
326 CreateCondGraph();
327 params_->cond_output.index = 100;
328 params_->body_outputs[0] = params_->body_inputs[0];
329 ExpectError(TF_INVALID_ARGUMENT,
330 "Invalid return output 100 of node 'less_than', which has 1 "
331 "output(s)");
332 }
333
334 // TODO(skyewm): test bad cond output shape
335
TEST_F(CApiWhileLoopTest,UnsetBodyOutput)336 TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
337 Init(1);
338 CreateCondGraph();
339 ExpectError(TF_INVALID_ARGUMENT,
340 "TF_WhileParams `body_outputs[0]` field isn't set");
341 }
342
343 // TODO(skyewm): enable this when it works (currently doesn't error)
344 // TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
345 // Init(1);
346 // CreateCondGraph();
347 // TF_Operation* double_scalar =
348 // ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
349 // params_->body_outputs[0] = {double_scalar, 0};
350 // ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
351 // }
352
TEST_F(CApiWhileLoopTest,InvalidBodyOutputNode)353 TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
354 Init(1);
355 CreateCondGraph();
356 // Try to reuse node from parent graph
357 params_->body_outputs[0] = inputs_[0];
358 // TODO(skyewm): this error message could be more informative. Add explicit
359 // checks for this case in the while loop implementation?
360 ExpectError(TF_INVALID_ARGUMENT,
361 "Requested return tensor 'p0:0' not found in graph def");
362 }
363
364 // TODO(skyewm): enable this when it works (currently segfaults!)
365 // TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
366 // Init(1);
367 // CreateCondGraph();
368 // params_->body_outputs[0] = params_->body_inputs[0];
369 // params_->body_outputs[0].index = 100;
370 // ExpectError(TF_INVALID_ARGUMENT,
371 // "Invalid return output 100 of node 'less_than', which has 1 "
372 // "output(s)");
373 // }
374
375 // TODO(skyewm): test bad body output shape
376
TEST_F(CApiWhileLoopTest,NullName)377 TEST_F(CApiWhileLoopTest, NullName) {
378 Init(1);
379 CreateCondGraph();
380 params_->body_outputs[0] = params_->body_inputs[0];
381 params_->name = nullptr;
382 ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
383 }
384
TEST_F(CApiWhileLoopTest,WrongGraph)385 TEST_F(CApiWhileLoopTest, WrongGraph) {
386 Init(1);
387 CreateCondGraph();
388 // Set body output to output from outer graph
389 params_->body_outputs[0] = inputs_[0];
390 // TODO(skyewm): improve error message
391 ExpectError(TF_INVALID_ARGUMENT,
392 "Requested return tensor 'p0:0' not found in graph def");
393 }
394
TEST_F(CApiWhileLoopTest,BadTypes)395 TEST_F(CApiWhileLoopTest, BadTypes) {
396 Init(1);
397 CreateCondGraph();
398 // Op that has a float input + output
399 TF_OperationDescription* desc = TF_NewOperation(
400 params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
401 TF_AddInput(desc, params_->body_inputs[0]);
402 TF_FinishOperation(desc, s_);
403 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
404 string msg(TF_Message(s_));
405 EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
406 "building NodeDef 'float_op'"),
407 msg.npos);
408 TF_AbortWhile(params_.get());
409 }
410
411 // This is a basic test to make sure the C++ gradient code can handle while
412 // loops created by the C API (which calls the C++ API under the hood). There
413 // are more while loop gradient tests in cc/framework/while_gradients_test.cc.
TEST_F(CApiWhileLoopTest,Gradients)414 TEST_F(CApiWhileLoopTest, Gradients) {
415 Init(1);
416
417 // Create loop: while (i < 10) i += 1
418 TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
419 TF_Operation* less_than =
420 LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
421 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
422 params_->cond_output = {less_than, 0};
423
424 TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
425 TF_Operation* add =
426 Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
427 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
428 params_->body_outputs[0] = {add, 0};
429
430 ExpectOK();
431
432 // Create backprop graph
433 TF_Output grad_output;
434 TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
435 nullptr, s_, &grad_output);
436 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
437
438 // Run gradient
439 Run({grad_output}, {0});
440 ExpectOutputValue(0, 1);
441 }
442
443 } // namespace
444