1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/framework/function_testlib.h"
23 #include "tensorflow/core/graph/benchmark_testlib.h"
24 #include "tensorflow/core/graph/graph_constructor.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/random/simple_philox.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34
35 namespace tensorflow {
36 namespace {
37
38 REGISTER_OP("OneInput").Input("x: float");
39
40 REGISTER_OP("OneOutput").Output("y: float");
41
42 REGISTER_OP("OneInputTwoOutputs")
43 .Input("x: float")
44 .Output("y: float")
45 .Output("z: float");
46
47 REGISTER_OP("TwoInputsOneOutput")
48 .Input("x: float")
49 .Input("y: float")
50 .Output("z: float");
51
52 class GraphTest : public ::testing::Test {
53 protected:
GraphTest()54 GraphTest() : graph_(OpRegistry::Global()) {}
~GraphTest()55 ~GraphTest() override {}
56
VerifyNodes(Node * node,const std::vector<Node * > & expected_in,const std::vector<Node * > & expected_out)57 static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in,
58 const std::vector<Node*>& expected_out) {
59 std::vector<Node*> in;
60 for (const Edge* e : node->in_edges()) {
61 in.push_back(e->src());
62 }
63 EXPECT_EQ(Stringify(expected_in), Stringify(in));
64
65 std::vector<Node*> out;
66 for (const Edge* e : node->out_edges()) {
67 out.push_back(e->dst());
68 }
69 EXPECT_EQ(Stringify(expected_out), Stringify(out));
70 }
71
VerifyGraphStats()72 void VerifyGraphStats() {
73 int nodes = 0;
74 for (const Node* n : graph_.nodes()) {
75 VLOG(1) << n->id();
76 ++nodes;
77 }
78 EXPECT_EQ(nodes, graph_.num_nodes());
79 int edges = 0;
80 for (const Edge* e : graph_.edges()) {
81 VLOG(1) << e->id();
82 ++edges;
83 }
84 EXPECT_EQ(edges, graph_.num_edges());
85 }
86
AddNodeWithName(const string & name)87 Node* AddNodeWithName(const string& name) {
88 Node* node;
89 TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
90 return node;
91 }
92
FromNodeDef(const string & name,const string & node_type,int num_inputs)93 Node* FromNodeDef(const string& name, const string& node_type,
94 int num_inputs) {
95 auto builder = NodeDefBuilder(name, node_type);
96 for (int i = 0; i < num_inputs; ++i) {
97 builder = builder.Input(strings::StrCat("node_", i), i, DT_FLOAT);
98 }
99
100 NodeDef node_def;
101 TF_CHECK_OK(builder.Finalize(&node_def));
102
103 Status s;
104 Node* node = graph_.AddNode(node_def, &s);
105 TF_CHECK_OK(s);
106 return node;
107 }
108
FromGraphDef(const string & gdef_ascii)109 void FromGraphDef(const string& gdef_ascii) {
110 GraphDef gdef;
111 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef));
112 GraphConstructorOptions opts;
113 TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef, &graph_));
114 }
115
FindNode(const string & name)116 Node* FindNode(const string& name) {
117 for (Node* node : graph_.nodes()) {
118 if (node->name() == name) return node;
119 }
120 LOG(FATAL) << name;
121 }
122
ControlEdgeExistsInGraphOrNodeDef(const Node * src,const Node * dst)123 bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
124 for (const Edge* e : dst->in_edges()) {
125 if (e->IsControlEdge() && e->src() == src &&
126 e->src_output() == Graph::kControlSlot &&
127 e->dst_input() == Graph::kControlSlot) {
128 return true;
129 }
130 }
131 std::string control_edge_name = strings::StrCat("^", src->name());
132 for (int i = 0; i < dst->def().input_size(); ++i) {
133 if (dst->def().input(i) == control_edge_name) {
134 return true;
135 }
136 }
137 return false;
138 }
139
140 Graph graph_;
141
142 private:
143 // Convert a list of nodes to a sorted list of strings so failure messages
144 // are readable.
Stringify(const std::vector<Node * > & nodes)145 static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
146 std::vector<string> result;
147 result.reserve(nodes.size());
148 for (Node* n : nodes) {
149 result.push_back(n->DebugString());
150 }
151 std::sort(result.begin(), result.end());
152 return result;
153 }
154 };
155
TEST_F(GraphTest,Constructor)156 TEST_F(GraphTest, Constructor) {
157 Node* source = graph_.source_node();
158 EXPECT_NE(source, nullptr);
159 Node* sink = graph_.sink_node();
160 EXPECT_NE(sink, nullptr);
161 VerifyNodes(source, {}, {sink});
162 VerifyNodes(sink, {source}, {});
163 EXPECT_EQ(2, graph_.num_node_ids());
164 VerifyGraphStats();
165 }
166
TEST_F(GraphTest,RemoveThenAdd)167 TEST_F(GraphTest, RemoveThenAdd) {
168 AddNodeWithName("A");
169 Node* b = AddNodeWithName("B");
170 const int b_id = b->id();
171 AddNodeWithName("C");
172 EXPECT_EQ(5, graph_.num_node_ids());
173 graph_.RemoveNode(b);
174 EXPECT_EQ(5, graph_.num_node_ids());
175 Node* d = AddNodeWithName("D");
176 EXPECT_NE(b_id, d->id()); // Ids should not be reused.
177 EXPECT_EQ(6, graph_.num_node_ids());
178 VerifyGraphStats();
179 }
180
TEST_F(GraphTest,InNodesAndOutNodes)181 TEST_F(GraphTest, InNodesAndOutNodes) {
182 Node* a = FromNodeDef("A", "OneOutput", 0);
183 Node* b = AddNodeWithName("B");
184 Node* c = FromNodeDef("C", "OneInput", 1);
185 graph_.RemoveNode(b);
186 Node* d = AddNodeWithName("D");
187
188 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
189 graph_.AddControlEdge(a, graph_.sink_node());
190 graph_.AddEdge(a, 0, c, 0);
191 graph_.AddControlEdge(c, graph_.sink_node());
192
193 EXPECT_EQ("A", a->name());
194 VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
195
196 EXPECT_EQ("C", c->name());
197 VerifyNodes(c, {a}, {graph_.sink_node()});
198
199 EXPECT_EQ("D", d->name());
200 VerifyNodes(d, {}, {});
201
202 VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
203 VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
204
205 graph_.RemoveEdge(source_to_a);
206 VerifyNodes(a, {}, {c, graph_.sink_node()});
207 VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a
208
209 graph_.RemoveNode(c);
210 VerifyNodes(a, {}, {graph_.sink_node()}); // no more c
211 VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c
212 EXPECT_EQ(6, graph_.num_node_ids());
213 EXPECT_EQ(5, graph_.num_edge_ids());
214 VerifyGraphStats();
215 }
216
TEST_F(GraphTest,NodeByIndex)217 TEST_F(GraphTest, NodeByIndex) {
218 Node* a = FromNodeDef("A", "OneOutput", 0);
219 Node* c = FromNodeDef("C", "OneInput", 1);
220 graph_.AddEdge(a, 0, c, 0);
221
222 // Ask for 'a' from 'c' by index.
223 const Node* a_copy;
224 TF_ASSERT_OK(c->input_node(0, &a_copy));
225 EXPECT_EQ(a, a_copy);
226
227 const Edge* e;
228 TF_ASSERT_OK(c->input_edge(0, &e));
229 EXPECT_EQ(0, e->dst_input());
230 EXPECT_EQ(a, e->src());
231 EXPECT_EQ(c, e->dst());
232 EXPECT_EQ(0, e->src_output());
233
234 Node* t = FromNodeDef("T", "TwoInputsOneOutput", 2);
235 graph_.AddEdge(a, 0, t, 0);
236 // Weird self edge
237 graph_.AddEdge(t, 0, t, 1);
238
239 const Node* t_0;
240 const Node* t_1;
241 TF_ASSERT_OK(t->input_node(0, &t_0));
242 EXPECT_EQ(a, t_0);
243 TF_ASSERT_OK(t->input_node(1, &t_1));
244 EXPECT_EQ(t, t_1);
245
246 TF_ASSERT_OK(t->input_edge(1, &e));
247 EXPECT_EQ(1, e->dst_input());
248 EXPECT_EQ(t, e->src());
249
250 std::vector<const Edge*> t_input_edges;
251 TF_ASSERT_OK(t->input_edges(&t_input_edges));
252 ASSERT_EQ(2, t_input_edges.size());
253 EXPECT_EQ(a, t_input_edges[0]->src());
254 EXPECT_EQ(e, t_input_edges[1]);
255
256 // Check out of bounds access
257 EXPECT_FALSE(c->input_node(1, &a_copy).ok());
258 EXPECT_FALSE(c->input_node(-1, &a_copy).ok());
259
260 graph_.RemoveNode(a);
261
262 // 'c's input_node entry should be invalidated.
263 Status s = c->input_node(0, &a_copy);
264 EXPECT_FALSE(s.ok());
265
266 // Add two new nodes.
267 Node* a_new = FromNodeDef("A_new", "OneOutput", 0);
268 Node* b_new = FromNodeDef("B_new", "OneOutput", 0);
269
270 // Connect one up to c.
271 graph_.AddEdge(a_new, 0, c, 0);
272 const Edge* a_new_c_edge;
273 TF_ASSERT_OK(c->input_edge(0, &a_new_c_edge));
274
275 // Connect up the second edge
276 graph_.AddEdge(b_new, 0, c, 0);
277 const Edge* b_new_c_edge;
278 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
279
280 // Now remove the old one
281 graph_.RemoveEdge(a_new_c_edge);
282
283 // Check that the second edge can still be retrieved
284 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
285
286 std::vector<const Edge*> c_input_edges;
287 TF_ASSERT_OK(c->input_edges(&c_input_edges));
288 ASSERT_EQ(1, c_input_edges.size());
289 EXPECT_EQ(b_new_c_edge, c_input_edges[0]);
290 }
291
TEST_F(GraphTest,NodeIteration)292 TEST_F(GraphTest, NodeIteration) {
293 // Set up the graph with some holes due to removals.
294 Node* a = FromNodeDef("A", "OneOutput", 0);
295 Node* b = AddNodeWithName("B");
296 Node* c = FromNodeDef("C", "OneInput", 1);
297 graph_.RemoveNode(b);
298 Node* d = AddNodeWithName("D");
299 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
300 graph_.AddControlEdge(a, graph_.sink_node());
301 graph_.AddEdge(a, 0, c, 0);
302 graph_.AddControlEdge(c, graph_.sink_node());
303 graph_.RemoveEdge(source_to_a);
304 graph_.RemoveNode(c);
305
306 // expected = set of all node DebugStrings we expect in the graph
307 std::set<string> expected;
308 expected.insert(graph_.source_node()->DebugString());
309 expected.insert(a->DebugString());
310 expected.insert(d->DebugString());
311 expected.insert(graph_.sink_node()->DebugString());
312
313 // Verify that iterating through ids gets the same set of nodes.
314 std::set<string> actual;
315 for (int id = 0; id < graph_.num_node_ids(); ++id) {
316 Node* node = graph_.FindNodeId(id);
317 if (node != nullptr) {
318 actual.insert(node->DebugString());
319 }
320 }
321 EXPECT_EQ(expected, actual);
322
323 // Verify that range-based for loop gets the same set of nodes.
324 actual.clear();
325 for (Node* node : graph_.nodes()) {
326 actual.insert(node->DebugString());
327 }
328 EXPECT_EQ(expected, actual);
329 VerifyGraphStats();
330 }
331
CheckType(Node * node,bool b)332 static void CheckType(Node* node, bool b) {
333 EXPECT_TRUE(b) << node->DebugString();
334 // Make sure none of the other IsFoo() methods return true.
335 int count = 0;
336 if (node->IsSource()) count++;
337 if (node->IsSink()) count++;
338 if (node->IsOp()) count++;
339 EXPECT_EQ(1, count) << node->DebugString();
340 }
341
TEST_F(GraphTest,Type)342 TEST_F(GraphTest, Type) {
343 Node* op = AddNodeWithName("A");
344 CheckType(graph_.source_node(), graph_.source_node()->IsSource());
345 CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
346 CheckType(op, op->IsOp());
347 VerifyGraphStats();
348 }
349
TEST_F(GraphTest,AddAttr)350 TEST_F(GraphTest, AddAttr) {
351 Node* n1 = AddNodeWithName("A");
352
353 n1->AddAttr("_a", "new_attr");
354
355 string attr;
356 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
357 EXPECT_EQ("new_attr", attr);
358
359 Node* n2 = graph_.CopyNode(n1);
360
361 n1->AddAttr("_b", "new_attr_2");
362
363 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
364 EXPECT_EQ("new_attr", attr);
365 EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
366 EXPECT_EQ("new_attr_2", attr);
367
368 EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
369 EXPECT_EQ("new_attr", attr);
370 EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
371 }
372
373 // Convert edge iteration results into a sorted string.
EdgeIter(const Graph & g)374 static string EdgeIter(const Graph& g) {
375 std::vector<std::pair<int, int> > edges;
376 for (const Edge* e : g.edges()) {
377 edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
378 }
379 std::sort(edges.begin(), edges.end());
380 string result;
381 for (auto& p : edges) {
382 strings::StrAppend(&result, p.first, "->", p.second, ";");
383 }
384 return result;
385 }
386
TEST_F(GraphTest,EdgeIteration)387 TEST_F(GraphTest, EdgeIteration) {
388 EXPECT_EQ("0->1;", EdgeIter(graph_));
389
390 Node* a = FromNodeDef("A", "OneInputTwoOutputs", 1);
391 Node* b = FromNodeDef("B", "OneInput", 1);
392 EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected
393
394 graph_.AddEdge(a, 0, b, 0);
395 EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
396
397 graph_.AddControlEdge(graph_.source_node(), a);
398 graph_.AddControlEdge(b, graph_.sink_node());
399 EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
400
401 graph_.AddEdge(a, 1, a, 0);
402 EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
403 VerifyGraphStats();
404 }
405
TEST_F(GraphTest,NewName)406 TEST_F(GraphTest, NewName) {
407 string a1 = graph_.NewName("A");
408 string a2 = graph_.NewName("A");
409 string b1 = graph_.NewName("B");
410 EXPECT_NE(a1, a2);
411 EXPECT_NE(a1, b1);
412 EXPECT_NE(a2, b1);
413 EXPECT_TRUE(str_util::StartsWith(a1, "A")) << a1;
414 }
415
TEST_F(GraphTest,IsValidNode)416 TEST_F(GraphTest, IsValidNode) {
417 // Add 1 node to graph_
418 Node* g1_node1;
419 TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
420
421 // Add 2 nodes to graph2
422 Graph graph2(OpRegistry::Global());
423 Node* g2_node1;
424 Node* g2_node2;
425 TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
426 TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
427
428 // nullptr
429 Status s = graph_.IsValidNode(nullptr);
430 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
431 EXPECT_EQ(string("Node is null"), s.error_message());
432
433 // node id_ is too high
434 s = graph_.IsValidNode(g2_node2);
435 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
436 EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
437 s.error_message());
438
439 // valid id_ but different ptr
440 s = graph_.IsValidNode(g2_node1);
441 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
442 EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
443 "Does it belong to a different graph?"),
444 s.error_message());
445 }
446
TEST_F(GraphTest,AddControlEdge)447 TEST_F(GraphTest, AddControlEdge) {
448 FromGraphDef(
449 "node { name: 'A' op: 'OneOutput' }"
450 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
451 "node { name: 'C' op: 'NoOp' } ");
452 Node* a = FindNode("A");
453 Node* b = FindNode("B");
454 Node* c = FindNode("C");
455
456 // Add a control edge.
457 const Edge* edge = graph_.AddControlEdge(c, a);
458 ASSERT_TRUE(edge != nullptr);
459 // Check newly-created edge.
460 EXPECT_EQ(edge->src(), c);
461 EXPECT_EQ(edge->src_output(), Graph::kControlSlot);
462 EXPECT_EQ(edge->dst(), a);
463 EXPECT_EQ(edge->dst_input(), Graph::kControlSlot);
464 // Check A's NodeDef.
465 ASSERT_EQ(a->def().input_size(), 1);
466 EXPECT_EQ(a->def().input(0), "^C");
467
468 // Can add control edge redundant with data edge.
469 edge = graph_.AddControlEdge(a, b);
470 EXPECT_TRUE(edge != nullptr);
471 ASSERT_EQ(b->def().input_size(), 2);
472 EXPECT_EQ(b->def().input(0), "A:0");
473 EXPECT_EQ(b->def().input(1), "^A");
474
475 // Doesn't add edge redundant with control edge.
476 edge = graph_.AddControlEdge(a, b);
477 EXPECT_TRUE(edge == nullptr);
478 EXPECT_EQ(b->def().input_size(), 2);
479
480 // Can add redundant control edge with allow_duplicates.
481 edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true);
482 EXPECT_TRUE(edge != nullptr);
483 // create_duplicate causes the NodeDef not to be updated.
484 ASSERT_EQ(b->def().input_size(), 2);
485 EXPECT_EQ(b->def().input(0), "A:0");
486 EXPECT_EQ(b->def().input(1), "^A");
487
488 // Add control edge from source.
489 edge = graph_.AddControlEdge(graph_.source_node(), b);
490 EXPECT_TRUE(edge != nullptr);
491 // Check that we don't include source input in the NodeDef.
492 EXPECT_EQ(b->def().input_size(), 2);
493 // Doesn't add redundant edge.
494 edge = graph_.AddControlEdge(graph_.source_node(), b);
495 EXPECT_TRUE(edge == nullptr);
496 EXPECT_EQ(b->def().input_size(), 2);
497 }
498
TEST_F(GraphTest,RemoveControlEdge)499 TEST_F(GraphTest, RemoveControlEdge) {
500 FromGraphDef(
501 "node { name: 'A' op: 'OneOutput' }"
502 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
503 "node { name: 'C' op: 'NoOp' } ");
504 Node* a = FindNode("A");
505 Node* b = FindNode("B");
506 Node* c = FindNode("C");
507
508 // Add a control edge.
509 const Edge* edge_1 = graph_.AddControlEdge(c, a);
510 const Edge* edge_2 = graph_.AddControlEdge(a, b);
511 ASSERT_TRUE(edge_1 != nullptr);
512 ASSERT_TRUE(edge_2 != nullptr);
513
514 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a));
515 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
516
517 graph_.RemoveControlEdge(edge_1);
518 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
519 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
520
521 graph_.RemoveControlEdge(edge_2);
522 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
523 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b));
524
525 // Test removing a duplicate control edge.
526 // Note that unless allow_duplicates is true, the duplicate edge
527 // will not be added. That's why we expect edge_4 to be a null
528 // pointer. We are not testing with allow_duplicates set to true,
529 // as that is a highly unlikely use case that does not make much
530 // sense.
531 const Edge* edge_3 = graph_.AddControlEdge(c, a);
532 const Edge* edge_4 = graph_.AddControlEdge(c, a);
533 ASSERT_TRUE(edge_3 != nullptr);
534 ASSERT_TRUE(edge_4 == nullptr);
535
536 graph_.RemoveControlEdge(edge_3);
537 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
538 }
539
TEST_F(GraphTest,UpdateEdge)540 TEST_F(GraphTest, UpdateEdge) {
541 // Build a little graph
542 Node* a = FromNodeDef("A", "OneOutput", 0);
543 Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
544 Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
545 Node* d = FromNodeDef("D", "OneInput", 1);
546
547 graph_.AddControlEdge(graph_.source_node(), a);
548 graph_.AddControlEdge(a, graph_.sink_node());
549 graph_.AddEdge(a, 0, c, 0);
550
551 graph_.AddControlEdge(c, graph_.sink_node());
552 graph_.AddEdge(c, 0, b, 0);
553 graph_.AddEdge(c, 1, d, 0);
554
555 // Initial edge connections
556 EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
557
558 // Update the inputs, expect that Edge a to b (2->3) is now in the graph
559 // and c to b (4->3) no longer appears.
560 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
561 // Check that the edge is connecting the correct nodes.
562 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
563
564 // Update a's 0th output again.
565 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
566 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
567
568 // Update a's 1st output which is out of range.
569 Status s = graph_.UpdateEdge(a, 1, d, 0);
570 EXPECT_FALSE(s.ok());
571 EXPECT_EQ(
572 s.error_message(),
573 "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
574
575 // Update a's 1st input which is out of range.
576 s = graph_.UpdateEdge(c, 0, a, 0);
577 EXPECT_FALSE(s.ok());
578 EXPECT_EQ(
579 s.error_message(),
580 "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
581 }
582
TEST_F(GraphTest,InputEdges)583 TEST_F(GraphTest, InputEdges) {
584 Node* a = FromNodeDef("A", "OneOutput", 0);
585 Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
586 graph_.AddEdge(a, 0, b, 0);
587 std::vector<const Edge*> edges;
588 EXPECT_EQ(error::INVALID_ARGUMENT, b->input_edges(&edges).code());
589 graph_.AddEdge(a, 0, b, 1);
590 TF_EXPECT_OK(b->input_edges(&edges));
591 }
592
TEST_F(GraphTest,AddFunctionLibrary)593 TEST_F(GraphTest, AddFunctionLibrary) {
594 // Basic functionality
595 FunctionDefLibrary proto;
596 *proto.add_function() = test::function::XTimesTwo();
597 *proto.add_function() = test::function::XTimesFour();
598 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
599 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
600 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
601
602 // Duplicate functions are ignored
603 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
604 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
605 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
606
607 // Duplicate names corresponding to different functions trigger an error
608 FunctionDefLibrary error_proto = proto;
609 *error_proto.mutable_function(0)->add_node_def() =
610 error_proto.function(0).node_def(0);
611 Status s = graph_.AddFunctionLibrary(error_proto);
612 EXPECT_FALSE(s.ok());
613 EXPECT_EQ(s.error_message(),
614 "Cannot add function 'XTimesTwo' because a different function with "
615 "the same name already exists.");
616
617 // Function with same name as an existing op triggers an error
618 error_proto = proto;
619 error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
620 s = graph_.AddFunctionLibrary(error_proto);
621 EXPECT_FALSE(s.ok());
622 EXPECT_EQ(s.error_message(),
623 "Cannot add function 'Add' because an op with the same name "
624 "already exists.");
625
626 // Adding a gradient function to an existing function is ok
627 GradientDef* grad = proto.add_gradient();
628 grad->set_function_name("XTimesTwo");
629 grad->set_gradient_func("Undefined"); // undefined funcs in grads are ok
630 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
631 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
632
633 // Duplicate gradients are ignored
634 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
635 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
636
637 // Conflicting gradient triggers an error
638 error_proto = proto;
639 error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
640 s = graph_.AddFunctionLibrary(error_proto);
641 EXPECT_FALSE(s.ok());
642 EXPECT_EQ(s.error_message(),
643 "Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
644 "because it already has gradient function 'Undefined'");
645 }
646
TEST_F(GraphTest,BuildNodeNameIndex)647 TEST_F(GraphTest, BuildNodeNameIndex) {
648 FromGraphDef(
649 "node { name: 'A' op: 'OneOutput' }"
650 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
651 "node { name: 'C' op: 'NoOp' } ");
652
653 auto node_name_index = graph_.BuildNodeNameIndex();
654 EXPECT_EQ(node_name_index.size(), 5);
655
656 std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
657 for (const string& node_name : node_names) {
658 EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
659 EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
660 }
661 }
662
BM_InEdgeIteration(int iters,int num_nodes,int num_edges_per_node)663 static void BM_InEdgeIteration(int iters, int num_nodes,
664 int num_edges_per_node) {
665 testing::StopTiming();
666 const GraphDef graph_def =
667 test::CreateGraphDef(num_nodes, num_edges_per_node);
668 Graph graph(OpRegistry::Global());
669 GraphConstructorOptions opts;
670 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
671
672 int64 sum = 0;
673 testing::StartTiming();
674 for (int i = 0; i < iters; ++i) {
675 for (const Node* node : graph.nodes()) {
676 for (auto e : node->in_edges()) {
677 sum += e->id();
678 }
679 }
680 }
681 VLOG(1) << sum;
682 testing::StopTiming();
683 }
684 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
685 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
686 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 2);
687 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 2);
688 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 2);
689 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 4);
690 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 4);
691 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 4);
692 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 4);
693 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 4);
694 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 8);
695 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 8);
696 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 8);
697 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 8);
698 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 8);
699 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 16);
700 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 16);
701 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
702 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
703 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
704
BM_GraphCreation(int iters,int num_nodes,int num_edges_per_node)705 static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) {
706 testing::StopTiming();
707 const GraphDef graph_def =
708 test::CreateGraphDef(num_nodes, num_edges_per_node);
709 const auto registry = OpRegistry::Global();
710 GraphConstructorOptions opts;
711 // Warmup step.
712 Graph graph(registry);
713 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
714 int64 sum = 0;
715 testing::StartTiming();
716 for (int i = 0; i < iters; ++i) {
717 Graph graph(registry);
718 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
719 sum += graph.num_node_ids();
720 }
721 VLOG(1) << sum;
722 testing::StopTiming();
723 }
724 BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
725 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
726 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 2);
727 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 2);
728 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 2);
729 BENCHMARK(BM_GraphCreation)->ArgPair(10, 4);
730 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 4);
731 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 4);
732 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 4);
733 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 4);
734 BENCHMARK(BM_GraphCreation)->ArgPair(10, 8);
735 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 8);
736 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 8);
737 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 8);
738 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 8);
739 BENCHMARK(BM_GraphCreation)->ArgPair(10, 16);
740 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 16);
741 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
742 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
743 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
744
BM_ToGraphDef(int iters,int num_nodes,int num_edges_per_node)745 static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) {
746 testing::StopTiming();
747 const GraphDef graph_def =
748 test::CreateGraphDef(num_nodes, num_edges_per_node);
749 const auto registry = OpRegistry::Global();
750 GraphConstructorOptions opts;
751 // Warmup step.
752 Graph graph(registry);
753 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
754 int64 sum = 0;
755 testing::StartTiming();
756 for (int i = 0; i < iters; ++i) {
757 GraphDef graph_def;
758 graph.ToGraphDef(&graph_def);
759 sum += graph_def.node_size();
760 }
761 VLOG(1) << sum;
762 testing::StopTiming();
763 }
764 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
765 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
766 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2);
767 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2);
768 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2);
769 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4);
770 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4);
771 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4);
772 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4);
773 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4);
774 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8);
775 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8);
776 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8);
777 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8);
778 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8);
779 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16);
780 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16);
781 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
782 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
783 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
784
785 } // namespace
786 } // namespace tensorflow
787