1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include "tensorflow/c/c_test_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 using tensorflow::GraphDef;
24 
25 namespace {
26 
27 class CApiWhileLoopTest : public ::testing::Test {
28  protected:
CApiWhileLoopTest()29   CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
30 
~CApiWhileLoopTest()31   ~CApiWhileLoopTest() override {
32     TF_DeleteGraph(graph_);
33     TF_DeleteStatus(s_);
34   }
35 
Init(int ninputs)36   void Init(int ninputs) {
37     DCHECK(inputs_.empty());
38     DCHECK_GT(ninputs, 0);
39 
40     for (int i = 0; i < ninputs; ++i) {
41       TF_Operation* placeholder = Placeholder(
42           graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
43       DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
44       inputs_.push_back({placeholder, 0});
45     }
46 
47     original_graph_description_ = GraphDebugString();
48 
49     params_.reset(new TF_WhileParams(
50         TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
51     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
52     ASSERT_EQ(original_graph_description_, GraphDebugString())
53         << "TF_NewWhile() altered graph";
54 
55     params_->name = "test_loop";
56 
57     // Initialize outputs_ so we can easily detect errors/bugs
58     outputs_.resize(ninputs, {nullptr, -1});
59   }
60 
ExpectOK()61   void ExpectOK() {
62     TF_FinishWhile(params_.get(), s_, &outputs_[0]);
63     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
64   }
65 
ExpectError(TF_Code expected_code,const string & expected_msg)66   void ExpectError(TF_Code expected_code, const string& expected_msg) {
67     TF_FinishWhile(params_.get(), s_, &outputs_[0]);
68     EXPECT_EQ(expected_code, TF_GetCode(s_));
69     EXPECT_EQ(expected_msg, TF_Message(s_));
70     // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
71     // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
72     //     "TF_FinishWhile() altered graph on error";
73   }
74 
Run(std::initializer_list<int> input_values)75   void Run(std::initializer_list<int> input_values) {
76     Run(outputs_, input_values);
77   }
78 
Run(const std::vector<TF_Output> & run_outputs,std::initializer_list<int> input_values)79   void Run(const std::vector<TF_Output>& run_outputs,
80            std::initializer_list<int> input_values) {
81     DCHECK_EQ(inputs_.size(), input_values.size());
82     std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
83     int i = 0;
84     for (int v : input_values) {
85       inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
86       ++i;
87     }
88     // TODO(skyewm): use std::make_unique or absl::make_unique when possible.
89     csession_.reset(new CSession(graph_, s_));
90     csession_->SetInputs(inputs);
91     csession_->SetOutputs(run_outputs);
92     csession_->Run(s_);
93     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94   }
95 
ExpectOutputValue(int idx,int expected_value)96   void ExpectOutputValue(int idx, int expected_value) {
97     TF_Tensor* out = csession_->output_tensor(idx);
98     ASSERT_TRUE(out != nullptr);
99     EXPECT_EQ(TF_INT32, TF_TensorType(out));
100     EXPECT_EQ(0, TF_NumDims(out));
101     ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
102     int32_t* data = static_cast<int32_t*>(TF_TensorData(out));
103     EXPECT_EQ(expected_value, *data);
104   }
105 
106   // Create a valid conditional graph. Useful for testing unrelated errors.
CreateCondGraph()107   void CreateCondGraph() {
108     TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
109     TF_Operation* less_than =
110         LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
111     DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
112     params_->cond_output = {less_than, 0};
113   }
114 
GraphDebugString() const115   string GraphDebugString() const {
116     TF_Buffer* buf = TF_NewBuffer();
117     TF_GraphToGraphDef(graph_, buf, s_);
118     DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119     GraphDef def;
120     bool success = def.ParseFromArray(buf->data, buf->length);
121     DCHECK(success);
122     TF_DeleteBuffer(buf);
123     return def.DebugString();
124   }
125 
126   TF_Status* s_;
127   TF_Graph* graph_;
128   std::vector<TF_Output> inputs_;   // The inputs to the while loop
129   std::vector<TF_Output> outputs_;  // The final outputs of the while loop
130   std::unique_ptr<TF_WhileParams> params_;
131   std::unique_ptr<CSession> csession_;
132 
133  private:
134   // Used to verify that errors don't change graph_
135   string original_graph_description_;
136 };
137 
TEST_F(CApiWhileLoopTest,BasicLoop)138 TEST_F(CApiWhileLoopTest, BasicLoop) {
139   Init(2);
140 
141   // Validate TF_WhileParams returned by TF_NewWhile()
142   EXPECT_TRUE(params_->body_graph != nullptr);
143   EXPECT_TRUE(params_->cond_graph != nullptr);
144 
145   EXPECT_EQ(params_->ninputs, 2);
146 
147   ASSERT_TRUE(params_->cond_inputs != nullptr);
148   ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
149   EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
150 
151   ASSERT_TRUE(params_->body_inputs != nullptr);
152   EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
153   EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
154 
155   ASSERT_TRUE(params_->body_outputs != nullptr);
156 
157   // Create loop: while (input1 < input2) input1 += input2 + 1
158   TF_Operation* less_than =
159       LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
160                params_->cond_graph, s_);
161   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
162   params_->cond_output = {less_than, 0};
163 
164   TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
165                            params_->body_graph, s_, "add1");
166   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
167   TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
168   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
169   TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
170   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
171   params_->body_outputs[0] = {add2, 0};
172   params_->body_outputs[1] = params_->body_inputs[1];
173 
174   // Finalize while loop
175   ExpectOK();
176 
177   // Validate while loop outputs returned by TF_FinishWhile()
178   EXPECT_TRUE(outputs_[0].oper != nullptr);
179   EXPECT_GE(outputs_[0].index, 0);
180   EXPECT_TRUE(outputs_[1].oper != nullptr);
181   EXPECT_GE(outputs_[1].index, 0);
182 
183   // Check that cond and body inputs are not present
184   for (int i = 0; i < params_->ninputs; ++i) {
185     string cond_name =
186         ::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
187     string body_name =
188         ::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
189     EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
190     EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
191   }
192 
193   // Run the graph
194   Run({-9, 2});
195   ExpectOutputValue(0, 3);
196   ExpectOutputValue(1, 2);
197 }
198 
TEST_F(CApiWhileLoopTest,NestedLoop)199 TEST_F(CApiWhileLoopTest, NestedLoop) {
200   Init(2);
201   // Create nested loop:
202   //  while (input1 < 6) {
203   //    inner_input1 = input1
204   //    while (inner_input1 < 3) {
205   //      input2 += 1
206   //      inner_input1 += 2
207   //    }
208   //    input1 += input2
209   //  }
210   //
211   // Expected execution with initial values input1 = input2 = 0:
212   //
213   // outer inner               inner_
214   // step# step# input1 input2 input1
215   // ------------------------------------
216   //   0     0     0      0      0
217   //   0     1     0      1      2
218   //   0     2     0      2      4
219   //   0     -     2      2      -
220   //   1     0     2      2      2
221   //   1     1     2      3      4
222   //   1     -     5      3      -
223   //   2     0     5      3      5
224   //   2     -     8      3      -
225 
226   // Create outer cond graph
227   TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
228   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
229   TF_Operation* less_than =
230       LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
231   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
232   params_->cond_output = {less_than, 0};
233 
234   // Create outer body graph
235   // Init inner graph
236   TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
237   TF_WhileParams inner_params =
238       TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
239   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
240   inner_params.name = "inner_loop";
241 
242   // Create inner cond graph
243   TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
244   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
245   TF_Operation* inner_less_than = LessThan(
246       inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
247   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
248   inner_params.cond_output = {inner_less_than, 0};
249 
250   // Create inner body graph
251   TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
252   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
253   TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
254   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
255 
256   TF_Operation* input2_add =
257       Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
258   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
259   inner_params.body_outputs[1] = {input2_add, 0};
260 
261   TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
262                                        inner_params.body_graph, s_, "add2");
263   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
264   inner_params.body_outputs[0] = {inner_input1_add, 0};
265 
266   // Finalize inner graph
267   TF_Output inner_outputs[2] = {{nullptr, -1}};
268   TF_FinishWhile(&inner_params, s_, inner_outputs);
269   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
270 
271   TF_Operation* input1_add =
272       Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
273   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
274   params_->body_outputs[0] = {input1_add, 0};
275 
276   params_->body_outputs[1] = inner_outputs[1];
277 
278   // Finalize outer graph
279   ExpectOK();
280 
281   // Check for a few expected nodes
282   const char* node_name = "test_loop/cond/scalar";
283   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
284   node_name = "test_loop/body/add";
285   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
286   node_name = "test_loop/body/inner_loop/body/one";
287   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
288   node_name = "test_loop/body/inner_loop/cond/less_than";
289   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
290 
291   // Run the graph
292   Run({0, 0});
293   ExpectOutputValue(0, 8);
294   ExpectOutputValue(1, 3);
295 }
296 
TEST_F(CApiWhileLoopTest,UnsetCondOutput)297 TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
298   Init(1);
299   params_->body_outputs[0] = params_->body_inputs[0];
300   ExpectError(TF_INVALID_ARGUMENT,
301               "TF_WhileParams `cond_output` field isn't set");
302 }
303 
TEST_F(CApiWhileLoopTest,WrongCondOutputType)304 TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
305   Init(1);
306   params_->cond_output = params_->cond_inputs[0];
307   params_->body_outputs[0] = params_->body_inputs[0];
308   ExpectError(TF_INVALID_ARGUMENT,
309               "BuildWhileLoop: 'cond' argument must return a boolean output, "
310               "got int32");
311 }
312 
TEST_F(CApiWhileLoopTest,InvalidCondOutputNode)313 TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
314   Init(1);
315   // Try to reuse node from parent graph
316   params_->cond_output = inputs_[0];
317   params_->body_outputs[0] = params_->body_inputs[0];
318   // TODO(skyewm): this error message could be more informative. Add explicit
319   // checks for this case in the while loop implementation?
320   ExpectError(TF_INVALID_ARGUMENT,
321               "Requested return tensor 'p0:0' not found in graph def");
322 }
323 
TEST_F(CApiWhileLoopTest,InvalidCondOutputIndex)324 TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
325   Init(1);
326   CreateCondGraph();
327   params_->cond_output.index = 100;
328   params_->body_outputs[0] = params_->body_inputs[0];
329   ExpectError(TF_INVALID_ARGUMENT,
330               "Invalid return output 100 of node 'less_than', which has 1 "
331               "output(s)");
332 }
333 
334 // TODO(skyewm): test bad cond output shape
335 
TEST_F(CApiWhileLoopTest,UnsetBodyOutput)336 TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
337   Init(1);
338   CreateCondGraph();
339   ExpectError(TF_INVALID_ARGUMENT,
340               "TF_WhileParams `body_outputs[0]` field isn't set");
341 }
342 
343 // TODO(skyewm): enable this when it works (currently doesn't error)
344 // TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
345 //   Init(1);
346 //   CreateCondGraph();
347 //   TF_Operation* double_scalar =
348 //       ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
349 //   params_->body_outputs[0] = {double_scalar, 0};
350 //   ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
351 // }
352 
TEST_F(CApiWhileLoopTest,InvalidBodyOutputNode)353 TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
354   Init(1);
355   CreateCondGraph();
356   // Try to reuse node from parent graph
357   params_->body_outputs[0] = inputs_[0];
358   // TODO(skyewm): this error message could be more informative. Add explicit
359   // checks for this case in the while loop implementation?
360   ExpectError(TF_INVALID_ARGUMENT,
361               "Requested return tensor 'p0:0' not found in graph def");
362 }
363 
364 // TODO(skyewm): enable this when it works (currently segfaults!)
365 // TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
366 //   Init(1);
367 //   CreateCondGraph();
368 //   params_->body_outputs[0] = params_->body_inputs[0];
369 //   params_->body_outputs[0].index = 100;
370 //   ExpectError(TF_INVALID_ARGUMENT,
371 //               "Invalid return output 100 of node 'less_than', which has 1 "
372 //               "output(s)");
373 // }
374 
375 // TODO(skyewm): test bad body output shape
376 
TEST_F(CApiWhileLoopTest,NullName)377 TEST_F(CApiWhileLoopTest, NullName) {
378   Init(1);
379   CreateCondGraph();
380   params_->body_outputs[0] = params_->body_inputs[0];
381   params_->name = nullptr;
382   ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
383 }
384 
TEST_F(CApiWhileLoopTest,WrongGraph)385 TEST_F(CApiWhileLoopTest, WrongGraph) {
386   Init(1);
387   CreateCondGraph();
388   // Set body output to output from outer graph
389   params_->body_outputs[0] = inputs_[0];
390   // TODO(skyewm): improve error message
391   ExpectError(TF_INVALID_ARGUMENT,
392               "Requested return tensor 'p0:0' not found in graph def");
393 }
394 
TEST_F(CApiWhileLoopTest,BadTypes)395 TEST_F(CApiWhileLoopTest, BadTypes) {
396   Init(1);
397   CreateCondGraph();
398   // Op that has a float input + output
399   TF_OperationDescription* desc = TF_NewOperation(
400       params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
401   TF_AddInput(desc, params_->body_inputs[0]);
402   TF_FinishOperation(desc, s_);
403   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
404   string msg(TF_Message(s_));
405   EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
406                      "building NodeDef 'float_op'"),
407             msg.npos);
408   TF_AbortWhile(params_.get());
409 }
410 
411 // This is a basic test to make sure the C++ gradient code can handle while
412 // loops created by the C API (which calls the C++ API under the hood). There
413 // are more while loop gradient tests in cc/framework/while_gradients_test.cc.
TEST_F(CApiWhileLoopTest,Gradients)414 TEST_F(CApiWhileLoopTest, Gradients) {
415   Init(1);
416 
417   // Create loop: while (i < 10) i += 1
418   TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
419   TF_Operation* less_than =
420       LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
421   DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
422   params_->cond_output = {less_than, 0};
423 
424   TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
425   TF_Operation* add =
426       Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
427   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
428   params_->body_outputs[0] = {add, 0};
429 
430   ExpectOK();
431 
432   // Create backprop graph
433   TF_Output grad_output;
434   TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
435                   nullptr, s_, &grad_output);
436   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
437 
438   // Run gradient
439   Run({grad_output}, {0});
440   ExpectOutputValue(0, 1);
441 }
442 
443 }  // namespace
444