1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/framework/function_testlib.h"
23 #include "tensorflow/core/graph/benchmark_testlib.h"
24 #include "tensorflow/core/graph/graph_constructor.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/random/simple_philox.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 REGISTER_OP("OneInput").Input("x: float");
39 
40 REGISTER_OP("OneOutput").Output("y: float");
41 
42 REGISTER_OP("OneInputTwoOutputs")
43     .Input("x: float")
44     .Output("y: float")
45     .Output("z: float");
46 
47 REGISTER_OP("TwoInputsOneOutput")
48     .Input("x: float")
49     .Input("y: float")
50     .Output("z: float");
51 
52 class GraphTest : public ::testing::Test {
53  protected:
GraphTest()54   GraphTest() : graph_(OpRegistry::Global()) {}
~GraphTest()55   ~GraphTest() override {}
56 
VerifyNodes(Node * node,const std::vector<Node * > & expected_in,const std::vector<Node * > & expected_out)57   static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in,
58                           const std::vector<Node*>& expected_out) {
59     std::vector<Node*> in;
60     for (const Edge* e : node->in_edges()) {
61       in.push_back(e->src());
62     }
63     EXPECT_EQ(Stringify(expected_in), Stringify(in));
64 
65     std::vector<Node*> out;
66     for (const Edge* e : node->out_edges()) {
67       out.push_back(e->dst());
68     }
69     EXPECT_EQ(Stringify(expected_out), Stringify(out));
70   }
71 
VerifyGraphStats()72   void VerifyGraphStats() {
73     int nodes = 0;
74     for (const Node* n : graph_.nodes()) {
75       VLOG(1) << n->id();
76       ++nodes;
77     }
78     EXPECT_EQ(nodes, graph_.num_nodes());
79     int edges = 0;
80     for (const Edge* e : graph_.edges()) {
81       VLOG(1) << e->id();
82       ++edges;
83     }
84     EXPECT_EQ(edges, graph_.num_edges());
85   }
86 
AddNodeWithName(const string & name)87   Node* AddNodeWithName(const string& name) {
88     Node* node;
89     TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
90     return node;
91   }
92 
FromNodeDef(const string & name,const string & node_type,int num_inputs)93   Node* FromNodeDef(const string& name, const string& node_type,
94                     int num_inputs) {
95     auto builder = NodeDefBuilder(name, node_type);
96     for (int i = 0; i < num_inputs; ++i) {
97       builder = builder.Input(strings::StrCat("node_", i), i, DT_FLOAT);
98     }
99 
100     NodeDef node_def;
101     TF_CHECK_OK(builder.Finalize(&node_def));
102 
103     Status s;
104     Node* node = graph_.AddNode(node_def, &s);
105     TF_CHECK_OK(s);
106     return node;
107   }
108 
FromGraphDef(const string & gdef_ascii)109   void FromGraphDef(const string& gdef_ascii) {
110     GraphDef gdef;
111     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef));
112     GraphConstructorOptions opts;
113     TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef, &graph_));
114   }
115 
FindNode(const string & name)116   Node* FindNode(const string& name) {
117     for (Node* node : graph_.nodes()) {
118       if (node->name() == name) return node;
119     }
120     LOG(FATAL) << name;
121   }
122 
ControlEdgeExistsInGraphOrNodeDef(const Node * src,const Node * dst)123   bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
124     for (const Edge* e : dst->in_edges()) {
125       if (e->IsControlEdge() && e->src() == src &&
126           e->src_output() == Graph::kControlSlot &&
127           e->dst_input() == Graph::kControlSlot) {
128         return true;
129       }
130     }
131     std::string control_edge_name = strings::StrCat("^", src->name());
132     for (int i = 0; i < dst->def().input_size(); ++i) {
133       if (dst->def().input(i) == control_edge_name) {
134         return true;
135       }
136     }
137     return false;
138   }
139 
140   Graph graph_;
141 
142  private:
143   // Convert a list of nodes to a sorted list of strings so failure messages
144   // are readable.
Stringify(const std::vector<Node * > & nodes)145   static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
146     std::vector<string> result;
147     result.reserve(nodes.size());
148     for (Node* n : nodes) {
149       result.push_back(n->DebugString());
150     }
151     std::sort(result.begin(), result.end());
152     return result;
153   }
154 };
155 
TEST_F(GraphTest,Constructor)156 TEST_F(GraphTest, Constructor) {
157   Node* source = graph_.source_node();
158   EXPECT_NE(source, nullptr);
159   Node* sink = graph_.sink_node();
160   EXPECT_NE(sink, nullptr);
161   VerifyNodes(source, {}, {sink});
162   VerifyNodes(sink, {source}, {});
163   EXPECT_EQ(2, graph_.num_node_ids());
164   VerifyGraphStats();
165 }
166 
TEST_F(GraphTest,RemoveThenAdd)167 TEST_F(GraphTest, RemoveThenAdd) {
168   AddNodeWithName("A");
169   Node* b = AddNodeWithName("B");
170   const int b_id = b->id();
171   AddNodeWithName("C");
172   EXPECT_EQ(5, graph_.num_node_ids());
173   graph_.RemoveNode(b);
174   EXPECT_EQ(5, graph_.num_node_ids());
175   Node* d = AddNodeWithName("D");
176   EXPECT_NE(b_id, d->id());  // Ids should not be reused.
177   EXPECT_EQ(6, graph_.num_node_ids());
178   VerifyGraphStats();
179 }
180 
TEST_F(GraphTest,InNodesAndOutNodes)181 TEST_F(GraphTest, InNodesAndOutNodes) {
182   Node* a = FromNodeDef("A", "OneOutput", 0);
183   Node* b = AddNodeWithName("B");
184   Node* c = FromNodeDef("C", "OneInput", 1);
185   graph_.RemoveNode(b);
186   Node* d = AddNodeWithName("D");
187 
188   const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
189   graph_.AddControlEdge(a, graph_.sink_node());
190   graph_.AddEdge(a, 0, c, 0);
191   graph_.AddControlEdge(c, graph_.sink_node());
192 
193   EXPECT_EQ("A", a->name());
194   VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
195 
196   EXPECT_EQ("C", c->name());
197   VerifyNodes(c, {a}, {graph_.sink_node()});
198 
199   EXPECT_EQ("D", d->name());
200   VerifyNodes(d, {}, {});
201 
202   VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
203   VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
204 
205   graph_.RemoveEdge(source_to_a);
206   VerifyNodes(a, {}, {c, graph_.sink_node()});
207   VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()});  // no more a
208 
209   graph_.RemoveNode(c);
210   VerifyNodes(a, {}, {graph_.sink_node()});                        // no more c
211   VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {});  // no more c
212   EXPECT_EQ(6, graph_.num_node_ids());
213   EXPECT_EQ(5, graph_.num_edge_ids());
214   VerifyGraphStats();
215 }
216 
TEST_F(GraphTest,NodeByIndex)217 TEST_F(GraphTest, NodeByIndex) {
218   Node* a = FromNodeDef("A", "OneOutput", 0);
219   Node* c = FromNodeDef("C", "OneInput", 1);
220   graph_.AddEdge(a, 0, c, 0);
221 
222   // Ask for 'a' from 'c' by index.
223   const Node* a_copy;
224   TF_ASSERT_OK(c->input_node(0, &a_copy));
225   EXPECT_EQ(a, a_copy);
226 
227   const Edge* e;
228   TF_ASSERT_OK(c->input_edge(0, &e));
229   EXPECT_EQ(0, e->dst_input());
230   EXPECT_EQ(a, e->src());
231   EXPECT_EQ(c, e->dst());
232   EXPECT_EQ(0, e->src_output());
233 
234   Node* t = FromNodeDef("T", "TwoInputsOneOutput", 2);
235   graph_.AddEdge(a, 0, t, 0);
236   // Weird self edge
237   graph_.AddEdge(t, 0, t, 1);
238 
239   const Node* t_0;
240   const Node* t_1;
241   TF_ASSERT_OK(t->input_node(0, &t_0));
242   EXPECT_EQ(a, t_0);
243   TF_ASSERT_OK(t->input_node(1, &t_1));
244   EXPECT_EQ(t, t_1);
245 
246   TF_ASSERT_OK(t->input_edge(1, &e));
247   EXPECT_EQ(1, e->dst_input());
248   EXPECT_EQ(t, e->src());
249 
250   std::vector<const Edge*> t_input_edges;
251   TF_ASSERT_OK(t->input_edges(&t_input_edges));
252   ASSERT_EQ(2, t_input_edges.size());
253   EXPECT_EQ(a, t_input_edges[0]->src());
254   EXPECT_EQ(e, t_input_edges[1]);
255 
256   // Check out of bounds access
257   EXPECT_FALSE(c->input_node(1, &a_copy).ok());
258   EXPECT_FALSE(c->input_node(-1, &a_copy).ok());
259 
260   graph_.RemoveNode(a);
261 
262   // 'c's input_node entry should be invalidated.
263   Status s = c->input_node(0, &a_copy);
264   EXPECT_FALSE(s.ok());
265 
266   // Add two new nodes.
267   Node* a_new = FromNodeDef("A_new", "OneOutput", 0);
268   Node* b_new = FromNodeDef("B_new", "OneOutput", 0);
269 
270   // Connect one up to c.
271   graph_.AddEdge(a_new, 0, c, 0);
272   const Edge* a_new_c_edge;
273   TF_ASSERT_OK(c->input_edge(0, &a_new_c_edge));
274 
275   // Connect up the second edge
276   graph_.AddEdge(b_new, 0, c, 0);
277   const Edge* b_new_c_edge;
278   TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
279 
280   // Now remove the old one
281   graph_.RemoveEdge(a_new_c_edge);
282 
283   // Check that the second edge can still be retrieved
284   TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
285 
286   std::vector<const Edge*> c_input_edges;
287   TF_ASSERT_OK(c->input_edges(&c_input_edges));
288   ASSERT_EQ(1, c_input_edges.size());
289   EXPECT_EQ(b_new_c_edge, c_input_edges[0]);
290 }
291 
TEST_F(GraphTest,NodeIteration)292 TEST_F(GraphTest, NodeIteration) {
293   // Set up the graph with some holes due to removals.
294   Node* a = FromNodeDef("A", "OneOutput", 0);
295   Node* b = AddNodeWithName("B");
296   Node* c = FromNodeDef("C", "OneInput", 1);
297   graph_.RemoveNode(b);
298   Node* d = AddNodeWithName("D");
299   const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
300   graph_.AddControlEdge(a, graph_.sink_node());
301   graph_.AddEdge(a, 0, c, 0);
302   graph_.AddControlEdge(c, graph_.sink_node());
303   graph_.RemoveEdge(source_to_a);
304   graph_.RemoveNode(c);
305 
306   // expected = set of all node DebugStrings we expect in the graph
307   std::set<string> expected;
308   expected.insert(graph_.source_node()->DebugString());
309   expected.insert(a->DebugString());
310   expected.insert(d->DebugString());
311   expected.insert(graph_.sink_node()->DebugString());
312 
313   // Verify that iterating through ids gets the same set of nodes.
314   std::set<string> actual;
315   for (int id = 0; id < graph_.num_node_ids(); ++id) {
316     Node* node = graph_.FindNodeId(id);
317     if (node != nullptr) {
318       actual.insert(node->DebugString());
319     }
320   }
321   EXPECT_EQ(expected, actual);
322 
323   // Verify that range-based for loop gets the same set of nodes.
324   actual.clear();
325   for (Node* node : graph_.nodes()) {
326     actual.insert(node->DebugString());
327   }
328   EXPECT_EQ(expected, actual);
329   VerifyGraphStats();
330 }
331 
CheckType(Node * node,bool b)332 static void CheckType(Node* node, bool b) {
333   EXPECT_TRUE(b) << node->DebugString();
334   // Make sure none of the other IsFoo() methods return true.
335   int count = 0;
336   if (node->IsSource()) count++;
337   if (node->IsSink()) count++;
338   if (node->IsOp()) count++;
339   EXPECT_EQ(1, count) << node->DebugString();
340 }
341 
TEST_F(GraphTest,Type)342 TEST_F(GraphTest, Type) {
343   Node* op = AddNodeWithName("A");
344   CheckType(graph_.source_node(), graph_.source_node()->IsSource());
345   CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
346   CheckType(op, op->IsOp());
347   VerifyGraphStats();
348 }
349 
TEST_F(GraphTest,AddAttr)350 TEST_F(GraphTest, AddAttr) {
351   Node* n1 = AddNodeWithName("A");
352 
353   n1->AddAttr("_a", "new_attr");
354 
355   string attr;
356   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
357   EXPECT_EQ("new_attr", attr);
358 
359   Node* n2 = graph_.CopyNode(n1);
360 
361   n1->AddAttr("_b", "new_attr_2");
362 
363   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
364   EXPECT_EQ("new_attr", attr);
365   EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
366   EXPECT_EQ("new_attr_2", attr);
367 
368   EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
369   EXPECT_EQ("new_attr", attr);
370   EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
371 }
372 
373 // Convert edge iteration results into a sorted string.
EdgeIter(const Graph & g)374 static string EdgeIter(const Graph& g) {
375   std::vector<std::pair<int, int> > edges;
376   for (const Edge* e : g.edges()) {
377     edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
378   }
379   std::sort(edges.begin(), edges.end());
380   string result;
381   for (auto& p : edges) {
382     strings::StrAppend(&result, p.first, "->", p.second, ";");
383   }
384   return result;
385 }
386 
TEST_F(GraphTest,EdgeIteration)387 TEST_F(GraphTest, EdgeIteration) {
388   EXPECT_EQ("0->1;", EdgeIter(graph_));
389 
390   Node* a = FromNodeDef("A", "OneInputTwoOutputs", 1);
391   Node* b = FromNodeDef("B", "OneInput", 1);
392   EXPECT_EQ("0->1;", EdgeIter(graph_));  // Since a,b are currently disconnected
393 
394   graph_.AddEdge(a, 0, b, 0);
395   EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
396 
397   graph_.AddControlEdge(graph_.source_node(), a);
398   graph_.AddControlEdge(b, graph_.sink_node());
399   EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
400 
401   graph_.AddEdge(a, 1, a, 0);
402   EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
403   VerifyGraphStats();
404 }
405 
TEST_F(GraphTest,NewName)406 TEST_F(GraphTest, NewName) {
407   string a1 = graph_.NewName("A");
408   string a2 = graph_.NewName("A");
409   string b1 = graph_.NewName("B");
410   EXPECT_NE(a1, a2);
411   EXPECT_NE(a1, b1);
412   EXPECT_NE(a2, b1);
413   EXPECT_TRUE(str_util::StartsWith(a1, "A")) << a1;
414 }
415 
TEST_F(GraphTest,IsValidNode)416 TEST_F(GraphTest, IsValidNode) {
417   // Add 1 node to graph_
418   Node* g1_node1;
419   TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
420 
421   // Add 2 nodes to graph2
422   Graph graph2(OpRegistry::Global());
423   Node* g2_node1;
424   Node* g2_node2;
425   TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
426   TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
427 
428   // nullptr
429   Status s = graph_.IsValidNode(nullptr);
430   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
431   EXPECT_EQ(string("Node is null"), s.error_message());
432 
433   // node id_ is too high
434   s = graph_.IsValidNode(g2_node2);
435   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
436   EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
437             s.error_message());
438 
439   // valid id_ but different ptr
440   s = graph_.IsValidNode(g2_node1);
441   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
442   EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
443                    "Does it belong to a different graph?"),
444             s.error_message());
445 }
446 
TEST_F(GraphTest,AddControlEdge)447 TEST_F(GraphTest, AddControlEdge) {
448   FromGraphDef(
449       "node { name: 'A' op: 'OneOutput' }"
450       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
451       "node { name: 'C' op: 'NoOp' } ");
452   Node* a = FindNode("A");
453   Node* b = FindNode("B");
454   Node* c = FindNode("C");
455 
456   // Add a control edge.
457   const Edge* edge = graph_.AddControlEdge(c, a);
458   ASSERT_TRUE(edge != nullptr);
459   // Check newly-created edge.
460   EXPECT_EQ(edge->src(), c);
461   EXPECT_EQ(edge->src_output(), Graph::kControlSlot);
462   EXPECT_EQ(edge->dst(), a);
463   EXPECT_EQ(edge->dst_input(), Graph::kControlSlot);
464   // Check A's NodeDef.
465   ASSERT_EQ(a->def().input_size(), 1);
466   EXPECT_EQ(a->def().input(0), "^C");
467 
468   // Can add control edge redundant with data edge.
469   edge = graph_.AddControlEdge(a, b);
470   EXPECT_TRUE(edge != nullptr);
471   ASSERT_EQ(b->def().input_size(), 2);
472   EXPECT_EQ(b->def().input(0), "A:0");
473   EXPECT_EQ(b->def().input(1), "^A");
474 
475   // Doesn't add edge redundant with control edge.
476   edge = graph_.AddControlEdge(a, b);
477   EXPECT_TRUE(edge == nullptr);
478   EXPECT_EQ(b->def().input_size(), 2);
479 
480   // Can add redundant control edge with allow_duplicates.
481   edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true);
482   EXPECT_TRUE(edge != nullptr);
483   // create_duplicate causes the NodeDef not to be updated.
484   ASSERT_EQ(b->def().input_size(), 2);
485   EXPECT_EQ(b->def().input(0), "A:0");
486   EXPECT_EQ(b->def().input(1), "^A");
487 
488   // Add control edge from source.
489   edge = graph_.AddControlEdge(graph_.source_node(), b);
490   EXPECT_TRUE(edge != nullptr);
491   // Check that we don't include source input in the NodeDef.
492   EXPECT_EQ(b->def().input_size(), 2);
493   // Doesn't add redundant edge.
494   edge = graph_.AddControlEdge(graph_.source_node(), b);
495   EXPECT_TRUE(edge == nullptr);
496   EXPECT_EQ(b->def().input_size(), 2);
497 }
498 
TEST_F(GraphTest,RemoveControlEdge)499 TEST_F(GraphTest, RemoveControlEdge) {
500   FromGraphDef(
501       "node { name: 'A' op: 'OneOutput' }"
502       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
503       "node { name: 'C' op: 'NoOp' } ");
504   Node* a = FindNode("A");
505   Node* b = FindNode("B");
506   Node* c = FindNode("C");
507 
508   // Add a control edge.
509   const Edge* edge_1 = graph_.AddControlEdge(c, a);
510   const Edge* edge_2 = graph_.AddControlEdge(a, b);
511   ASSERT_TRUE(edge_1 != nullptr);
512   ASSERT_TRUE(edge_2 != nullptr);
513 
514   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a));
515   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
516 
517   graph_.RemoveControlEdge(edge_1);
518   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
519   ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
520 
521   graph_.RemoveControlEdge(edge_2);
522   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
523   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b));
524 
525   // Test removing a duplicate control edge.
526   // Note that unless allow_duplicates is true, the duplicate edge
527   // will not be added. That's why we expect edge_4 to be a null
528   // pointer. We are not testing with allow_duplicates set to true,
529   // as that is a highly unlikely use case that does not make much
530   // sense.
531   const Edge* edge_3 = graph_.AddControlEdge(c, a);
532   const Edge* edge_4 = graph_.AddControlEdge(c, a);
533   ASSERT_TRUE(edge_3 != nullptr);
534   ASSERT_TRUE(edge_4 == nullptr);
535 
536   graph_.RemoveControlEdge(edge_3);
537   ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
538 }
539 
TEST_F(GraphTest,UpdateEdge)540 TEST_F(GraphTest, UpdateEdge) {
541   // Build a little graph
542   Node* a = FromNodeDef("A", "OneOutput", 0);
543   Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
544   Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
545   Node* d = FromNodeDef("D", "OneInput", 1);
546 
547   graph_.AddControlEdge(graph_.source_node(), a);
548   graph_.AddControlEdge(a, graph_.sink_node());
549   graph_.AddEdge(a, 0, c, 0);
550 
551   graph_.AddControlEdge(c, graph_.sink_node());
552   graph_.AddEdge(c, 0, b, 0);
553   graph_.AddEdge(c, 1, d, 0);
554 
555   // Initial edge connections
556   EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
557 
558   // Update the inputs, expect that Edge a to b (2->3) is now in the graph
559   // and c to b (4->3) no longer appears.
560   TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
561   // Check that the edge is connecting the correct nodes.
562   EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
563 
564   // Update a's 0th output again.
565   TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
566   EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
567 
568   // Update a's 1st output which is out of range.
569   Status s = graph_.UpdateEdge(a, 1, d, 0);
570   EXPECT_FALSE(s.ok());
571   EXPECT_EQ(
572       s.error_message(),
573       "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
574 
575   // Update a's 1st input which is out of range.
576   s = graph_.UpdateEdge(c, 0, a, 0);
577   EXPECT_FALSE(s.ok());
578   EXPECT_EQ(
579       s.error_message(),
580       "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
581 }
582 
TEST_F(GraphTest,InputEdges)583 TEST_F(GraphTest, InputEdges) {
584   Node* a = FromNodeDef("A", "OneOutput", 0);
585   Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
586   graph_.AddEdge(a, 0, b, 0);
587   std::vector<const Edge*> edges;
588   EXPECT_EQ(error::INVALID_ARGUMENT, b->input_edges(&edges).code());
589   graph_.AddEdge(a, 0, b, 1);
590   TF_EXPECT_OK(b->input_edges(&edges));
591 }
592 
TEST_F(GraphTest,AddFunctionLibrary)593 TEST_F(GraphTest, AddFunctionLibrary) {
594   // Basic functionality
595   FunctionDefLibrary proto;
596   *proto.add_function() = test::function::XTimesTwo();
597   *proto.add_function() = test::function::XTimesFour();
598   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
599   EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
600   EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
601 
602   // Duplicate functions are ignored
603   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
604   EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
605   EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
606 
607   // Duplicate names corresponding to different functions trigger an error
608   FunctionDefLibrary error_proto = proto;
609   *error_proto.mutable_function(0)->add_node_def() =
610       error_proto.function(0).node_def(0);
611   Status s = graph_.AddFunctionLibrary(error_proto);
612   EXPECT_FALSE(s.ok());
613   EXPECT_EQ(s.error_message(),
614             "Cannot add function 'XTimesTwo' because a different function with "
615             "the same name already exists.");
616 
617   // Function with same name as an existing op triggers an error
618   error_proto = proto;
619   error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
620   s = graph_.AddFunctionLibrary(error_proto);
621   EXPECT_FALSE(s.ok());
622   EXPECT_EQ(s.error_message(),
623             "Cannot add function 'Add' because an op with the same name "
624             "already exists.");
625 
626   // Adding a gradient function to an existing function is ok
627   GradientDef* grad = proto.add_gradient();
628   grad->set_function_name("XTimesTwo");
629   grad->set_gradient_func("Undefined");  // undefined funcs in grads are ok
630   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
631   EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
632 
633   // Duplicate gradients are ignored
634   TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
635   EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
636 
637   // Conflicting gradient triggers an error
638   error_proto = proto;
639   error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
640   s = graph_.AddFunctionLibrary(error_proto);
641   EXPECT_FALSE(s.ok());
642   EXPECT_EQ(s.error_message(),
643             "Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
644             "because it already has gradient function 'Undefined'");
645 }
646 
TEST_F(GraphTest,BuildNodeNameIndex)647 TEST_F(GraphTest, BuildNodeNameIndex) {
648   FromGraphDef(
649       "node { name: 'A' op: 'OneOutput' }"
650       "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
651       "node { name: 'C' op: 'NoOp' } ");
652 
653   auto node_name_index = graph_.BuildNodeNameIndex();
654   EXPECT_EQ(node_name_index.size(), 5);
655 
656   std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
657   for (const string& node_name : node_names) {
658     EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
659     EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
660   }
661 }
662 
BM_InEdgeIteration(int iters,int num_nodes,int num_edges_per_node)663 static void BM_InEdgeIteration(int iters, int num_nodes,
664                                int num_edges_per_node) {
665   testing::StopTiming();
666   const GraphDef graph_def =
667       test::CreateGraphDef(num_nodes, num_edges_per_node);
668   Graph graph(OpRegistry::Global());
669   GraphConstructorOptions opts;
670   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
671 
672   int64 sum = 0;
673   testing::StartTiming();
674   for (int i = 0; i < iters; ++i) {
675     for (const Node* node : graph.nodes()) {
676       for (auto e : node->in_edges()) {
677         sum += e->id();
678       }
679     }
680   }
681   VLOG(1) << sum;
682   testing::StopTiming();
683 }
684 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
685 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
686 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 2);
687 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 2);
688 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 2);
689 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 4);
690 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 4);
691 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 4);
692 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 4);
693 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 4);
694 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 8);
695 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 8);
696 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 8);
697 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 8);
698 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 8);
699 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 16);
700 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 16);
701 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
702 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
703 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
704 
BM_GraphCreation(int iters,int num_nodes,int num_edges_per_node)705 static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) {
706   testing::StopTiming();
707   const GraphDef graph_def =
708       test::CreateGraphDef(num_nodes, num_edges_per_node);
709   const auto registry = OpRegistry::Global();
710   GraphConstructorOptions opts;
711   // Warmup step.
712   Graph graph(registry);
713   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
714   int64 sum = 0;
715   testing::StartTiming();
716   for (int i = 0; i < iters; ++i) {
717     Graph graph(registry);
718     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
719     sum += graph.num_node_ids();
720   }
721   VLOG(1) << sum;
722   testing::StopTiming();
723 }
724 BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
725 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
726 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 2);
727 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 2);
728 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 2);
729 BENCHMARK(BM_GraphCreation)->ArgPair(10, 4);
730 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 4);
731 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 4);
732 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 4);
733 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 4);
734 BENCHMARK(BM_GraphCreation)->ArgPair(10, 8);
735 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 8);
736 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 8);
737 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 8);
738 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 8);
739 BENCHMARK(BM_GraphCreation)->ArgPair(10, 16);
740 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 16);
741 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
742 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
743 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
744 
BM_ToGraphDef(int iters,int num_nodes,int num_edges_per_node)745 static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) {
746   testing::StopTiming();
747   const GraphDef graph_def =
748       test::CreateGraphDef(num_nodes, num_edges_per_node);
749   const auto registry = OpRegistry::Global();
750   GraphConstructorOptions opts;
751   // Warmup step.
752   Graph graph(registry);
753   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
754   int64 sum = 0;
755   testing::StartTiming();
756   for (int i = 0; i < iters; ++i) {
757     GraphDef graph_def;
758     graph.ToGraphDef(&graph_def);
759     sum += graph_def.node_size();
760   }
761   VLOG(1) << sum;
762   testing::StopTiming();
763 }
764 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
765 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
766 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2);
767 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2);
768 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2);
769 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4);
770 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4);
771 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4);
772 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4);
773 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4);
774 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8);
775 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8);
776 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8);
777 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8);
778 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8);
779 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16);
780 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16);
781 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
782 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
783 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
784 
785 }  // namespace
786 }  // namespace tensorflow
787