1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/algorithm.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/graph_def_builder.h"
23 #include "tensorflow/core/graph/graph_def_builder_util.h"
24 #include "tensorflow/core/graph/subgraph.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 // TODO(josh11b): Test setting the "device" field of a NodeDef.
31 // TODO(josh11b): Test that feeding won't prune targets.
32 
33 namespace tensorflow {
34 namespace {
35 
36 REGISTER_OP("TestParams").Output("o: float");
37 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
38 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
39 REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
40 REGISTER_OP("TestBinary")
41     .Input("a: float")
42     .Input("b: float")
43     .Output("o: float");
44 
45 // Compares that the order of nodes in 'inputs' respects the
46 // pair orders described in 'ordered_pairs'.
ExpectBefore(const std::vector<std::pair<string,string>> & ordered_pairs,const std::vector<Node * > & inputs,string * error)47 bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
48                   const std::vector<Node*>& inputs, string* error) {
49   for (const std::pair<string, string>& pair : ordered_pairs) {
50     const string& before_node = pair.first;
51     const string& after_node = pair.second;
52     bool seen_before = false;
53     bool seen_both = false;
54     for (const Node* node : inputs) {
55       if (!seen_before && after_node == node->name()) {
56         *error = strings::StrCat("Saw ", after_node, " before ", before_node);
57         return false;
58       }
59 
60       if (before_node == node->name()) {
61         seen_before = true;
62       } else if (after_node == node->name()) {
63         seen_both = seen_before;
64         break;
65       }
66     }
67     if (!seen_both) {
68       *error = strings::StrCat("didn't see either ", before_node, " or ",
69                                after_node);
70       return false;
71     }
72   }
73 
74   return true;
75 }
76 
TEST(AlgorithmTest,ReversePostOrder)77 TEST(AlgorithmTest, ReversePostOrder) {
78   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
79   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
80   Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
81   Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
82   Node* input =
83       SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
84   Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
85   BinaryOp("TestMul", w1, {input, 1},
86            b.opts().WithName("t2").WithControlInput(t1));
87   BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
88 
89   Graph g(OpRegistry::Global());
90   TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
91   std::vector<Node*> order;
92 
93   // Test reverse post order:
94   GetReversePostOrder(g, &order);
95 
96   // Check that the order respects the dependencies correctly.
97   std::vector<std::pair<string, string>> reverse_orders = {
98       {"W1", "input"}, {"W1", "t1"},    {"W1", "t2"}, {"W1", "t3"},
99       {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
100   string error;
101   EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
102 
103   // A false ordering should fail the check.
104   reverse_orders = {{"input", "W1"}};
105   EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
106 
107   // Test post order:
108   GetPostOrder(g, &order);
109 
110   // Check that the order respects the dependencies correctly.
111   std::vector<std::pair<string, string>> orders = {
112       {"input", "W1"}, {"t1", "W1"},    {"t2", "W1"}, {"t3", "W1"},
113       {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
114   EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
115 
116   // A false ordering should fail the check.
117   orders = {{"W1", "t3"}};
118   EXPECT_FALSE(ExpectBefore(orders, order, &error));
119 }
120 
TEST(AlgorithmTest,ReversePostOrderStable)121 TEST(AlgorithmTest, ReversePostOrderStable) {
122   int64 run_count = 100;
123   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
124 
125   for (int64 i = 0; i < run_count; ++i) {
126     // One source of nondeterminism comes from unordered set with key of a
127     // pointer type, for example the order of FlatSet<Node*> depends on the
128     // raw pointer value of Node. Stable post order suppose to remove this
129     // nondeterminism by enforcing an ordering based on node ids.
130     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
131     string error;
132     Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
133     Node* input =
134         SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
135     BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2"));
136     // Insert different number of nodes between the allocation of t2 and t3,
137     // this creates enough entropy in the memory distance between t2 and t3 thus
138     // forces them to have randomized ordering had stable DFS was not
139     // implemented correctly.
140     for (int64 j = 0; j < i; ++j) {
141       BinaryOp("TestMul", w1, {input, 1},
142                b.opts().WithName(strings::StrCat("internal", j)));
143     }
144 
145     BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
146 
147     Graph g(OpRegistry::Global());
148     TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
149     std::vector<Node*> order;
150 
151     // Test reverse post order generates expected ordering.
152     GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName());
153     EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
154   }
155 }
156 
TEST(AlgorithmTest,PostOrderWithEdgeFilter)157 TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
158   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
159   Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
160   Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
161   Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
162   Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
163 
164   Graph g(OpRegistry::Global());
165   TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
166 
167   g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
168 
169   std::vector<Node*> post_order;
170   auto edge_filter = [&](const Edge& e) {
171     return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
172   };
173 
174   std::vector<Node*> expected_post_order = {
175       g.sink_node(),          g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
176       g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
177 
178   std::vector<Node*> expected_reverse_post_order = expected_post_order;
179   std::reverse(expected_reverse_post_order.begin(),
180                expected_reverse_post_order.end());
181 
182   GetPostOrder(g, &post_order, /*stable_comparator=*/{},
183                /*edge_filter=*/edge_filter);
184 
185   ASSERT_EQ(expected_post_order.size(), post_order.size());
186   for (int i = 0; i < post_order.size(); i++) {
187     CHECK_EQ(post_order[i], expected_post_order[i])
188         << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
189   }
190 
191   std::vector<Node*> reverse_post_order;
192   GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
193                       /*edge_filter=*/edge_filter);
194 
195   ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
196   for (int i = 0; i < reverse_post_order.size(); i++) {
197     CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
198         << reverse_post_order[i]->name() << " vs. "
199         << expected_reverse_post_order[i]->name();
200   }
201 }
202 }  // namespace
203 }  // namespace tensorflow
204