1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/graph/benchmark_testlib.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_util.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/random/simple_philox.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35
36 namespace tensorflow {
37 namespace {
38
39 REGISTER_OP("OneInput").Input("x: float");
40
41 REGISTER_OP("OneOutput").Output("y: float");
42
43 REGISTER_OP("OneInputTwoOutputs")
44 .Input("x: float")
45 .Output("y: float")
46 .Output("z: float");
47
48 REGISTER_OP("TwoInputsOneOutput")
49 .Input("x: float")
50 .Input("y: float")
51 .Output("z: float");
52
53 class GraphTest : public ::testing::Test {
54 protected:
GraphTest()55 GraphTest() : graph_(OpRegistry::Global()) {}
~GraphTest()56 ~GraphTest() override {}
57
VerifyNodes(Node * node,const std::vector<Node * > & expected_in,const std::vector<Node * > & expected_out)58 static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in,
59 const std::vector<Node*>& expected_out) {
60 std::vector<Node*> in;
61 for (const Edge* e : node->in_edges()) {
62 in.push_back(e->src());
63 }
64 EXPECT_EQ(Stringify(expected_in), Stringify(in));
65
66 std::vector<Node*> out;
67 for (const Edge* e : node->out_edges()) {
68 out.push_back(e->dst());
69 }
70 EXPECT_EQ(Stringify(expected_out), Stringify(out));
71 }
72
VerifyGraphStats()73 void VerifyGraphStats() {
74 int nodes = 0;
75 for (const Node* n : graph_.nodes()) {
76 VLOG(1) << n->id();
77 ++nodes;
78 }
79 EXPECT_EQ(nodes, graph_.num_nodes());
80 int edges = 0;
81 for (const Edge* e : graph_.edges()) {
82 VLOG(1) << e->id();
83 ++edges;
84 }
85 EXPECT_EQ(edges, graph_.num_edges());
86 }
87
AddNodeWithName(const string & name)88 Node* AddNodeWithName(const string& name) {
89 Node* node;
90 TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
91 return node;
92 }
93
FromNodeDef(const string & name,const string & node_type,int num_inputs)94 Node* FromNodeDef(const string& name, const string& node_type,
95 int num_inputs) {
96 auto builder = NodeDefBuilder(name, node_type);
97 for (int i = 0; i < num_inputs; ++i) {
98 builder = builder.Input(strings::StrCat("node_", i), i, DT_FLOAT);
99 }
100
101 NodeDef node_def;
102 TF_CHECK_OK(builder.Finalize(&node_def));
103
104 Status s;
105 Node* node = graph_.AddNode(node_def, &s);
106 TF_CHECK_OK(s);
107 return node;
108 }
109
FromGraphDef(const string & gdef_ascii)110 void FromGraphDef(const string& gdef_ascii) {
111 GraphDef gdef;
112 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef));
113 GraphConstructorOptions opts;
114 TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef, &graph_));
115 }
116
FindNode(const string & name)117 Node* FindNode(const string& name) {
118 for (Node* node : graph_.nodes()) {
119 if (node->name() == name) return node;
120 }
121 LOG(FATAL) << name;
122 }
123
ControlEdgeExistsInGraphOrNodeDef(const Node * src,const Node * dst)124 bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
125 for (const Edge* e : dst->in_edges()) {
126 if (e->IsControlEdge() && e->src() == src &&
127 e->src_output() == Graph::kControlSlot &&
128 e->dst_input() == Graph::kControlSlot) {
129 return true;
130 }
131 }
132 std::string control_edge_name = strings::StrCat("^", src->name());
133 for (int i = 0; i < dst->def().input_size(); ++i) {
134 if (dst->def().input(i) == control_edge_name) {
135 return true;
136 }
137 }
138 return false;
139 }
140
141 Graph graph_;
142
143 private:
144 // Convert a list of nodes to a sorted list of strings so failure messages
145 // are readable.
Stringify(const std::vector<Node * > & nodes)146 static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
147 std::vector<string> result;
148 result.reserve(nodes.size());
149 for (Node* n : nodes) {
150 result.push_back(n->DebugString());
151 }
152 std::sort(result.begin(), result.end());
153 return result;
154 }
155 };
156
TEST_F(GraphTest,Constructor)157 TEST_F(GraphTest, Constructor) {
158 Node* source = graph_.source_node();
159 EXPECT_NE(source, nullptr);
160 Node* sink = graph_.sink_node();
161 EXPECT_NE(sink, nullptr);
162 VerifyNodes(source, {}, {sink});
163 VerifyNodes(sink, {source}, {});
164 EXPECT_EQ(2, graph_.num_node_ids());
165 VerifyGraphStats();
166 }
167
TEST_F(GraphTest,RemoveThenAdd)168 TEST_F(GraphTest, RemoveThenAdd) {
169 AddNodeWithName("A");
170 Node* b = AddNodeWithName("B");
171 const int b_id = b->id();
172 AddNodeWithName("C");
173 EXPECT_EQ(5, graph_.num_node_ids());
174 graph_.RemoveNode(b);
175 EXPECT_EQ(5, graph_.num_node_ids());
176 Node* d = AddNodeWithName("D");
177 EXPECT_NE(b_id, d->id()); // Ids should not be reused.
178 EXPECT_EQ(6, graph_.num_node_ids());
179 VerifyGraphStats();
180 }
181
TEST_F(GraphTest,InNodesAndOutNodes)182 TEST_F(GraphTest, InNodesAndOutNodes) {
183 Node* a = FromNodeDef("A", "OneOutput", 0);
184 Node* b = AddNodeWithName("B");
185 Node* c = FromNodeDef("C", "OneInput", 1);
186 graph_.RemoveNode(b);
187 Node* d = AddNodeWithName("D");
188
189 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
190 graph_.AddControlEdge(a, graph_.sink_node());
191 graph_.AddEdge(a, 0, c, 0);
192 graph_.AddControlEdge(c, graph_.sink_node());
193
194 EXPECT_EQ("A", a->name());
195 VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
196
197 EXPECT_EQ("C", c->name());
198 VerifyNodes(c, {a}, {graph_.sink_node()});
199
200 EXPECT_EQ("D", d->name());
201 VerifyNodes(d, {}, {});
202
203 VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
204 VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
205
206 graph_.RemoveEdge(source_to_a);
207 VerifyNodes(a, {}, {c, graph_.sink_node()});
208 VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a
209
210 graph_.RemoveNode(c);
211 VerifyNodes(a, {}, {graph_.sink_node()}); // no more c
212 VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c
213 EXPECT_EQ(6, graph_.num_node_ids());
214 EXPECT_EQ(5, graph_.num_edge_ids());
215 VerifyGraphStats();
216 }
217
TEST_F(GraphTest,NodeByIndex)218 TEST_F(GraphTest, NodeByIndex) {
219 Node* a = FromNodeDef("A", "OneOutput", 0);
220 Node* c = FromNodeDef("C", "OneInput", 1);
221 graph_.AddEdge(a, 0, c, 0);
222
223 // Ask for 'a' from 'c' by index.
224 const Node* a_copy;
225 TF_ASSERT_OK(c->input_node(0, &a_copy));
226 EXPECT_EQ(a, a_copy);
227
228 const Edge* e;
229 TF_ASSERT_OK(c->input_edge(0, &e));
230 EXPECT_EQ(0, e->dst_input());
231 EXPECT_EQ(a, e->src());
232 EXPECT_EQ(c, e->dst());
233 EXPECT_EQ(0, e->src_output());
234
235 Node* t = FromNodeDef("T", "TwoInputsOneOutput", 2);
236 graph_.AddEdge(a, 0, t, 0);
237 // Weird self edge
238 graph_.AddEdge(t, 0, t, 1);
239
240 const Node* t_0;
241 const Node* t_1;
242 TF_ASSERT_OK(t->input_node(0, &t_0));
243 EXPECT_EQ(a, t_0);
244 TF_ASSERT_OK(t->input_node(1, &t_1));
245 EXPECT_EQ(t, t_1);
246
247 TF_ASSERT_OK(t->input_edge(1, &e));
248 EXPECT_EQ(1, e->dst_input());
249 EXPECT_EQ(t, e->src());
250
251 std::vector<const Edge*> t_input_edges;
252 TF_ASSERT_OK(t->input_edges(&t_input_edges));
253 ASSERT_EQ(2, t_input_edges.size());
254 EXPECT_EQ(a, t_input_edges[0]->src());
255 EXPECT_EQ(e, t_input_edges[1]);
256
257 // Check out of bounds access
258 EXPECT_FALSE(c->input_node(1, &a_copy).ok());
259 EXPECT_FALSE(c->input_node(-1, &a_copy).ok());
260
261 graph_.RemoveNode(a);
262
263 // 'c's input_node entry should be invalidated.
264 Status s = c->input_node(0, &a_copy);
265 EXPECT_FALSE(s.ok());
266
267 // Add two new nodes.
268 Node* a_new = FromNodeDef("A_new", "OneOutput", 0);
269 Node* b_new = FromNodeDef("B_new", "OneOutput", 0);
270
271 // Connect one up to c.
272 graph_.AddEdge(a_new, 0, c, 0);
273 const Edge* a_new_c_edge;
274 TF_ASSERT_OK(c->input_edge(0, &a_new_c_edge));
275
276 // Connect up the second edge
277 graph_.AddEdge(b_new, 0, c, 0);
278 const Edge* b_new_c_edge;
279 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
280
281 // Now remove the old one
282 graph_.RemoveEdge(a_new_c_edge);
283
284 // Check that the second edge can still be retrieved
285 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
286
287 std::vector<const Edge*> c_input_edges;
288 TF_ASSERT_OK(c->input_edges(&c_input_edges));
289 ASSERT_EQ(1, c_input_edges.size());
290 EXPECT_EQ(b_new_c_edge, c_input_edges[0]);
291 }
292
TEST_F(GraphTest,NodeIteration)293 TEST_F(GraphTest, NodeIteration) {
294 // Set up the graph with some holes due to removals.
295 Node* a = FromNodeDef("A", "OneOutput", 0);
296 Node* b = AddNodeWithName("B");
297 Node* c = FromNodeDef("C", "OneInput", 1);
298 graph_.RemoveNode(b);
299 Node* d = AddNodeWithName("D");
300 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
301 graph_.AddControlEdge(a, graph_.sink_node());
302 graph_.AddEdge(a, 0, c, 0);
303 graph_.AddControlEdge(c, graph_.sink_node());
304 graph_.RemoveEdge(source_to_a);
305 graph_.RemoveNode(c);
306
307 // expected = set of all node DebugStrings we expect in the graph
308 std::set<string> expected;
309 expected.insert(graph_.source_node()->DebugString());
310 expected.insert(a->DebugString());
311 expected.insert(d->DebugString());
312 expected.insert(graph_.sink_node()->DebugString());
313
314 // Verify that iterating through ids gets the same set of nodes.
315 std::set<string> actual;
316 for (int id = 0; id < graph_.num_node_ids(); ++id) {
317 Node* node = graph_.FindNodeId(id);
318 if (node != nullptr) {
319 actual.insert(node->DebugString());
320 }
321 }
322 EXPECT_EQ(expected, actual);
323
324 // Verify that range-based for loop gets the same set of nodes.
325 actual.clear();
326 for (Node* node : graph_.nodes()) {
327 actual.insert(node->DebugString());
328 }
329 EXPECT_EQ(expected, actual);
330 VerifyGraphStats();
331 }
332
CheckType(Node * node,bool b)333 static void CheckType(Node* node, bool b) {
334 EXPECT_TRUE(b) << node->DebugString();
335 // Make sure none of the other IsFoo() methods return true.
336 int count = 0;
337 if (node->IsSource()) count++;
338 if (node->IsSink()) count++;
339 if (node->IsOp()) count++;
340 EXPECT_EQ(1, count) << node->DebugString();
341 }
342
TEST_F(GraphTest,Type)343 TEST_F(GraphTest, Type) {
344 Node* op = AddNodeWithName("A");
345 CheckType(graph_.source_node(), graph_.source_node()->IsSource());
346 CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
347 CheckType(op, op->IsOp());
348 VerifyGraphStats();
349 }
350
TEST_F(GraphTest,AddAttr)351 TEST_F(GraphTest, AddAttr) {
352 Node* n1 = AddNodeWithName("A");
353
354 n1->AddAttr("_a", "new_attr");
355
356 string attr;
357 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
358 EXPECT_EQ("new_attr", attr);
359
360 Node* n2 = graph_.CopyNode(n1);
361
362 n1->AddAttr("_b", "new_attr_2");
363
364 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
365 EXPECT_EQ("new_attr", attr);
366 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
367 EXPECT_EQ("new_attr_2", attr);
368
369 EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
370 EXPECT_EQ("new_attr", attr);
371 EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
372 }
373
374 // Convert edge iteration results into a sorted string.
EdgeIter(const Graph & g)375 static string EdgeIter(const Graph& g) {
376 std::vector<std::pair<int, int> > edges;
377 for (const Edge* e : g.edges()) {
378 edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
379 }
380 std::sort(edges.begin(), edges.end());
381 string result;
382 for (auto& p : edges) {
383 strings::StrAppend(&result, p.first, "->", p.second, ";");
384 }
385 return result;
386 }
387
TEST_F(GraphTest,EdgeIteration)388 TEST_F(GraphTest, EdgeIteration) {
389 EXPECT_EQ("0->1;", EdgeIter(graph_));
390
391 Node* a = FromNodeDef("A", "OneInputTwoOutputs", 1);
392 Node* b = FromNodeDef("B", "OneInput", 1);
393 EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected
394
395 graph_.AddEdge(a, 0, b, 0);
396 EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
397
398 graph_.AddControlEdge(graph_.source_node(), a);
399 graph_.AddControlEdge(b, graph_.sink_node());
400 EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
401
402 graph_.AddEdge(a, 1, a, 0);
403 EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
404 VerifyGraphStats();
405 }
406
TEST_F(GraphTest,NewName)407 TEST_F(GraphTest, NewName) {
408 string a1 = graph_.NewName("A");
409 string a2 = graph_.NewName("A");
410 string b1 = graph_.NewName("B");
411 EXPECT_NE(a1, a2);
412 EXPECT_NE(a1, b1);
413 EXPECT_NE(a2, b1);
414 EXPECT_TRUE(absl::StartsWith(a1, "A")) << a1;
415 }
416
TEST_F(GraphTest,IsValidNode)417 TEST_F(GraphTest, IsValidNode) {
418 // Add 1 node to graph_
419 Node* g1_node1;
420 TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
421
422 // Add 2 nodes to graph2
423 Graph graph2(OpRegistry::Global());
424 Node* g2_node1;
425 Node* g2_node2;
426 TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
427 TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
428
429 // nullptr
430 Status s = graph_.IsValidNode(nullptr);
431 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
432 EXPECT_EQ(string("Node is null"), s.error_message());
433
434 // node id_ is too high
435 s = graph_.IsValidNode(g2_node2);
436 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
437 EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
438 s.error_message());
439
440 // valid id_ but different ptr
441 s = graph_.IsValidNode(g2_node1);
442 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
443 EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
444 "Does it belong to a different graph?"),
445 s.error_message());
446 }
447
TEST_F(GraphTest,AddControlEdge)448 TEST_F(GraphTest, AddControlEdge) {
449 FromGraphDef(
450 "node { name: 'A' op: 'OneOutput' }"
451 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
452 "node { name: 'C' op: 'NoOp' } ");
453 Node* a = FindNode("A");
454 Node* b = FindNode("B");
455 Node* c = FindNode("C");
456
457 // Add a control edge.
458 const Edge* edge = graph_.AddControlEdge(c, a);
459 ASSERT_TRUE(edge != nullptr);
460 // Check newly-created edge.
461 EXPECT_EQ(edge->src(), c);
462 EXPECT_EQ(edge->src_output(), Graph::kControlSlot);
463 EXPECT_EQ(edge->dst(), a);
464 EXPECT_EQ(edge->dst_input(), Graph::kControlSlot);
465 // Check A's NodeDef.
466 ASSERT_EQ(a->def().input_size(), 1);
467 EXPECT_EQ(a->def().input(0), "^C");
468
469 // Can add control edge redundant with data edge.
470 edge = graph_.AddControlEdge(a, b);
471 EXPECT_TRUE(edge != nullptr);
472 ASSERT_EQ(b->def().input_size(), 2);
473 EXPECT_EQ(b->def().input(0), "A:0");
474 EXPECT_EQ(b->def().input(1), "^A");
475
476 // Doesn't add edge redundant with control edge.
477 edge = graph_.AddControlEdge(a, b);
478 EXPECT_TRUE(edge == nullptr);
479 EXPECT_EQ(b->def().input_size(), 2);
480
481 // Can add redundant control edge with allow_duplicates.
482 edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true);
483 EXPECT_TRUE(edge != nullptr);
484 // create_duplicate causes the NodeDef not to be updated.
485 ASSERT_EQ(b->def().input_size(), 2);
486 EXPECT_EQ(b->def().input(0), "A:0");
487 EXPECT_EQ(b->def().input(1), "^A");
488
489 // Add control edge from source.
490 edge = graph_.AddControlEdge(graph_.source_node(), b);
491 EXPECT_TRUE(edge != nullptr);
492 // Check that we don't include source input in the NodeDef.
493 EXPECT_EQ(b->def().input_size(), 2);
494 // Doesn't add redundant edge.
495 edge = graph_.AddControlEdge(graph_.source_node(), b);
496 EXPECT_TRUE(edge == nullptr);
497 EXPECT_EQ(b->def().input_size(), 2);
498 }
499
TEST_F(GraphTest,RemoveControlEdge)500 TEST_F(GraphTest, RemoveControlEdge) {
501 FromGraphDef(
502 "node { name: 'A' op: 'OneOutput' }"
503 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
504 "node { name: 'C' op: 'NoOp' } ");
505 Node* a = FindNode("A");
506 Node* b = FindNode("B");
507 Node* c = FindNode("C");
508
509 // Add a control edge.
510 const Edge* edge_1 = graph_.AddControlEdge(c, a);
511 const Edge* edge_2 = graph_.AddControlEdge(a, b);
512 ASSERT_TRUE(edge_1 != nullptr);
513 ASSERT_TRUE(edge_2 != nullptr);
514
515 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a));
516 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
517
518 graph_.RemoveControlEdge(edge_1);
519 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
520 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
521
522 graph_.RemoveControlEdge(edge_2);
523 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
524 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b));
525
526 // Test removing a duplicate control edge.
527 // Note that unless allow_duplicates is true, the duplicate edge
528 // will not be added. That's why we expect edge_4 to be a null
529 // pointer. We are not testing with allow_duplicates set to true,
530 // as that is a highly unlikely use case that does not make much
531 // sense.
532 const Edge* edge_3 = graph_.AddControlEdge(c, a);
533 const Edge* edge_4 = graph_.AddControlEdge(c, a);
534 ASSERT_TRUE(edge_3 != nullptr);
535 ASSERT_TRUE(edge_4 == nullptr);
536
537 graph_.RemoveControlEdge(edge_3);
538 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
539 }
540
TEST_F(GraphTest,UpdateEdge)541 TEST_F(GraphTest, UpdateEdge) {
542 // Build a little graph
543 Node* a = FromNodeDef("A", "OneOutput", 0);
544 Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
545 Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
546 Node* d = FromNodeDef("D", "OneInput", 1);
547
548 graph_.AddControlEdge(graph_.source_node(), a);
549 graph_.AddControlEdge(a, graph_.sink_node());
550 graph_.AddEdge(a, 0, c, 0);
551
552 graph_.AddControlEdge(c, graph_.sink_node());
553 graph_.AddEdge(c, 0, b, 0);
554 graph_.AddEdge(c, 1, d, 0);
555
556 // Initial edge connections
557 EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
558
559 // Update the inputs, expect that Edge a to b (2->3) is now in the graph
560 // and c to b (4->3) no longer appears.
561 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
562 // Check that the edge is connecting the correct nodes.
563 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
564
565 // Update a's 0th output again.
566 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
567 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
568
569 // Update a's 1st output which is out of range.
570 Status s = graph_.UpdateEdge(a, 1, d, 0);
571 EXPECT_FALSE(s.ok());
572 EXPECT_EQ(
573 s.error_message(),
574 "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
575
576 // Update a's 1st input which is out of range.
577 s = graph_.UpdateEdge(c, 0, a, 0);
578 EXPECT_FALSE(s.ok());
579 EXPECT_EQ(
580 s.error_message(),
581 "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
582 }
583
TEST_F(GraphTest,InputEdges)584 TEST_F(GraphTest, InputEdges) {
585 Node* a = FromNodeDef("A", "OneOutput", 0);
586 Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
587 graph_.AddEdge(a, 0, b, 0);
588 std::vector<const Edge*> edges;
589 EXPECT_EQ(error::INVALID_ARGUMENT, b->input_edges(&edges).code());
590 graph_.AddEdge(a, 0, b, 1);
591 TF_EXPECT_OK(b->input_edges(&edges));
592 }
593
TEST_F(GraphTest,AddFunctionLibrary)594 TEST_F(GraphTest, AddFunctionLibrary) {
595 // Basic functionality
596 FunctionDefLibrary proto;
597 *proto.add_function() = test::function::XTimesTwo();
598 *proto.add_function() = test::function::XTimesFour();
599 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
600 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
601 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
602
603 // Duplicate functions are ignored
604 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
605 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
606 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
607
608 // Duplicate names corresponding to different functions trigger an error
609 FunctionDefLibrary error_proto = proto;
610 *error_proto.mutable_function(0)->add_node_def() =
611 error_proto.function(0).node_def(0);
612 Status s = graph_.AddFunctionLibrary(error_proto);
613 EXPECT_FALSE(s.ok());
614 EXPECT_EQ(s.error_message(),
615 "Cannot add function 'XTimesTwo' because a different function with "
616 "the same name already exists.");
617
618 // Function with same name as an existing op triggers an error
619 error_proto = proto;
620 error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
621 s = graph_.AddFunctionLibrary(error_proto);
622 EXPECT_FALSE(s.ok());
623 EXPECT_EQ(s.error_message(),
624 "Cannot add function 'Add' because an op with the same name "
625 "already exists.");
626
627 // Adding a gradient function to an existing function is ok
628 GradientDef* grad = proto.add_gradient();
629 grad->set_function_name("XTimesTwo");
630 grad->set_gradient_func("Undefined"); // undefined funcs in grads are ok
631 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
632 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
633
634 // Duplicate gradients are ignored
635 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
636 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
637
638 // Conflicting gradient triggers an error
639 error_proto = proto;
640 error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
641 s = graph_.AddFunctionLibrary(error_proto);
642 EXPECT_FALSE(s.ok());
643 EXPECT_EQ(s.error_message(),
644 "Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
645 "because it already has gradient function 'Undefined'");
646 }
647
TEST_F(GraphTest,BuildNodeNameIndex)648 TEST_F(GraphTest, BuildNodeNameIndex) {
649 FromGraphDef(
650 "node { name: 'A' op: 'OneOutput' }"
651 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
652 "node { name: 'C' op: 'NoOp' } ");
653
654 auto node_name_index = graph_.BuildNodeNameIndex();
655 EXPECT_EQ(node_name_index.size(), 5);
656
657 std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
658 for (const string& node_name : node_names) {
659 EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
660 EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
661 }
662 }
663
BM_InEdgeIteration(::testing::benchmark::State & state)664 void BM_InEdgeIteration(::testing::benchmark::State& state) {
665 const int num_nodes = state.range(0);
666 const int num_edges_per_node = state.range(1);
667 const GraphDef graph_def =
668 test::CreateGraphDef(num_nodes, num_edges_per_node);
669 Graph graph(OpRegistry::Global());
670 GraphConstructorOptions opts;
671 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
672
673 int64 sum = 0;
674 for (auto s : state) {
675 for (const Node* node : graph.nodes()) {
676 for (auto e : node->in_edges()) {
677 sum += e->id();
678 }
679 }
680 }
681 VLOG(1) << sum;
682 }
683 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
684 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
685 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 2);
686 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 2);
687 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 2);
688 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 4);
689 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 4);
690 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 4);
691 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 4);
692 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 4);
693 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 8);
694 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 8);
695 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 8);
696 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 8);
697 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 8);
698 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 16);
699 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 16);
700 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
701 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
702 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
703
BM_GraphCreation(::testing::benchmark::State & state)704 void BM_GraphCreation(::testing::benchmark::State& state) {
705 const int num_nodes = state.range(0);
706 const int num_edges_per_node = state.range(1);
707 const GraphDef graph_def =
708 test::CreateGraphDef(num_nodes, num_edges_per_node);
709 const auto registry = OpRegistry::Global();
710 GraphConstructorOptions opts;
711 // Warmup step.
712 Graph graph(registry);
713 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
714 int64 sum = 0;
715 for (auto s : state) {
716 Graph graph(registry);
717 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
718 sum += graph.num_node_ids();
719 }
720 VLOG(1) << sum;
721 }
722 BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
723 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
724 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 2);
725 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 2);
726 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 2);
727 BENCHMARK(BM_GraphCreation)->ArgPair(10, 4);
728 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 4);
729 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 4);
730 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 4);
731 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 4);
732 BENCHMARK(BM_GraphCreation)->ArgPair(10, 8);
733 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 8);
734 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 8);
735 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 8);
736 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 8);
737 BENCHMARK(BM_GraphCreation)->ArgPair(10, 16);
738 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 16);
739 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
740 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
741 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
742
BM_ToGraphDef(::testing::benchmark::State & state)743 void BM_ToGraphDef(::testing::benchmark::State& state) {
744 const int num_nodes = state.range(0);
745 const int num_edges_per_node = state.range(1);
746 const GraphDef graph_def =
747 test::CreateGraphDef(num_nodes, num_edges_per_node);
748 const auto registry = OpRegistry::Global();
749 GraphConstructorOptions opts;
750 // Warmup step.
751 Graph graph(registry);
752 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
753 int64 sum = 0;
754 for (auto s : state) {
755 GraphDef graph_def;
756 graph.ToGraphDef(&graph_def);
757 sum += graph_def.node_size();
758 }
759 VLOG(1) << sum;
760 }
761 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
762 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
763 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2);
764 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2);
765 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2);
766 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4);
767 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4);
768 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4);
769 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4);
770 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4);
771 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8);
772 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8);
773 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8);
774 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8);
775 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8);
776 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16);
777 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16);
778 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
779 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
780 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
781
BM_RemoveNode(::testing::benchmark::State & state)782 void BM_RemoveNode(::testing::benchmark::State& state) {
783 const int num_nodes = state.range(0);
784 const int num_edges_per_node = state.range(1);
785 const GraphDef graph_def =
786 test::CreateGraphDef(num_nodes, num_edges_per_node);
787 const auto registry = OpRegistry::Global();
788 GraphConstructorOptions opts;
789 for (auto s : state) {
790 state.PauseTiming();
791 Graph graph(registry);
792 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
793 state.ResumeTiming();
794 for (Node* n : graph.op_nodes()) {
795 graph.RemoveNode(n);
796 }
797 }
798 }
799 BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
800 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
801 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
802 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
803 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
804 BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
805 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
806 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
807 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
808 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
809 BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
810 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
811 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
812 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
813 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
814 BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
815 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
816 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
817 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
818 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
819
820 } // namespace
821 } // namespace tensorflow
822