1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/algorithm.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/graph_def_builder.h"
23 #include "tensorflow/core/graph/graph_def_builder_util.h"
24 #include "tensorflow/core/graph/subgraph.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29
30 // TODO(josh11b): Test setting the "device" field of a NodeDef.
31 // TODO(josh11b): Test that feeding won't prune targets.
32
33 namespace tensorflow {
34 namespace {
35
36 REGISTER_OP("TestParams").Output("o: float");
37 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
38 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
39 REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
40 REGISTER_OP("TestBinary")
41 .Input("a: float")
42 .Input("b: float")
43 .Output("o: float");
44
45 // Compares that the order of nodes in 'inputs' respects the
46 // pair orders described in 'ordered_pairs'.
ExpectBefore(const std::vector<std::pair<string,string>> & ordered_pairs,const std::vector<Node * > & inputs,string * error)47 bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
48 const std::vector<Node*>& inputs, string* error) {
49 for (const std::pair<string, string>& pair : ordered_pairs) {
50 const string& before_node = pair.first;
51 const string& after_node = pair.second;
52 bool seen_before = false;
53 bool seen_both = false;
54 for (const Node* node : inputs) {
55 if (!seen_before && after_node == node->name()) {
56 *error = strings::StrCat("Saw ", after_node, " before ", before_node);
57 return false;
58 }
59
60 if (before_node == node->name()) {
61 seen_before = true;
62 } else if (after_node == node->name()) {
63 seen_both = seen_before;
64 break;
65 }
66 }
67 if (!seen_both) {
68 *error = strings::StrCat("didn't see either ", before_node, " or ",
69 after_node);
70 return false;
71 }
72 }
73
74 return true;
75 }
76
TEST(AlgorithmTest,ReversePostOrder)77 TEST(AlgorithmTest, ReversePostOrder) {
78 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
79 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
80 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
81 Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
82 Node* input =
83 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
84 Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
85 BinaryOp("TestMul", w1, {input, 1},
86 b.opts().WithName("t2").WithControlInput(t1));
87 BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
88
89 Graph g(OpRegistry::Global());
90 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
91 std::vector<Node*> order;
92
93 // Test reverse post order:
94 GetReversePostOrder(g, &order);
95
96 // Check that the order respects the dependencies correctly.
97 std::vector<std::pair<string, string>> reverse_orders = {
98 {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"},
99 {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
100 string error;
101 EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
102
103 // A false ordering should fail the check.
104 reverse_orders = {{"input", "W1"}};
105 EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
106
107 // Test post order:
108 GetPostOrder(g, &order);
109
110 // Check that the order respects the dependencies correctly.
111 std::vector<std::pair<string, string>> orders = {
112 {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"},
113 {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
114 EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
115
116 // A false ordering should fail the check.
117 orders = {{"W1", "t3"}};
118 EXPECT_FALSE(ExpectBefore(orders, order, &error));
119 }
120
TEST(AlgorithmTest,ReversePostOrderStable)121 TEST(AlgorithmTest, ReversePostOrderStable) {
122 int64 run_count = 100;
123 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
124
125 for (int64 i = 0; i < run_count; ++i) {
126 // One source of nondeterminism comes from unordered set with key of a
127 // pointer type, for example the order of FlatSet<Node*> depends on the
128 // raw pointer value of Node. Stable post order suppose to remove this
129 // nondeterminism by enforcing an ordering based on node ids.
130 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
131 string error;
132 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
133 Node* input =
134 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
135 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2"));
136 // Insert different number of nodes between the allocation of t2 and t3,
137 // this creates enough entropy in the memory distance between t2 and t3 thus
138 // forces them to have randomized ordering had stable DFS was not
139 // implemented correctly.
140 for (int64 j = 0; j < i; ++j) {
141 BinaryOp("TestMul", w1, {input, 1},
142 b.opts().WithName(strings::StrCat("internal", j)));
143 }
144
145 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
146
147 Graph g(OpRegistry::Global());
148 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
149 std::vector<Node*> order;
150
151 // Test reverse post order generates expected ordering.
152 GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName());
153 EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
154 }
155 }
156
TEST(AlgorithmTest,PostOrderWithEdgeFilter)157 TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
158 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
159 Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
160 Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
161 Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
162 Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
163
164 Graph g(OpRegistry::Global());
165 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
166
167 g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
168
169 std::vector<Node*> post_order;
170 auto edge_filter = [&](const Edge& e) {
171 return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
172 };
173
174 std::vector<Node*> expected_post_order = {
175 g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
176 g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
177
178 std::vector<Node*> expected_reverse_post_order = expected_post_order;
179 std::reverse(expected_reverse_post_order.begin(),
180 expected_reverse_post_order.end());
181
182 GetPostOrder(g, &post_order, /*stable_comparator=*/{},
183 /*edge_filter=*/edge_filter);
184
185 ASSERT_EQ(expected_post_order.size(), post_order.size());
186 for (int i = 0; i < post_order.size(); i++) {
187 CHECK_EQ(post_order[i], expected_post_order[i])
188 << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
189 }
190
191 std::vector<Node*> reverse_post_order;
192 GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
193 /*edge_filter=*/edge_filter);
194
195 ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
196 for (int i = 0; i < reverse_post_order.size(); i++) {
197 CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
198 << reverse_post_order[i]->name() << " vs. "
199 << expected_reverse_post_order[i]->name();
200 }
201 }
202 } // namespace
203 } // namespace tensorflow
204