1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/control_flow.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/while_loop.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace {
LessThanTenCond(const Scope & scope,const std::vector<Output> & inputs,Output * output)29 Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
30 Output* output) {
31 *output = ops::Less(scope, inputs[0], 10);
32 return scope.status();
33 }
34
AddOneBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)35 Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
36 std::vector<Output>* outputs) {
37 outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
38 return scope.status();
39 }
40
NestedLoopBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)41 Status NestedLoopBody(const Scope& scope, const std::vector<Output>& inputs,
42 std::vector<Output>* outputs) {
43 return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs,
44 LessThanTenCond, AddOneBody, "inner_loop",
45 outputs);
46 }
47
TEST(ValidateControlFlowTest,InputsFromDifferentFrames)48 TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
49 Scope scope = Scope::NewRootScope().ExitOnError();
50 std::vector<Output> inputs;
51 inputs.push_back(ops::Placeholder(scope, DT_INT32));
52 std::vector<Output> outputs;
53 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs,
54 LessThanTenCond, NestedLoopBody,
55 "outer_loop", &outputs));
56 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
57 TF_ASSERT_OK(scope.ToGraph(graph.get()));
58 // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame
59 // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'.
60 std::vector<ControlFlowInfo> info;
61 Status status = BuildControlFlowInfo(graph.get(), &info);
62 EXPECT_FALSE(status.ok());
63 EXPECT_TRUE(absl::StrContains(status.error_message(),
64 "has inputs from different frames"))
65 << status.error_message();
66 EXPECT_TRUE(absl::StrContains(status.error_message(),
67 "{{node outer/body/inner/Merge}}"))
68 << status.error_message();
69 EXPECT_TRUE(absl::StrContains(status.error_message(),
70 "{{node outer/body/inner/Enter}}"))
71 << status.error_message();
72 EXPECT_TRUE(
73 absl::StrContains(status.error_message(), "{{node outer/Switch}}"))
74 << status.error_message();
75 }
76
TEST(ValidateControlFlowTest,MismatchedParentFrames)77 TEST(ValidateControlFlowTest, MismatchedParentFrames) {
78 Scope scope = Scope::NewRootScope().ExitOnError();
79 std::vector<Output> inputs;
80 inputs.push_back(ops::Placeholder(scope, DT_INT32));
81 std::vector<Output> outputs;
82 TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
83 "test_loop", &outputs));
84 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
85 TF_ASSERT_OK(scope.ToGraph(graph.get()));
86 Node* enter_1 = nullptr;
87 for (Node* node : graph->op_nodes()) {
88 if (IsEnter(node)) {
89 enter_1 = node;
90 }
91 }
92 ASSERT_TRUE(enter_1 != nullptr);
93
94 NodeDef enter;
95 enter.set_name("Enter2");
96 enter.set_op("Enter");
97 (*enter.mutable_attr())["T"].set_type(DT_INT32);
98 (*enter.mutable_attr())["frame_name"].set_s("test_loop");
99 *enter.add_input() = "Enter";
100 Status status;
101 Node* enter_2 = graph->AddNode(enter, &status);
102 TF_ASSERT_OK(status);
103 graph->AddControlEdge(enter_1, enter_2);
104
105 // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop")
106 // For node 'Enter', the parent frame of "test_loop" is empty.
107 // For node 'Enter2', the parent frame of "test_loop" is "test_loop".
108 std::vector<ControlFlowInfo> info;
109 status = BuildControlFlowInfo(graph.get(), &info);
110 EXPECT_FALSE(status.ok());
111 EXPECT_TRUE(
112 absl::StrContains(status.error_message(), "Mismatched parent frames"))
113 << status.error_message();
114 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Enter2}}"))
115 << status.error_message();
116 }
117
TEST(ValidateControlFlowTest,TwoLoopCond)118 TEST(ValidateControlFlowTest, TwoLoopCond) {
119 // Test that one frame has at most one LoopCond node. This is necessary for
120 // functionalize control flow.
121 Scope scope = Scope::NewRootScope().ExitOnError();
122 std::vector<Output> inputs;
123 inputs.push_back(ops::Placeholder(scope, DT_INT32));
124 std::vector<Output> outputs;
125 TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
126 "test_loop", &outputs));
127 outputs.clear();
128 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs,
129 LessThanTenCond, AddOneBody, "test_loop",
130 &outputs, false));
131 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
132 TF_ASSERT_OK(scope.ToGraph(graph.get()));
133 std::vector<ControlFlowInfo> info;
134 Status status = BuildControlFlowInfo(graph.get(), &info);
135 EXPECT_FALSE(status.ok());
136 EXPECT_TRUE(
137 absl::StrContains(status.error_message(), "more than one LoopCond node"))
138 << status.error_message();
139 EXPECT_TRUE(
140 absl::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
141 << status.error_message();
142 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node LoopCond}}"))
143 << status.error_message();
144 }
145
146 } // namespace
147 } // namespace tensorflow
148