1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/graph/benchmark_testlib.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_util.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/random/simple_philox.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 REGISTER_OP("OneInput").Input("x: float");
40 
41 REGISTER_OP("OneOutput").Output("y: float");
42 
43 REGISTER_OP("OneInputTwoOutputs")
44     .Input("x: float")
45     .Output("y: float")
46     .Output("z: float");
47 
48 REGISTER_OP("TwoInputsOneOutput")
49     .Input("x: float")
50     .Input("y: float")
51     .Output("z: float");
52 
53 class GraphTest : public ::testing::Test {
54  protected:
GraphTest()55   GraphTest() : graph_(OpRegistry::Global()) {}
~GraphTest()56   ~GraphTest() override {}
57 
VerifyNodes(Node * node,const std::vector<Node * > & expected_in,const std::vector<Node * > & expected_out)58   static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in,
59                           const std::vector<Node*>& expected_out) {
60     std::vector<Node*> in;
61     for (const Edge* e : node->in_edges()) {
62       in.push_back(e->src());
63     }
64     EXPECT_EQ(Stringify(expected_in), Stringify(in));
65 
66     std::vector<Node*> out;
67     for (const Edge* e : node->out_edges()) {
68       out.push_back(e->dst());
69     }
70     EXPECT_EQ(Stringify(expected_out), Stringify(out));
71   }
72 
VerifyGraphStats()73   void VerifyGraphStats() {
74     int nodes = 0;
75     for (const Node* n : graph_.nodes()) {
76       VLOG(1) << n->id();
77       ++nodes;
78     }
79     EXPECT_EQ(nodes, graph_.num_nodes());
80     int edges = 0;
81     for (const Edge* e : graph_.edges()) {
82       VLOG(1) << e->id();
83       ++edges;
84     }
85     EXPECT_EQ(edges, graph_.num_edges());
86   }
87 
AddNodeWithName(const string & name)88   Node* AddNodeWithName(const string& name) {
89     Node* node;
90     TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
91     return node;
92   }
93 
FromNodeDef(const string & name,const string & node_type,int num_inputs)94   Node* FromNodeDef(const string& name, const string& node_type,
95                     int num_inputs) {
96     auto builder = NodeDefBuilder(name, node_type);
97     for (int i = 0; i < num_inputs; ++i) {
98       builder = builder.Input(strings::StrCat("node_", i), i, DT_FLOAT);
99     }
100 
101     NodeDef node_def;
102     TF_CHECK_OK(builder.Finalize(&node_def));
103 
104     Status s;
105     Node* node = graph_.AddNode(node_def, &s);
106     TF_CHECK_OK(s);
107     return node;
108   }
109 
FromGraphDef(const string & gdef_ascii)110   void FromGraphDef(const string& gdef_ascii) {
111     GraphDef gdef;
112     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef));
113     GraphConstructorOptions opts;
114     TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef, &graph_));
115   }
116 
FindNode(const string & name)117   Node* FindNode(const string& name) {
118     for (Node* node : graph_.nodes()) {
119       if (node->name() == name) return node;
120     }
121     LOG(FATAL) << name;
122   }
123 
ControlEdgeExistsInGraphOrNodeDef(const Node * src,const Node * dst)124   bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
125     for (const Edge* e : dst->in_edges()) {
126       if (e->IsControlEdge() && e->src() == src &&
127           e->src_output() == Graph::kControlSlot &&
128           e->dst_input() == Graph::kControlSlot) {
129         return true;
130       }
131     }
132     std::string control_edge_name = strings::StrCat("^", src->name());
133     for (int i = 0; i < dst->def().input_size(); ++i) {
134       if (dst->def().input(i) == control_edge_name) {
135         return true;
136       }
137     }
138     return false;
139   }
140 
141   Graph graph_;
142 
143  private:
144   // Convert a list of nodes to a sorted list of strings so failure messages
145   // are readable.
Stringify(const std::vector<Node * > & nodes)146   static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
147     std::vector<string> result;
148     result.reserve(nodes.size());
149     for (Node* n : nodes) {
150       result.push_back(n->DebugString());
151     }
152     std::sort(result.begin(), result.end());
153     return result;
154   }
155 };
156 
TEST_F(GraphTest,Constructor)157 TEST_F(GraphTest, Constructor) {
158   Node* source = graph_.source_node();
159   EXPECT_NE(source, nullptr);
160   Node* sink = graph_.sink_node();
161   EXPECT_NE(sink, nullptr);
162   VerifyNodes(source, {}, {sink});
163   VerifyNodes(sink, {source}, {});
164   EXPECT_EQ(2, graph_.num_node_ids());
165   VerifyGraphStats();
166 }
167 
TEST_F(GraphTest,RemoveThenAdd)168 TEST_F(GraphTest, RemoveThenAdd) {
169   AddNodeWithName("A");
170   Node* b = AddNodeWithName("B");
171   const int b_id = b->id();
172   AddNodeWithName("C");
173   EXPECT_EQ(5, graph_.num_node_ids());
174   graph_.RemoveNode(b);
175   EXPECT_EQ(5, graph_.num_node_ids());
176   Node* d = AddNodeWithName("D");
177   EXPECT_NE(b_id, d->id());  // Ids should not be reused.
178   EXPECT_EQ(6, graph_.num_node_ids());
179   VerifyGraphStats();
180 }
181 
TEST_F(GraphTest,InNodesAndOutNodes)182 TEST_F(GraphTest, InNodesAndOutNodes) {
183   Node* a = FromNodeDef("A", "OneOutput", 0);
184   Node* b = AddNodeWithName("B");
185   Node* c = FromNodeDef("C", "OneInput", 1);
186   graph_.RemoveNode(b);
187   Node* d = AddNodeWithName("D");
188 
189   const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
190   graph_.AddControlEdge(a, graph_.sink_node());
191   graph_.AddEdge(a, 0, c, 0);
192   graph_.AddControlEdge(c, graph_.sink_node());
193 
194   EXPECT_EQ("A", a->name());
195   VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
196 
197   EXPECT_EQ("C", c->name());
198   VerifyNodes(c, {a}, {graph_.sink_node()});
199 
200   EXPECT_EQ("D", d->name());
201   VerifyNodes(d, {}, {});
202 
203   VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
204   VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
205 
206   graph_.RemoveEdge(source_to_a);
207   VerifyNodes(a, {}, {c, graph_.sink_node()});
208   VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()});  // no more a
209 
210   graph_.RemoveNode(c);
211   VerifyNodes(a, {}, {graph_.sink_node()});                        // no more c
212   VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {});  // no more c
213   EXPECT_EQ(6, graph_.num_node_ids());
214   EXPECT_EQ(5, graph_.num_edge_ids());
215   VerifyGraphStats();
216 }
217 
TEST_F(GraphTest,NodeByIndex)218 TEST_F(GraphTest, NodeByIndex) {
219   Node* a = FromNodeDef("A", "OneOutput", 0);
220   Node* c = FromNodeDef("C", "OneInput", 1);
221   graph_.AddEdge(a, 0, c, 0);
222 
223   // Ask for 'a' from 'c' by index.
224   const Node* a_copy;
225   TF_ASSERT_OK(c->input_node(0, &a_copy));
226   EXPECT_EQ(a, a_copy);
227 
228   const Edge* e;
229   TF_ASSERT_OK(c->input_edge(0, &e));
230   EXPECT_EQ(0, e->dst_input());
231   EXPECT_EQ(a, e->src());
232   EXPECT_EQ(c, e->dst());
233   EXPECT_EQ(0, e->src_output());
234 
235   Node* t = FromNodeDef("T", "TwoInputsOneOutput", 2);
236   graph_.AddEdge(a, 0, t, 0);
237   // Weird self edge
238   graph_.AddEdge(t, 0, t, 1);
239 
240   const Node* t_0;
241   const Node* t_1;
242   TF_ASSERT_OK(t->input_node(0, &t_0));
243   EXPECT_EQ(a, t_0);
244   TF_ASSERT_OK(t->input_node(1, &t_1));
245   EXPECT_EQ(t, t_1);
246 
247   TF_ASSERT_OK(t->input_edge(1, &e));
248   EXPECT_EQ(1, e->dst_input());
249   EXPECT_EQ(t, e->src());
250 
251   std::vector<const Edge*> t_input_edges;
252   TF_ASSERT_OK(t->input_edges(&t_input_edges));
253   ASSERT_EQ(2, t_input_edges.size());
254   EXPECT_EQ(a, t_input_edges[0]->src());
255   EXPECT_EQ(e, t_input_edges[1]);
256 
257   // Check out of bounds access
258   EXPECT_FALSE(c->input_node(1, &a_copy).ok());
259   EXPECT_FALSE(c->input_node(-1, &a_copy).ok());
260 
261   graph_.RemoveNode(a);
262 
263   // 'c's input_node entry should be invalidated.
264   Status s = c->input_node(0, &a_copy);
265   EXPECT_FALSE(s.ok());
266 
267   // Add two new nodes.
268   Node* a_new = FromNodeDef("A_new", "OneOutput", 0);
269   Node* b_new = FromNodeDef("B_new", "OneOutput", 0);
270 
271   // Connect one up to c.
272   graph_.AddEdge(a_new, 0, c, 0);
273   const Edge* a_new_c_edge;
274   TF_ASSERT_OK(c->input_edge(0, &a_new_c_edge));
275 
276   // Connect up the second edge
277   graph_.AddEdge(b_new, 0, c, 0);
278   const Edge* b_new_c_edge;
279   TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
280 
281   // Now remove the old one
282   graph_.RemoveEdge(a_new_c_edge);
283 
284   // Check that the second edge can still be retrieved
285   TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
286 
287   std::vector<const Edge*> c_input_edges;
288   TF_ASSERT_OK(c->input_edges(&c_input_edges));
289   ASSERT_EQ(1, c_input_edges.size());
290   EXPECT_EQ(b_new_c_edge, c_input_edges[0]);
291 }
292 
TEST_F(GraphTest,NodeIteration)293 TEST_F(GraphTest, NodeIteration) {
294   // Set up the graph with some holes due to removals.
295   Node* a = FromNodeDef("A", "OneOutput", 0);
296   Node* b = AddNodeWithName("B");
297   Node* c = FromNodeDef("C", "OneInput", 1);
298   graph_.RemoveNode(b);
299   Node* d = AddNodeWithName("D");
300   const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
301   graph_.AddControlEdge(a, graph_.sink_node());
302   graph_.AddEdge(a, 0, c, 0);
303   graph_.AddControlEdge(c, graph_.sink_node());
304   graph_.RemoveEdge(source_to_a);
305   graph_.RemoveNode(c);
306 
307   // expected = set of all node DebugStrings we expect in the graph
308   std::set<string> expected;
309   expected.insert(graph_.source_node()->DebugString());
310   expected.insert(a->DebugString());
311   expected.insert(d->DebugString());
312   expected.insert(graph_.sink_node()->DebugString());
313 
314   // Verify that iterating through ids gets the same set of nodes.
315   std::set<string> actual;
316   for (int id = 0; id < graph_.num_node_ids(); ++id) {
317     Node* node = graph_.FindNodeId(id);
318     if (node != nullptr) {
319       actual.insert(node->DebugString());
320     }
321   }
322   EXPECT_EQ(expected, actual);
323 
324   // Verify that range-based for loop gets the same set of nodes.
325   actual.clear();
326   for (Node* node : graph_.nodes()) {
327     actual.insert(node->DebugString());
328   }
329   EXPECT_EQ(expected, actual);
330   VerifyGraphStats();
331 }
332 
CheckType(Node * node,bool b)333 static void CheckType(Node* node, bool b) {
334   EXPECT_TRUE(b) << node->DebugString();
335   // Make sure none of the other IsFoo() methods return true.
336   int count = 0;
337   if (node->IsSource()) count++;
338   if (node->IsSink()) count++;
339   if (node->IsOp()) count++;
340   EXPECT_EQ(1, count) << node->DebugString();
341 }
342 
TEST_F(GraphTest,Type)343 TEST_F(GraphTest, Type) {
344   Node* op = AddNodeWithName("A");
345   CheckType(graph_.source_node(), graph_.source_node()->IsSource());
346   CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
347   CheckType(op, op->IsOp());
348   VerifyGraphStats();
349 }
350 
TEST_F(GraphTest,AddAttr)351 TEST_F(GraphTest, AddAttr) {
352   Node* n1 = AddNodeWithName("A");
353 
354   n1->AddAttr("_a", "new_attr");
355 
356   string attr;
357   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
358   EXPECT_EQ("new_attr", attr);
359 
360   Node* n2 = graph_.CopyNode(n1);
361 
362   n1->AddAttr("_b", "new_attr_2");
363 
364   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
365   EXPECT_EQ("new_attr", attr);
366   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
367   EXPECT_EQ("new_attr_2", attr);
368 
369   EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
370   EXPECT_EQ("new_attr", attr);
371   EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
372 }
373 
374 // Convert edge iteration results into a sorted string.
EdgeIter(const Graph & g)375 static string EdgeIter(const Graph& g) {
376   std::vector<std::pair<int, int> > edges;
377   for (const Edge* e : g.edges()) {
378     edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
379   }
380   std::sort(edges.begin(), edges.end());
381   string result;
382   for (auto& p : edges) {
383     strings::StrAppend(&result, p.first, "->", p.second, ";");
384   }
385   return result;
386 }
387 
TEST_F(GraphTest,EdgeIteration)388 TEST_F(GraphTest, EdgeIteration) {
389   EXPECT_EQ("0->1;", EdgeIter(graph_));
390 
391   Node* a = FromNodeDef("A", "OneInputTwoOutputs", 1);
392   Node* b = FromNodeDef("B", "OneInput", 1);
393   EXPECT_EQ("0->1;", EdgeIter(graph_));  // Since a,b are currently disconnected
394 
395   graph_.AddEdge(a, 0, b, 0);
396   EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
397 
398   graph_.AddControlEdge(graph_.source_node(), a);
399   graph_.AddControlEdge(b, graph_.sink_node());
400   EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
401 
402   graph_.AddEdge(a, 1, a, 0);
403   EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
404   VerifyGraphStats();
405 }
406 
TEST_F(GraphTest,NewName)407 TEST_F(GraphTest, NewName) {
408   string a1 = graph_.NewName("A");
409   string a2 = graph_.NewName("A");
410   string b1 = graph_.NewName("B");
411   EXPECT_NE(a1, a2);
412   EXPECT_NE(a1, b1);
413   EXPECT_NE(a2, b1);
414   EXPECT_TRUE(absl::StartsWith(a1, "A")) << a1;
415 }
416 
TEST_F(GraphTest,IsValidNode)417 TEST_F(GraphTest, IsValidNode) {
418   // Add 1 node to graph_
419   Node* g1_node1;
420   TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
421 
422   // Add 2 nodes to graph2
423   Graph graph2(OpRegistry::Global());
424   Node* g2_node1;
425   Node* g2_node2;
426   TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
427   TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
428 
429   // nullptr
430   Status s = graph_.IsValidNode(nullptr);
431   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
432   EXPECT_EQ(string("Node is null"), s.error_message());
433 
434   // node id_ is too high
435   s = graph_.IsValidNode(g2_node2);
436   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
437   EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
438             s.error_message());
439 
440   // valid id_ but different ptr
441   s = graph_.IsValidNode(g2_node1);
442   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
443   EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
444                    "Does it belong to a different graph?"),
445             s.error_message());
446 }
447 
TEST_F(GraphTest,AddControlEdge)448 TEST_F(GraphTest, AddControlEdge) {
449   FromGraphDef(
450       "node { name: 'A' op: 'OneOutput' }"
451       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
452       "node { name: 'C' op: 'NoOp' } ");
453   Node* a = FindNode("A");
454   Node* b = FindNode("B");
455   Node* c = FindNode("C");
456 
457   // Add a control edge.
458   const Edge* edge = graph_.AddControlEdge(c, a);
459   ASSERT_TRUE(edge != nullptr);
460   // Check newly-created edge.
461   EXPECT_EQ(edge->src(), c);
462   EXPECT_EQ(edge->src_output(), Graph::kControlSlot);
463   EXPECT_EQ(edge->dst(), a);
464   EXPECT_EQ(edge->dst_input(), Graph::kControlSlot);
465   // Check A's NodeDef.
466   ASSERT_EQ(a->def().input_size(), 1);
467   EXPECT_EQ(a->def().input(0), "^C");
468 
469   // Can add control edge redundant with data edge.
470   edge = graph_.AddControlEdge(a, b);
471   EXPECT_TRUE(edge != nullptr);
472   ASSERT_EQ(b->def().input_size(), 2);
473   EXPECT_EQ(b->def().input(0), "A:0");
474   EXPECT_EQ(b->def().input(1), "^A");
475 
476   // Doesn't add edge redundant with control edge.
477   edge = graph_.AddControlEdge(a, b);
478   EXPECT_TRUE(edge == nullptr);
479   EXPECT_EQ(b->def().input_size(), 2);
480 
481   // Can add redundant control edge with allow_duplicates.
482   edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true);
483   EXPECT_TRUE(edge != nullptr);
484   // create_duplicate causes the NodeDef not to be updated.
485   ASSERT_EQ(b->def().input_size(), 2);
486   EXPECT_EQ(b->def().input(0), "A:0");
487   EXPECT_EQ(b->def().input(1), "^A");
488 
489   // Add control edge from source.
490   edge = graph_.AddControlEdge(graph_.source_node(), b);
491   EXPECT_TRUE(edge != nullptr);
492   // Check that we don't include source input in the NodeDef.
493   EXPECT_EQ(b->def().input_size(), 2);
494   // Doesn't add redundant edge.
495   edge = graph_.AddControlEdge(graph_.source_node(), b);
496   EXPECT_TRUE(edge == nullptr);
497   EXPECT_EQ(b->def().input_size(), 2);
498 }
499 
TEST_F(GraphTest,RemoveControlEdge)500 TEST_F(GraphTest, RemoveControlEdge) {
501   FromGraphDef(
502       "node { name: 'A' op: 'OneOutput' }"
503       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
504       "node { name: 'C' op: 'NoOp' } ");
505   Node* a = FindNode("A");
506   Node* b = FindNode("B");
507   Node* c = FindNode("C");
508 
509   // Add a control edge.
510   const Edge* edge_1 = graph_.AddControlEdge(c, a);
511   const Edge* edge_2 = graph_.AddControlEdge(a, b);
512   ASSERT_TRUE(edge_1 != nullptr);
513   ASSERT_TRUE(edge_2 != nullptr);
514 
515   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a));
516   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
517 
518   graph_.RemoveControlEdge(edge_1);
519   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
520   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
521 
522   graph_.RemoveControlEdge(edge_2);
523   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
524   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b));
525 
526   // Test removing a duplicate control edge.
527   // Note that unless allow_duplicates is true, the duplicate edge
528   // will not be added. That's why we expect edge_4 to be a null
529   // pointer. We are not testing with allow_duplicates set to true,
530   // as that is a highly unlikely use case that does not make much
531   // sense.
532   const Edge* edge_3 = graph_.AddControlEdge(c, a);
533   const Edge* edge_4 = graph_.AddControlEdge(c, a);
534   ASSERT_TRUE(edge_3 != nullptr);
535   ASSERT_TRUE(edge_4 == nullptr);
536 
537   graph_.RemoveControlEdge(edge_3);
538   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
539 }
540 
TEST_F(GraphTest,UpdateEdge)541 TEST_F(GraphTest, UpdateEdge) {
542   // Build a little graph
543   Node* a = FromNodeDef("A", "OneOutput", 0);
544   Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
545   Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
546   Node* d = FromNodeDef("D", "OneInput", 1);
547 
548   graph_.AddControlEdge(graph_.source_node(), a);
549   graph_.AddControlEdge(a, graph_.sink_node());
550   graph_.AddEdge(a, 0, c, 0);
551 
552   graph_.AddControlEdge(c, graph_.sink_node());
553   graph_.AddEdge(c, 0, b, 0);
554   graph_.AddEdge(c, 1, d, 0);
555 
556   // Initial edge connections
557   EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
558 
559   // Update the inputs, expect that Edge a to b (2->3) is now in the graph
560   // and c to b (4->3) no longer appears.
561   TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
562   // Check that the edge is connecting the correct nodes.
563   EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
564 
565   // Update a's 0th output again.
566   TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
567   EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
568 
569   // Update a's 1st output which is out of range.
570   Status s = graph_.UpdateEdge(a, 1, d, 0);
571   EXPECT_FALSE(s.ok());
572   EXPECT_EQ(
573       s.error_message(),
574       "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
575 
576   // Update a's 1st input which is out of range.
577   s = graph_.UpdateEdge(c, 0, a, 0);
578   EXPECT_FALSE(s.ok());
579   EXPECT_EQ(
580       s.error_message(),
581       "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
582 }
583 
TEST_F(GraphTest,InputEdges)584 TEST_F(GraphTest, InputEdges) {
585   Node* a = FromNodeDef("A", "OneOutput", 0);
586   Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
587   graph_.AddEdge(a, 0, b, 0);
588   std::vector<const Edge*> edges;
589   EXPECT_EQ(error::INVALID_ARGUMENT, b->input_edges(&edges).code());
590   graph_.AddEdge(a, 0, b, 1);
591   TF_EXPECT_OK(b->input_edges(&edges));
592 }
593 
TEST_F(GraphTest,AddFunctionLibrary)594 TEST_F(GraphTest, AddFunctionLibrary) {
595   // Basic functionality
596   FunctionDefLibrary proto;
597   *proto.add_function() = test::function::XTimesTwo();
598   *proto.add_function() = test::function::XTimesFour();
599   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
600   EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
601   EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
602 
603   // Duplicate functions are ignored
604   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
605   EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
606   EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
607 
608   // Duplicate names corresponding to different functions trigger an error
609   FunctionDefLibrary error_proto = proto;
610   *error_proto.mutable_function(0)->add_node_def() =
611       error_proto.function(0).node_def(0);
612   Status s = graph_.AddFunctionLibrary(error_proto);
613   EXPECT_FALSE(s.ok());
614   EXPECT_EQ(s.error_message(),
615             "Cannot add function 'XTimesTwo' because a different function with "
616             "the same name already exists.");
617 
618   // Function with same name as an existing op triggers an error
619   error_proto = proto;
620   error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
621   s = graph_.AddFunctionLibrary(error_proto);
622   EXPECT_FALSE(s.ok());
623   EXPECT_EQ(s.error_message(),
624             "Cannot add function 'Add' because an op with the same name "
625             "already exists.");
626 
627   // Adding a gradient function to an existing function is ok
628   GradientDef* grad = proto.add_gradient();
629   grad->set_function_name("XTimesTwo");
630   grad->set_gradient_func("Undefined");  // undefined funcs in grads are ok
631   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
632   EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
633 
634   // Duplicate gradients are ignored
635   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
636   EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
637 
638   // Conflicting gradient triggers an error
639   error_proto = proto;
640   error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
641   s = graph_.AddFunctionLibrary(error_proto);
642   EXPECT_FALSE(s.ok());
643   EXPECT_EQ(s.error_message(),
644             "Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
645             "because it already has gradient function 'Undefined'");
646 }
647 
TEST_F(GraphTest,BuildNodeNameIndex)648 TEST_F(GraphTest, BuildNodeNameIndex) {
649   FromGraphDef(
650       "node { name: 'A' op: 'OneOutput' }"
651       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
652       "node { name: 'C' op: 'NoOp' } ");
653 
654   auto node_name_index = graph_.BuildNodeNameIndex();
655   EXPECT_EQ(node_name_index.size(), 5);
656 
657   std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
658   for (const string& node_name : node_names) {
659     EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
660     EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
661   }
662 }
663 
BM_InEdgeIteration(::testing::benchmark::State & state)664 void BM_InEdgeIteration(::testing::benchmark::State& state) {
665   const int num_nodes = state.range(0);
666   const int num_edges_per_node = state.range(1);
667   const GraphDef graph_def =
668       test::CreateGraphDef(num_nodes, num_edges_per_node);
669   Graph graph(OpRegistry::Global());
670   GraphConstructorOptions opts;
671   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
672 
673   int64 sum = 0;
674   for (auto s : state) {
675     for (const Node* node : graph.nodes()) {
676       for (auto e : node->in_edges()) {
677         sum += e->id();
678       }
679     }
680   }
681   VLOG(1) << sum;
682 }
683 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
684 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
685 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 2);
686 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 2);
687 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 2);
688 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 4);
689 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 4);
690 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 4);
691 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 4);
692 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 4);
693 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 8);
694 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 8);
695 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 8);
696 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 8);
697 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 8);
698 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 16);
699 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 16);
700 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
701 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
702 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
703 
BM_GraphCreation(::testing::benchmark::State & state)704 void BM_GraphCreation(::testing::benchmark::State& state) {
705   const int num_nodes = state.range(0);
706   const int num_edges_per_node = state.range(1);
707   const GraphDef graph_def =
708       test::CreateGraphDef(num_nodes, num_edges_per_node);
709   const auto registry = OpRegistry::Global();
710   GraphConstructorOptions opts;
711   // Warmup step.
712   Graph graph(registry);
713   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
714   int64 sum = 0;
715   for (auto s : state) {
716     Graph graph(registry);
717     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
718     sum += graph.num_node_ids();
719   }
720   VLOG(1) << sum;
721 }
722 BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
723 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
724 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 2);
725 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 2);
726 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 2);
727 BENCHMARK(BM_GraphCreation)->ArgPair(10, 4);
728 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 4);
729 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 4);
730 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 4);
731 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 4);
732 BENCHMARK(BM_GraphCreation)->ArgPair(10, 8);
733 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 8);
734 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 8);
735 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 8);
736 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 8);
737 BENCHMARK(BM_GraphCreation)->ArgPair(10, 16);
738 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 16);
739 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
740 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
741 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
742 
BM_ToGraphDef(::testing::benchmark::State & state)743 void BM_ToGraphDef(::testing::benchmark::State& state) {
744   const int num_nodes = state.range(0);
745   const int num_edges_per_node = state.range(1);
746   const GraphDef graph_def =
747       test::CreateGraphDef(num_nodes, num_edges_per_node);
748   const auto registry = OpRegistry::Global();
749   GraphConstructorOptions opts;
750   // Warmup step.
751   Graph graph(registry);
752   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
753   int64 sum = 0;
754   for (auto s : state) {
755     GraphDef graph_def;
756     graph.ToGraphDef(&graph_def);
757     sum += graph_def.node_size();
758   }
759   VLOG(1) << sum;
760 }
761 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
762 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
763 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2);
764 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2);
765 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2);
766 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4);
767 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4);
768 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4);
769 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4);
770 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4);
771 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8);
772 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8);
773 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8);
774 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8);
775 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8);
776 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16);
777 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16);
778 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
779 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
780 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
781 
BM_RemoveNode(::testing::benchmark::State & state)782 void BM_RemoveNode(::testing::benchmark::State& state) {
783   const int num_nodes = state.range(0);
784   const int num_edges_per_node = state.range(1);
785   const GraphDef graph_def =
786       test::CreateGraphDef(num_nodes, num_edges_per_node);
787   const auto registry = OpRegistry::Global();
788   GraphConstructorOptions opts;
789   for (auto s : state) {
790     state.PauseTiming();
791     Graph graph(registry);
792     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
793     state.ResumeTiming();
794     for (Node* n : graph.op_nodes()) {
795       graph.RemoveNode(n);
796     }
797   }
798 }
799 BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
800 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
801 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
802 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
803 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
804 BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
805 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
806 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
807 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
808 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
809 BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
810 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
811 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
812 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
813 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
814 BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
815 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
816 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
817 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
818 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
819 
820 }  // namespace
821 }  // namespace tensorflow
822