1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17 #include "absl/strings/substitute.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/function_testlib.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/graph/tensor_id.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 namespace {
32 
33 using ::tensorflow::test::function::NDef;
34 using FDH = FunctionDefHelper;
35 
CompareNodeFanins(const MutableGraphView & graph,NodeDef * node,absl::Span<const string> fanins)36 void CompareNodeFanins(const MutableGraphView& graph, NodeDef* node,
37                        absl::Span<const string> fanins) {
38   ASSERT_EQ(node->input_size(), fanins.size());
39   for (int i = 0; i < node->input_size(); ++i) {
40     TensorId tensor_id = ParseTensorName(fanins[i]);
41     EXPECT_EQ(ParseTensorName(node->input(i)), tensor_id);
42     int port;
43     if (tensor_id.index() == Graph::kControlSlot) {
44       port = Graph::kControlSlot;
45     } else {
46       port = i;
47     }
48     MutableGraphView::InputPort input_port(node, port);
49     MutableGraphView::OutputPort output_port =
50         graph.GetOutputPort(tensor_id.node(), tensor_id.index());
51     EXPECT_TRUE(graph.GetFanin(input_port).contains(output_port));
52     EXPECT_TRUE(graph.GetFanout(output_port).contains(input_port));
53   }
54 }
55 
CompareNodeFanouts(const MutableGraphView & graph,NodeDef * node,absl::Span<const string> fanouts)56 void CompareNodeFanouts(const MutableGraphView& graph, NodeDef* node,
57                         absl::Span<const string> fanouts) {
58   auto node_fanouts =
59       graph.GetFanouts(*node, /*include_controlled_nodes=*/true);
60   EXPECT_EQ(node_fanouts.size(), fanouts.size());
61   for (const string& fanout : fanouts) {
62     TensorId tensor_id = ParseTensorName(fanout);
63     MutableGraphView::InputPort input_port(graph.GetNode(tensor_id.node()),
64                                            tensor_id.index());
65     EXPECT_TRUE(node_fanouts.contains(input_port));
66   }
67 }
68 
CheckNode(const MutableGraphView & graph,absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,FDH::AttrValueWrapper>> attrs,absl::Span<const string> fanins,absl::Span<const string> fanouts)69 void CheckNode(const MutableGraphView& graph, absl::string_view node_name,
70                absl::string_view op, absl::string_view device,
71                absl::Span<const std::pair<string, FDH::AttrValueWrapper>> attrs,
72                absl::Span<const string> fanins,
73                absl::Span<const string> fanouts) {
74   NodeDef* node = graph.GetNode(node_name);
75   ASSERT_NE(node, nullptr);
76   EXPECT_EQ(node->op(), op);
77   EXPECT_EQ(node->device(), device);
78   EXPECT_EQ(node->attr_size(), attrs.size());
79   for (const auto& attr : attrs) {
80     auto it = node->attr().find(attr.first);
81     ASSERT_NE(it, node->attr().end());
82     EXPECT_TRUE(AreAttrValuesEqual(it->second, attr.second.proto));
83   }
84   CompareNodeFanins(graph, node, fanins);
85   CompareNodeFanouts(graph, node, fanouts);
86 }
87 
CheckGraph(const MutableGraphView & mutable_graph)88 void CheckGraph(const MutableGraphView& mutable_graph) {
89   GraphView immutable_graph(mutable_graph.graph());
90   EXPECT_EQ(mutable_graph.graph()->node_size(),
91             immutable_graph.graph()->node_size());
92   EXPECT_EQ(mutable_graph.graph(), immutable_graph.graph());
93 
94   auto check_edges =
95       [](const absl::flat_hash_set<MutableGraphView::Edge>& mutable_edges,
96          const absl::flat_hash_set<GraphView::Edge>& immutable_edges) {
97         EXPECT_EQ(mutable_edges.size(), immutable_edges.size());
98         for (const auto& fanin_edge : mutable_edges) {
99           GraphView::Edge immutable_edge(
100               {fanin_edge.src.node, fanin_edge.src.port_id},
101               {fanin_edge.dst.node, fanin_edge.dst.port_id});
102           EXPECT_TRUE(immutable_edges.contains(immutable_edge));
103         }
104       };
105 
106   // Check graph connectivity.
107   for (auto& node : *mutable_graph.graph()->mutable_node()) {
108     EXPECT_EQ(&node, immutable_graph.GetNode(node.name()));
109 
110     auto mutable_fanins =
111         mutable_graph.GetFanins(node, /*include_controlling_nodes=*/true);
112     auto immutable_fanins =
113         immutable_graph.GetFanins(node, /*include_controlling_nodes=*/true);
114     EXPECT_EQ(mutable_fanins.size(), immutable_fanins.size());
115     for (const auto& fanin : mutable_fanins) {
116       GraphView::OutputPort immutable_fanin(fanin.node, fanin.port_id);
117       EXPECT_TRUE(immutable_fanins.contains(immutable_fanin));
118     }
119 
120     auto mutable_fanouts =
121         mutable_graph.GetFanouts(node, /*include_controlled_nodes=*/true);
122     auto immutable_fanouts =
123         immutable_graph.GetFanouts(node, /*include_controlled_nodes=*/true);
124     EXPECT_EQ(mutable_fanouts.size(), immutable_fanouts.size());
125     for (const auto& fanout : mutable_fanouts) {
126       GraphView::InputPort immutable_fanout(fanout.node, fanout.port_id);
127       EXPECT_TRUE(immutable_fanouts.contains(immutable_fanout));
128     }
129 
130     auto mutable_fanin_edges =
131         mutable_graph.GetFaninEdges(node, /*include_controlling_edges=*/true);
132     auto immutable_fanin_edges =
133         immutable_graph.GetFaninEdges(node, /*include_controlling_edges=*/true);
134     check_edges(mutable_fanin_edges, immutable_fanin_edges);
135 
136     auto mutable_fanout_edges =
137         mutable_graph.GetFanoutEdges(node, /*include_controlled_edges=*/true);
138     auto immutable_fanout_edges =
139         immutable_graph.GetFanoutEdges(node, /*include_controlled_edges=*/true);
140     check_edges(mutable_fanout_edges, immutable_fanout_edges);
141   }
142 }
143 
TEST(MutableGraphViewTest,AddSubgraph)144 TEST(MutableGraphViewTest, AddSubgraph) {
145   GraphDef graph_def = test::function::GDef(
146       {
147           NDef("foo", "NotImportant", {}, {}),
148           NDef("bar", "NotImportant", {}, {}),
149           NDef("baz", "NotImportant", {"foo", "bar"}),
150       },
151       /*funcs=*/{});
152   MutableGraphView graph(&graph_def);
153 
154   // `s/bar` node has inputs that are valid only if we add subgraph into the
155   // original graph.
156   GraphDef subgraph = test::function::GDef(
157       {
158           NDef("s/n0", "NotImportant", {}, {}),
159           NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
160       },
161       /*funcs=*/{});
162 
163   TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
164 
165   // Fanins and fanouts must be updated for the nodes of the original graph, and
166   // added subgraph.
167   CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"baz:1", "s/n1"});
168   CheckNode(graph, "s/n1", "NotImportant", "", {}, {"bar", "s/n0"}, {});
169   CheckGraph(graph);
170 }
171 
TEST(MutableGraphViewTest,AddSubgraphAndAddFunction)172 TEST(MutableGraphViewTest, AddSubgraphAndAddFunction) {
173   GraphDef graph_def;
174   MutableGraphView graph(&graph_def);
175 
176   FunctionDef x_times_two = test::function::XTimesTwo();
177   GraphDef subgraph = test::function::GDef({}, {x_times_two});
178 
179   TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
180   EXPECT_EQ(graph_def.library().function_size(), 1);
181 }
182 
TEST(MutableGraphViewTest,AddSubgraphAndSkipSameFunction)183 TEST(MutableGraphViewTest, AddSubgraphAndSkipSameFunction) {
184   FunctionDef x_times_two = test::function::XTimesTwo();
185 
186   GraphDef graph_def = test::function::GDef({}, {x_times_two});
187   MutableGraphView graph(&graph_def);
188 
189   GraphDef subgraph = test::function::GDef({}, {x_times_two});
190 
191   TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
192   EXPECT_EQ(graph_def.library().function_size(), 1);
193 }
194 
TEST(MutableGraphViewTest,AddSubgraphAndFailIfFunctionDifferent)195 TEST(MutableGraphViewTest, AddSubgraphAndFailIfFunctionDifferent) {
196   FunctionDef x_times_four = test::function::XTimesFour();
197   x_times_four.mutable_signature()->set_name("XTimesTwo");
198 
199   GraphDef graph_def = test::function::GDef({}, {x_times_four});
200   MutableGraphView graph(&graph_def);
201 
202   FunctionDef x_times_two = test::function::XTimesTwo();
203   GraphDef subgraph = test::function::GDef({}, {x_times_two});
204 
205   Status status = graph.AddSubgraph(std::move(subgraph));
206   EXPECT_FALSE(status.ok());
207   EXPECT_EQ(status.error_message(),
208             "MutableGraphView::AddSubgraph(function_size=1) error: Found "
209             "different function definition with the same name: XTimesTwo.");
210 }
211 
TEST(MutableGraphViewTest,UpdateNodeNoDedupControlDependency)212 TEST(MutableGraphViewTest, UpdateNodeNoDedupControlDependency) {
213   constexpr char kDevice[] = "/device:foo:0";
214   GraphDef graph_def = test::function::GDef(
215       {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
216        NDef("other", "NotImportant", {}, {}),
217        NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
218        NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
219       /*funcs=*/{});
220 
221   MutableGraphView graph(&graph_def);
222 
223   AttrValue list_value;
224   list_value.mutable_list()->add_type(DT_FLOAT);
225   TF_EXPECT_OK(
226       graph.UpdateNode("bar_2", "IdentityN", kDevice, {{"T", list_value}}));
227 
228   CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
229   CheckNode(graph, "bar_2", "IdentityN", kDevice, {{"T", list_value}},
230             {"bar_1:1"}, {"foo_1", "foo_1:2", "^foo_1", "foo_2:1", "^foo_2"});
231   CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
232   CheckNode(graph, "foo_1", "NotImportant", "", {},
233             {"bar_2", "other", "bar_2:1", "^bar_2"}, {});
234   CheckNode(graph, "foo_2", "NotImportant", "", {},
235             {"other:1", "bar_2:2", "^bar_2"}, {});
236 
237   CheckGraph(graph);
238 }
239 
TEST(MutableGraphViewTest,UpdateNodeDedupControlDependency)240 TEST(MutableGraphViewTest, UpdateNodeDedupControlDependency) {
241   constexpr char kDevice[] = "/device:foo:0";
242   GraphDef graph_def = test::function::GDef(
243       {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
244        NDef("other", "NotImportant", {}, {}),
245        NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
246        NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
247       /*funcs=*/{});
248 
249   MutableGraphView graph(&graph_def);
250 
251   TF_EXPECT_OK(graph.UpdateNode("bar_2", "NotImportant", kDevice, {}));
252 
253   CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
254   CheckNode(graph, "bar_2", "NotImportant", kDevice, {}, {"bar_1:1"},
255             {"foo_1", "foo_1:2", "foo_2:1"});
256   CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
257   CheckNode(graph, "foo_1", "NotImportant", "", {},
258             {"bar_2", "other", "bar_2:1"}, {});
259   CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "bar_2:2"}, {});
260 
261   CheckGraph(graph);
262 }
263 
TEST(MutableGraphViewTest,UpdateNodeSwitchNoControlDependency)264 TEST(MutableGraphViewTest, UpdateNodeSwitchNoControlDependency) {
265   constexpr char kDevice[] = "/device:foo:0";
266   GraphDef graph_def =
267       test::function::GDef({NDef("foo", "NotImportant", {}, {}),
268                             NDef("bar", "NotImportant", {"foo:1"})},
269                            /*funcs=*/{});
270 
271   MutableGraphView graph(&graph_def);
272 
273   TF_EXPECT_OK(graph.UpdateNode("foo", "Switch", kDevice, {}));
274 
275   CheckNode(graph, "foo", "Switch", kDevice, {}, {}, {"bar"});
276   CheckNode(graph, "bar", "NotImportant", "", {}, {"foo:1"}, {});
277 
278   CheckGraph(graph);
279 }
280 
TEST(MutableGraphViewTest,UpdateNodeSwitchControlDependency)281 TEST(MutableGraphViewTest, UpdateNodeSwitchControlDependency) {
282   constexpr char kDevice[] = "/device:foo:0";
283   GraphDef graph_def =
284       test::function::GDef({NDef("foo", "NotImportant", {}, {}),
285                             NDef("bar", "NotImportant", {"^foo"})},
286                            /*funcs=*/{});
287 
288   MutableGraphView graph(&graph_def);
289 
290   AttrValue attr;
291   attr.set_type(DT_FLOAT);
292   Status s = graph.UpdateNode("foo", "Switch", kDevice, {{"T", attr}});
293   EXPECT_FALSE(s.ok());
294   string expected_msg =
295       "MutableGraphView::UpdateNodeOp(node_name='foo', op='Switch', "
296       "device='/device:foo:0', attrs={('T', type: DT_FLOAT)}) error: can't "
297       "change node op to Switch when node drives a control dependency "
298       "(alternatively, we could add the identity node needed, but it seems "
299       "like an unlikely event and probably a mistake).";
300   EXPECT_EQ(s.error_message(), expected_msg);
301 
302   CheckNode(graph, "foo", "NotImportant", "", {}, {}, {"^bar"});
303   CheckNode(graph, "bar", "NotImportant", "", {}, {"^foo"}, {});
304 
305   CheckGraph(graph);
306 }
307 
GetNodeInputsFromGraph(const GraphDef & graph,absl::string_view node_to_exclude)308 absl::flat_hash_map<string, std::vector<string>> GetNodeInputsFromGraph(
309     const GraphDef& graph, absl::string_view node_to_exclude) {
310   absl::flat_hash_map<string, std::vector<string>> node_inputs;
311   for (const auto& node : graph.node()) {
312     if (node.name() == node_to_exclude) {
313       continue;
314     }
315     node_inputs[node.name()] =
316         std::vector<string>(node.input().begin(), node.input().end());
317   }
318   return node_inputs;
319 }
320 
CheckUnmodifiedNodeFanins(const GraphDef & graph,absl::string_view node_to_exclude,const absl::flat_hash_map<string,std::vector<string>> & unmodified_node_inputs)321 void CheckUnmodifiedNodeFanins(
322     const GraphDef& graph, absl::string_view node_to_exclude,
323     const absl::flat_hash_map<string, std::vector<string>>&
324         unmodified_node_inputs) {
325   for (const auto& node : graph.node()) {
326     if (node.name() == node_to_exclude) {
327       continue;
328     }
329     auto it = unmodified_node_inputs.find(node.name());
330     ASSERT_NE(it, unmodified_node_inputs.end());
331     ASSERT_EQ(it->second.size(), node.input_size());
332     for (int i = 0; i < node.input_size(); ++i) {
333       EXPECT_EQ(node.input(i), it->second[i]);
334     }
335   }
336 }
337 
TestUpdateNodeName(absl::string_view from_node_name,bool node_exists,absl::string_view to_node_name,bool update_fanouts,bool success,const string & error_msg,absl::Span<const string> expected_fanins)338 void TestUpdateNodeName(absl::string_view from_node_name, bool node_exists,
339                         absl::string_view to_node_name, bool update_fanouts,
340                         bool success, const string& error_msg,
341                         absl::Span<const string> expected_fanins) {
342   GraphDef graph_def = test::function::GDef(
343       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
344        NDef("c", "NotImportant", {}, {})},
345       /*funcs=*/{});
346 
347   MutableGraphView graph(&graph_def);
348 
349   NodeDef* node = graph.GetNode(from_node_name);
350   if (node_exists) {
351     EXPECT_NE(node, nullptr);
352   } else {
353     EXPECT_EQ(node, nullptr);
354   }
355 
356   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
357       GetNodeInputsFromGraph(graph_def, from_node_name);
358 
359   Status s = graph.UpdateNodeName(from_node_name, to_node_name, update_fanouts);
360   EXPECT_EQ(s.ok(), success);
361   string updated_node_name;
362   if (success) {
363     updated_node_name = string(to_node_name);
364   } else {
365     updated_node_name = string(from_node_name);
366     EXPECT_EQ(s.error_message(), error_msg);
367   }
368   if (node_exists) {
369     EXPECT_EQ(node->name(), updated_node_name);
370     CompareNodeFanins(graph, node, expected_fanins);
371   }
372 
373   CheckUnmodifiedNodeFanins(graph_def, updated_node_name,
374                             unmodified_node_inputs);
375 
376   CheckGraph(graph);
377 }
378 
TEST(MutableGraphViewTest,UpdateNodeName)379 TEST(MutableGraphViewTest, UpdateNodeName) {
380   string error_msg;
381   // Node has no fanouts.
382   TestUpdateNodeName("b", /*node_exists=*/true, "d", /*update_fanouts=*/false,
383                      /*success=*/true, error_msg, {"a"});
384   // Node has fanouts and rename to self.
385   TestUpdateNodeName("b", /*node_exists=*/true, "b", /*update_fanouts=*/false,
386                      /*success=*/true, error_msg, {"a"});
387   // Node has no fanouts and rename to self.
388   TestUpdateNodeName("a", /*node_exists=*/true, "a", /*update_fanouts=*/false,
389                      /*success=*/true, error_msg, {});
390 
391   // New node name is in use.
392   error_msg =
393       "MutableGraphView::UpdateNodeName(from_node_name='c', to_node_name='b', "
394       "update_fanouts=false) error: can't update node name because new node "
395       "name is in use.";
396   TestUpdateNodeName("c", /*node_exists=*/true, "b", /*update_fanouts=*/false,
397                      /*success=*/false, error_msg, {});
398   error_msg =
399       "MutableGraphView::UpdateNodeName(from_node_name='a', to_node_name='b', "
400       "update_fanouts=true) error: can't update node name because new node "
401       "name is in use.";
402   TestUpdateNodeName("a", /*node_exists=*/true, "b", /*update_fanouts=*/true,
403                      /*success=*/false, error_msg, {});
404   // Node has fanouts.
405   error_msg =
406       "MutableGraphView::UpdateNodeName(from_node_name='a', to_node_name='d', "
407       "update_fanouts=false) error: can't update node name because node has "
408       "fanouts.";
409   TestUpdateNodeName("a", /*node_exists=*/true, "d", /*update_fanouts=*/false,
410                      /*success=*/false, error_msg, {});
411   // Node does not exist.
412   error_msg =
413       "MutableGraphView::UpdateNodeName(from_node_name='d', to_node_name='e', "
414       "update_fanouts=false) error: node 'd' was not found.";
415   TestUpdateNodeName("d", /*node_exists=*/false, "e", /*update_fanouts=*/false,
416                      /*success=*/false, error_msg, {});
417   error_msg =
418       "MutableGraphView::UpdateNodeName(from_node_name='d', to_node_name='e', "
419       "update_fanouts=true) error: node 'd' was not found.";
420   TestUpdateNodeName("d", /*node_exists=*/false, "e", /*update_fanouts=*/true,
421                      /*success=*/false, error_msg, {});
422 }
423 
TEST(MutableGraphViewTest,UpdateNodeNameWithFanouts)424 TEST(MutableGraphViewTest, UpdateNodeNameWithFanouts) {
425   GraphDef graph_def = test::function::GDef(
426       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
427        NDef("c", "NotImportant", {"b", "^a"}),
428        NDef("d", "NotImportant", {"^b", "^a"}),
429        NDef("e", "NotImportant", {"b:2", "c:4", "b:1", "^a"})},
430       /*funcs=*/{});
431 
432   MutableGraphView graph(&graph_def);
433 
434   TF_EXPECT_OK(graph.UpdateNodeName("b", "f", /*update_fanouts=*/true));
435 
436   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"f", "^c", "^d", "^e"});
437   CheckNode(graph, "f", "NotImportant", "", {}, {"a:2"},
438             {"c", "^d", "e", "e:2"});
439   CheckNode(graph, "c", "NotImportant", "", {}, {"f", "^a"}, {"e:1"});
440   CheckNode(graph, "d", "NotImportant", "", {}, {"^f", "^a"}, {});
441   CheckNode(graph, "e", "NotImportant", "", {}, {"f:2", "c:4", "f:1", "^a"},
442             {});
443 
444   CheckGraph(graph);
445 }
446 
SimpleSwapNodeNamesMutationGraph()447 GraphDef SimpleSwapNodeNamesMutationGraph() {
448   return test::function::GDef(
449       {NDef("a", "NotImportant", {}, {}), NDef("switch_1", "Switch", {"a"}),
450        NDef("identity_1", "Identity", {"switch_1:1"}),
451        NDef("b", "NotImportant", {}, {}), NDef("switch_2", "Switch", {"b"}),
452        NDef("identity_2", "Identity", {"switch_2:0"}),
453        NDef("foo_1", "NotImportant", {"identity_1", "^identity_1"}),
454        NDef("foo_2", "NotImportant", {"identity_2", "^identity_2"})},
455       /*funcs=*/{});
456 }
457 
TestSwapNodeNames(bool update_fanouts)458 void TestSwapNodeNames(bool update_fanouts) {
459   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
460 
461   MutableGraphView graph(&graph_def);
462 
463   TF_EXPECT_OK(graph.SwapNodeNames("foo_1", "foo_2", update_fanouts));
464 
465   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
466   CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
467   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
468             {"foo_2", "^foo_2"});
469   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
470   CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
471   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
472             {"foo_1", "^foo_1"});
473   CheckNode(graph, "foo_2", "NotImportant", "", {},
474             {"identity_1", "^identity_1"}, {});
475   CheckNode(graph, "foo_1", "NotImportant", "", {},
476             {"identity_2", "^identity_2"}, {});
477 
478   CheckGraph(graph);
479 }
480 
TEST(MutableGraphView,SwapNodeNames)481 TEST(MutableGraphView, SwapNodeNames) {
482   TestSwapNodeNames(/*update_fanouts=*/false);
483   TestSwapNodeNames(/*update_fanouts=*/true);
484 }
485 
TestSwapNodeNamesWithSameNames(bool update_fanouts)486 void TestSwapNodeNamesWithSameNames(bool update_fanouts) {
487   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
488 
489   MutableGraphView graph(&graph_def);
490 
491   TF_EXPECT_OK(graph.SwapNodeNames("identity_1", "identity_1", update_fanouts));
492 
493   // No changes to graph.
494   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
495   CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
496   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
497             {"foo_1", "^foo_1"});
498   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
499   CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
500   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
501             {"foo_2", "^foo_2"});
502   CheckNode(graph, "foo_1", "NotImportant", "", {},
503             {"identity_1", "^identity_1"}, {});
504   CheckNode(graph, "foo_2", "NotImportant", "", {},
505             {"identity_2", "^identity_2"}, {});
506 
507   CheckGraph(graph);
508 }
509 
TEST(MutableGraphView,SwapNodeNamesSameName)510 TEST(MutableGraphView, SwapNodeNamesSameName) {
511   TestSwapNodeNamesWithSameNames(/*update_fanouts=*/false);
512   TestSwapNodeNamesWithSameNames(/*update_fanouts=*/true);
513 }
514 
TEST(MutableGraphView,SwapNodeNamesBetweenSwitches)515 TEST(MutableGraphView, SwapNodeNamesBetweenSwitches) {
516   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
517 
518   MutableGraphView graph(&graph_def);
519 
520   TF_EXPECT_OK(
521       graph.SwapNodeNames("switch_1", "switch_2", /*update_fanouts=*/false));
522 
523   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_2"});
524   CheckNode(graph, "switch_2", "Switch", "", {}, {"a"}, {"identity_2"});
525   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
526             {"foo_1", "^foo_1"});
527   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_1"});
528   CheckNode(graph, "switch_1", "Switch", "", {}, {"b"}, {"identity_1"});
529   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
530             {"foo_2", "^foo_2"});
531   CheckNode(graph, "foo_1", "NotImportant", "", {},
532             {"identity_1", "^identity_1"}, {});
533   CheckNode(graph, "foo_2", "NotImportant", "", {},
534             {"identity_2", "^identity_2"}, {});
535 
536   CheckGraph(graph);
537 }
538 
TEST(MutableGraphView,SwapNodeNamesBetweenSwitchesAndUpdateFanouts)539 TEST(MutableGraphView, SwapNodeNamesBetweenSwitchesAndUpdateFanouts) {
540   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
541 
542   MutableGraphView graph(&graph_def);
543 
544   TF_EXPECT_OK(
545       graph.SwapNodeNames("switch_1", "switch_2", /*update_fanouts=*/true));
546 
547   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_2"});
548   CheckNode(graph, "switch_2", "Switch", "", {}, {"a"}, {"identity_1"});
549   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_2:1"},
550             {"foo_1", "^foo_1"});
551   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_1"});
552   CheckNode(graph, "switch_1", "Switch", "", {}, {"b"}, {"identity_2"});
553   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_1:0"},
554             {"foo_2", "^foo_2"});
555   CheckNode(graph, "foo_1", "NotImportant", "", {},
556             {"identity_1", "^identity_1"}, {});
557   CheckNode(graph, "foo_2", "NotImportant", "", {},
558             {"identity_2", "^identity_2"}, {});
559 
560   CheckGraph(graph);
561 }
562 
TEST(MutableGraphView,SwapNodeNamesSwitchAndNonSwitch)563 TEST(MutableGraphView, SwapNodeNamesSwitchAndNonSwitch) {
564   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
565 
566   MutableGraphView graph(&graph_def);
567 
568   TF_EXPECT_OK(graph.SwapNodeNames("a", "switch_1", /*update_fanouts=*/false));
569 
570   // Dedup controls and fix self loop.
571   CheckNode(graph, "switch_1", "NotImportant", "", {}, {}, {"a", "identity_1"});
572   CheckNode(graph, "a", "Switch", "", {}, {"switch_1"}, {});
573   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"}, {"foo_1"});
574   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
575   CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
576   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
577             {"foo_2", "^foo_2"});
578   CheckNode(graph, "foo_1", "NotImportant", "", {}, {"identity_1"}, {});
579   CheckNode(graph, "foo_2", "NotImportant", "", {},
580             {"identity_2", "^identity_2"}, {});
581 
582   CheckGraph(graph);
583 }
584 
TEST(MutableGraphView,SwapNodeNamesSwitchAndNonSwitchAndUpdateFanouts)585 TEST(MutableGraphView, SwapNodeNamesSwitchAndNonSwitchAndUpdateFanouts) {
586   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
587 
588   MutableGraphView graph(&graph_def);
589 
590   TF_EXPECT_OK(graph.SwapNodeNames("a", "switch_1", /*update_fanouts=*/true));
591 
592   CheckNode(graph, "switch_1", "NotImportant", "", {}, {}, {"a"});
593   CheckNode(graph, "a", "Switch", "", {}, {"switch_1"}, {"identity_1"});
594   CheckNode(graph, "identity_1", "Identity", "", {}, {"a:1"},
595             {"foo_1", "^foo_1"});
596   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
597   CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
598   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
599             {"foo_2", "^foo_2"});
600   CheckNode(graph, "foo_1", "NotImportant", "", {},
601             {"identity_1", "^identity_1"}, {});
602   CheckNode(graph, "foo_2", "NotImportant", "", {},
603             {"identity_2", "^identity_2"}, {});
604 
605   CheckGraph(graph);
606 }
607 
TEST(MutableGraphView,SwapNodeNamesNonSwitchAndSwitch)608 TEST(MutableGraphView, SwapNodeNamesNonSwitchAndSwitch) {
609   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
610 
611   MutableGraphView graph(&graph_def);
612 
613   TF_EXPECT_OK(graph.SwapNodeNames("switch_2", "b", /*update_fanouts=*/false));
614 
615   // Dedup controls and fix self loop.
616   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
617   CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
618   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
619             {"foo_1", "^foo_1"});
620   CheckNode(graph, "switch_2", "NotImportant", "", {}, {}, {"b", "identity_2"});
621   CheckNode(graph, "b", "Switch", "", {}, {"switch_2"}, {});
622   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"}, {"foo_2"});
623   CheckNode(graph, "foo_1", "NotImportant", "", {},
624             {"identity_1", "^identity_1"}, {});
625   CheckNode(graph, "foo_2", "NotImportant", "", {}, {"identity_2"}, {});
626 
627   CheckGraph(graph);
628 }
629 
TEST(MutableGraphView,SwapNodeNamesNonSwitchAndSwitchAndUpdateFanouts)630 TEST(MutableGraphView, SwapNodeNamesNonSwitchAndSwitchAndUpdateFanouts) {
631   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
632 
633   MutableGraphView graph(&graph_def);
634 
635   TF_EXPECT_OK(graph.SwapNodeNames("switch_2", "b", /*update_fanouts=*/true));
636 
637   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
638   CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
639   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
640             {"foo_1", "^foo_1"});
641   CheckNode(graph, "switch_2", "NotImportant", "", {}, {}, {"b"});
642   CheckNode(graph, "b", "Switch", "", {}, {"switch_2"}, {"identity_2"});
643   CheckNode(graph, "identity_2", "Identity", "", {}, {"b:0"},
644             {"foo_2", "^foo_2"});
645   CheckNode(graph, "foo_1", "NotImportant", "", {},
646             {"identity_1", "^identity_1"}, {});
647   CheckNode(graph, "foo_2", "NotImportant", "", {},
648             {"identity_2", "^identity_2"}, {});
649 
650   CheckGraph(graph);
651 }
652 
TestSwapNodeNamesSimpleSelfLoop(bool update_fanouts)653 void TestSwapNodeNamesSimpleSelfLoop(bool update_fanouts) {
654   GraphDef graph_def = test::function::GDef(
655       {NDef("a", "NotImportant", {"b:7"}), NDef("b", "NotImportant", {"a:10"})},
656       /*funcs=*/{});
657 
658   MutableGraphView graph(&graph_def);
659 
660   TF_EXPECT_OK(graph.SwapNodeNames("a", "b", update_fanouts));
661 
662   // No self loops.
663   CheckNode(graph, "a", "NotImportant", "", {}, {"b:10"}, {"b:0"});
664   CheckNode(graph, "b", "NotImportant", "", {}, {"a:7"}, {"a:0"});
665 
666   CheckGraph(graph);
667 }
668 
TEST(MutableGraphView,SwapNodeNamesSelfLoops)669 TEST(MutableGraphView, SwapNodeNamesSelfLoops) {
670   TestSwapNodeNamesSimpleSelfLoop(/*update_fanouts=*/false);
671   TestSwapNodeNamesSimpleSelfLoop(/*update_fanouts=*/true);
672 }
673 
TestSwapNodeNamesError(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts,const string & error_msg)674 void TestSwapNodeNamesError(absl::string_view from_node_name,
675                             absl::string_view to_node_name, bool update_fanouts,
676                             const string& error_msg) {
677   GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
678 
679   MutableGraphView graph(&graph_def);
680 
681   Status s = graph.SwapNodeNames(from_node_name, to_node_name, update_fanouts);
682   EXPECT_EQ(s.ok(), false);
683   EXPECT_EQ(s.error_message(), error_msg);
684 
685   // No changes to graph.
686   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
687   CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
688   CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
689             {"foo_1", "^foo_1"});
690   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
691   CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
692   CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
693             {"foo_2", "^foo_2"});
694   CheckNode(graph, "foo_1", "NotImportant", "", {},
695             {"identity_1", "^identity_1"}, {});
696   CheckNode(graph, "foo_2", "NotImportant", "", {},
697             {"identity_2", "^identity_2"}, {});
698 
699   CheckGraph(graph);
700 }
701 
702 // TODO(lyandy): add tests with update_fanouts == true.
TEST(MutableGraphView,SwapNodeNamesError)703 TEST(MutableGraphView, SwapNodeNamesError) {
704   string error_msg;
705   // Missing nodes.
706   error_msg =
707       "MutableGraphView::SwapNodeNames(from_node_name='foo_3', "
708       "to_node_name='foo_2', update_fanouts=false) error: node 'foo_3' was not "
709       "found.";
710   TestSwapNodeNamesError("foo_3", "foo_2", /*update_fanouts=*/false, error_msg);
711   error_msg =
712       "MutableGraphView::SwapNodeNames(from_node_name='foo_3', "
713       "to_node_name='foo_2', update_fanouts=true) error: node 'foo_3' was not "
714       "found.";
715   TestSwapNodeNamesError("foo_3", "foo_2", /*update_fanouts=*/true, error_msg);
716   error_msg =
717       "MutableGraphView::SwapNodeNames(from_node_name='foo_1', "
718       "to_node_name='foo_4', update_fanouts=false) error: node 'foo_4' was not "
719       "found.";
720   TestSwapNodeNamesError("foo_1", "foo_4", /*update_fanouts=*/false, error_msg);
721   error_msg =
722       "MutableGraphView::SwapNodeNames(from_node_name='foo_1', "
723       "to_node_name='foo_4', update_fanouts=true) error: node 'foo_4' was not "
724       "found.";
725   TestSwapNodeNamesError("foo_1", "foo_4", /*update_fanouts=*/true, error_msg);
726   error_msg =
727       "MutableGraphView::SwapNodeNames(from_node_name='foo_5', "
728       "to_node_name='foo_6', update_fanouts=false) error: node 'foo_5' was not "
729       "found.";
730   TestSwapNodeNamesError("foo_5", "foo_6", /*update_fanouts=*/false, error_msg);
731   error_msg =
732       "MutableGraphView::SwapNodeNames(from_node_name='foo_5', "
733       "to_node_name='foo_6', update_fanouts=true) error: node 'foo_5' was not "
734       "found.";
735   TestSwapNodeNamesError("foo_5", "foo_6", /*update_fanouts=*/true, error_msg);
736 
737   // Switch control dependencies.
738   error_msg =
739       "MutableGraphView::SwapNodeNames(from_node_name='switch_2', "
740       "to_node_name='identity_1', update_fanouts=false) error: can't swap node "
741       "name 'switch_2' as it will become a Switch control dependency.";
742   TestSwapNodeNamesError("switch_2", "identity_1", /*update_fanouts=*/false,
743                          error_msg);
744   error_msg =
745       "MutableGraphView::SwapNodeNames(from_node_name='identity_2', "
746       "to_node_name='switch_1', update_fanouts=false) error: can't swap node "
747       "name 'switch_1' as it will become a Switch control dependency.";
748   TestSwapNodeNamesError("identity_2", "switch_1", /*update_fanouts=*/false,
749                          error_msg);
750 }
751 
TEST(MutableGraphViewTest,AddAndUpdateFanouts)752 TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
753   // Actual node.op() is not important in this test.
754   GraphDef graph_def = test::function::GDef(
755       {NDef("bar", "NotImportant", {}, {}),
756        NDef("other", "NotImportant", {}, {}),
757        NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
758        NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"}),
759        NDef("foo_3", "NotImportant", {"other:2", "^bar"})},
760       /*funcs=*/{});
761 
762   MutableGraphView graph(&graph_def);
763 
764   NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NotImportant", {}, {}));
765 
766   TF_EXPECT_OK(graph.UpdateFanouts("bar", new_bar->name()));
767 
768   // Fanins and fanouts must be updated.
769   CheckNode(graph, "bar", "NotImportant", "", {}, {}, {});
770   CheckNode(graph, "other", "NotImportant", "", {}, {},
771             {"foo_1:1", "foo_2", "foo_3"});
772   CheckNode(graph, "foo_1", "NotImportant", "", {},
773             {"new_bar", "other", "new_bar:1"}, {});
774   CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "new_bar:2"},
775             {});
776   CheckNode(graph, "foo_3", "NotImportant", "", {}, {"other:2", "^new_bar"},
777             {});
778   CheckNode(graph, "new_bar", "NotImportant", "", {}, {},
779             {"foo_1:0", "foo_1:2", "foo_2:1", "^foo_3"});
780 
781   CheckGraph(graph);
782 }
783 
TEST(MutableGraphViewTest,AddAndUpdateFanoutsKeepControls)784 TEST(MutableGraphViewTest, AddAndUpdateFanoutsKeepControls) {
785   GraphDef graph_def = test::function::GDef(
786       {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
787        NDef("other", "NotImportant", {}, {}),
788        NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
789        NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
790       /*funcs=*/{});
791 
792   MutableGraphView graph(&graph_def);
793 
794   NodeDef* new_bar = graph.AddNode(NDef("new_bar", "Identity", {"bar_1:2"}));
795 
796   TF_EXPECT_OK(graph.UpdateFanouts("bar_2", new_bar->name()));
797 
798   // Fanins and fanouts must be updated.
799   CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2", "new_bar"});
800   CheckNode(graph, "bar_2", "Identity", "", {}, {"bar_1:1"}, {});
801   CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
802   CheckNode(graph, "foo_1", "NotImportant", "", {},
803             {"new_bar", "other", "new_bar:1", "^new_bar"}, {});
804   CheckNode(graph, "foo_2", "NotImportant", "", {},
805             {"other:1", "new_bar:2", "^new_bar"}, {});
806   CheckNode(graph, "new_bar", "Identity", "", {}, {"bar_1:2"},
807             {"foo_1", "foo_1:2", "^foo_1", "foo_2:1", "^foo_2"});
808 
809   CheckGraph(graph);
810 }
811 
TEST(MutableGraphViewTest,AddAndUpdateFanoutsWithoutSelfLoops)812 TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
813   // Actual node.op() is not important in this test.
814   GraphDef graph_def =
815       test::function::GDef({NDef("bar", "NotImportant", {}, {}),
816                             NDef("foo_1", "NotImportant", {"bar", "^bar"}),
817                             NDef("foo_2", "NotImportant", {"^bar"})},
818                            /*funcs=*/{});
819 
820   MutableGraphView graph(&graph_def);
821 
822   // `new_bar` reads the output of an original `bar` node.
823   NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NewBar", {"bar"}, {}));
824 
825   TF_EXPECT_OK(graph.UpdateFanouts("bar", new_bar->name()));
826 
827   // Fanins and fanouts must be updated.
828   CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"new_bar"});
829   CheckNode(graph, "foo_1", "NotImportant", "", {}, {"new_bar"}, {});
830   CheckNode(graph, "foo_2", "NotImportant", "", {}, {"^new_bar"}, {});
831   CheckNode(graph, "new_bar", "NewBar", "", {}, {"bar"}, {"foo_1", "^foo_2"});
832 
833   CheckGraph(graph);
834 }
835 
TEST(MutableGraphViewTest,UpdateFanoutsToSwitchWithControlFromSwitch)836 TEST(MutableGraphViewTest, UpdateFanoutsToSwitchWithControlFromSwitch) {
837   GraphDef graph_def = test::function::GDef(
838       {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
839        NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
840        NDef("e", "NotImportant", {"c", "b", "^a", "^d"})},
841       /*funcs=*/{});
842 
843   MutableGraphView graph(&graph_def);
844 
845   Status s = graph.UpdateFanouts("a", "b");
846   EXPECT_FALSE(s.ok());
847   string expected_msg =
848       "MutableGraphView::UpdateFanouts(from_node_name='a', to_node_name='b') "
849       "error: can't update fanouts to node 'b' as it will become a Switch "
850       "control dependency.";
851   EXPECT_EQ(s.error_message(), expected_msg);
852   s = graph.UpdateFanouts("d", "b");
853   EXPECT_FALSE(s.ok());
854   expected_msg =
855       "MutableGraphView::UpdateFanouts(from_node_name='d', to_node_name='b') "
856       "error: can't update fanouts to node 'b' as it will become a Switch "
857       "control dependency.";
858   EXPECT_EQ(s.error_message(), expected_msg);
859 
860   EXPECT_EQ(graph.graph()->node_size(), 5);
861 
862   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
863   CheckNode(graph, "b", "Switch", "", {}, {}, {"e:1"});
864   CheckNode(graph, "c", "NotImportant", "", {}, {}, {"e:0"});
865   CheckNode(graph, "d", "NotImportant", "", {}, {}, {"^e"});
866   CheckNode(graph, "e", "NotImportant", "", {}, {"c", "b", "^a", "^d"}, {});
867 
868   CheckGraph(graph);
869 }
870 
TEST(MutableGraphViewTest,UpdateFanoutsToSwitchWithNoControlFromSwitch)871 TEST(MutableGraphViewTest, UpdateFanoutsToSwitchWithNoControlFromSwitch) {
872   GraphDef graph_def = test::function::GDef(
873       {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
874        NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
875        NDef("e", "NotImportant", {"c", "b", "^a", "^d"})},
876       /*funcs=*/{});
877 
878   MutableGraphView graph(&graph_def);
879 
880   TF_EXPECT_OK(graph.UpdateFanouts("c", "b"));
881 
882   EXPECT_EQ(graph.graph()->node_size(), 5);
883 
884   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
885   CheckNode(graph, "b", "Switch", "", {}, {}, {"e:0", "e:1"});
886   CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
887   CheckNode(graph, "d", "NotImportant", "", {}, {}, {"^e"});
888   CheckNode(graph, "e", "NotImportant", "", {}, {"b", "b", "^a", "^d"}, {});
889 
890   CheckGraph(graph);
891 }
892 
SimpleMutateFaninGraph()893 GraphDef SimpleMutateFaninGraph() {
894   // Actual node.op() is not important in this test.
895   GraphDef graph_def = test::function::GDef(
896       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
897        NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
898        NDef("foo_1", "NotImportant", {"a"}),
899        NDef("foo_2", "NotImportant", {"b", "^a", "^c"}),
900        NDef("foo_3", "NotImportant", {"b", "a:1", "a:1"}),
901        NDef("foo_4", "NotImportant", {"a", "b:2", "b:2", "^c", "^d"}),
902        NDef("foo_5", "NotImportant", {}),
903        NDef("foo_6", "NotImportant", {"^a", "^b"})},
904       /*funcs=*/{});
905   return graph_def;
906 }
907 
TestAddRegularFanin(absl::string_view node_name,bool node_exists,const TensorId & fanin_to_add,bool success,const string & error_msg,absl::Span<const string> expected_fanins)908 void TestAddRegularFanin(absl::string_view node_name, bool node_exists,
909                          const TensorId& fanin_to_add, bool success,
910                          const string& error_msg,
911                          absl::Span<const string> expected_fanins) {
912   GraphDef graph_def = SimpleMutateFaninGraph();
913 
914   MutableGraphView graph(&graph_def);
915 
916   NodeDef* node = graph.GetNode(node_name);
917   if (node_exists) {
918     EXPECT_NE(node, nullptr);
919   } else {
920     EXPECT_EQ(node, nullptr);
921   }
922 
923   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
924       GetNodeInputsFromGraph(graph_def, node_name);
925 
926   Status s = graph.AddRegularFanin(node_name, fanin_to_add);
927   EXPECT_EQ(s.ok(), success);
928   if (!success) {
929     EXPECT_EQ(s.error_message(), error_msg);
930   }
931   if (node_exists) {
932     CompareNodeFanins(graph, node, expected_fanins);
933   }
934 
935   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
936 
937   CheckGraph(graph);
938 }
939 
TEST(MutableGraphViewTest,AddRegularFanin)940 TEST(MutableGraphViewTest, AddRegularFanin) {
941   string error_msg;
942   // Add input to node with 1 input 0 controls.
943   TestAddRegularFanin("foo_1", /*node_exists=*/true, {"b", 1}, /*success=*/true,
944                       error_msg, {"a", "b:1"});
945   // Add input to node with multiple inputs and 0 controls.
946   TestAddRegularFanin("foo_3", /*node_exists=*/true, {"b", 2}, /*success=*/true,
947                       error_msg, {"b", "a:1", "a:1", "b:2"});
948   // Add input to node with 1 input multiple controls.
949   TestAddRegularFanin("foo_2", /*node_exists=*/true, {"a", 0}, /*success=*/true,
950                       error_msg, {"b", "a", "^c"});
951   // Add input to node with multiple inputs and controls.
952   TestAddRegularFanin("foo_4", /*node_exists=*/true, {"a", 1}, /*success=*/true,
953                       error_msg, {"a", "b:2", "b:2", "a:1", "^d", "^c"});
954   // Add input to node with 0 inputs 0 controls.
955   TestAddRegularFanin("foo_5", /*node_exists=*/true, {"a", 1}, /*success=*/true,
956                       error_msg, {"a:1"});
957   // Add input to node with 0 inputs multiple controls.
958   TestAddRegularFanin("foo_6", /*node_exists=*/true, {"c", 1}, /*success=*/true,
959                       error_msg, {"c:1", "^b", "^a"});
960 
961   // Add control to node with 1 input 0 controls.
962   error_msg =
963       "MutableGraphView::AddRegularFanin(node_name='foo_1', fanin='^b') error: "
964       "fanin '^b' must be a regular tensor id.";
965   TestAddRegularFanin("foo_1", /*node_exists=*/true, {"b", Graph::kControlSlot},
966                       /*success=*/false, error_msg, {"a"});
967   // Add control to node with multiple inputs and 0 controls.
968   error_msg =
969       "MutableGraphView::AddRegularFanin(node_name='foo_3', fanin='^c') error: "
970       "fanin '^c' must be a regular tensor id.";
971   TestAddRegularFanin("foo_3", /*node_exists=*/true, {"c", Graph::kControlSlot},
972                       /*success=*/false, error_msg, {"b", "a:1", "a:1"});
973   // Add control to node with 1 input multiple controls.
974   error_msg =
975       "MutableGraphView::AddRegularFanin(node_name='foo_2', fanin='^d') error: "
976       "fanin '^d' must be a regular tensor id.";
977   TestAddRegularFanin("foo_2", /*node_exists=*/true, {"d", Graph::kControlSlot},
978                       /*success=*/false, error_msg, {"b", "^a", "^c"});
979   // Add control to node with multiple input multiple controls.
980   error_msg =
981       "MutableGraphView::AddRegularFanin(node_name='foo_4', fanin='^a') error: "
982       "fanin '^a' must be a regular tensor id.";
983   TestAddRegularFanin("foo_4", /*node_exists=*/true, {"a", Graph::kControlSlot},
984                       /*success=*/false, error_msg,
985                       {"a", "b:2", "b:2", "^c", "^d"});
986   // Add control to node with 0 inputs 0 controls.
987   error_msg =
988       "MutableGraphView::AddRegularFanin(node_name='foo_5', fanin='^a') error: "
989       "fanin '^a' must be a regular tensor id.";
990   TestAddRegularFanin("foo_5", /*node_exists=*/true, {"a", Graph::kControlSlot},
991                       /*success=*/false, error_msg, {});
992   // Add control to node with 0 inputs multiple controls.
993   error_msg =
994       "MutableGraphView::AddRegularFanin(node_name='foo_6', fanin='^c') error: "
995       "fanin '^c' must be a regular tensor id.";
996   TestAddRegularFanin("foo_6", /*node_exists=*/true, {"c", Graph::kControlSlot},
997                       /*success=*/false, error_msg, {"^a", "^b"});
998   // Add control to node with control that already exists.
999   error_msg =
1000       "MutableGraphView::AddRegularFanin(node_name='foo_2', fanin='^a') error: "
1001       "fanin '^a' must be a regular tensor id.";
1002   TestAddRegularFanin("foo_2", /*node_exists=*/true, {"a", Graph::kControlSlot},
1003                       /*success=*/false, error_msg, {"b", "^a", "^c"});
1004 
1005   // Add fanin to node where node is missing.
1006   error_msg =
1007       "MutableGraphView::AddRegularFanin(node_name='foo_missing', fanin='a:0') "
1008       "error: node 'foo_missing' was not found.";
1009   TestAddRegularFanin("foo_missing", /*node_exists=*/false, {"a", 0},
1010                       /*success=*/false, error_msg, {});
1011   // Add fanin to node where fanin is missing.
1012   error_msg =
1013       "MutableGraphView::AddRegularFanin(node_name='foo_1', "
1014       "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1015   TestAddRegularFanin("foo_1", /*node_exists=*/true, {"bar_missing", 0},
1016                       /*success=*/false, error_msg, {"a"});
1017   // Add fanin to node where node and fanin are missing.
1018   error_msg =
1019       "MutableGraphView::AddRegularFanin(node_name='foo_missing', "
1020       "fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1021   TestAddRegularFanin("foo_missing", /*node_exists=*/false, {"bar_missing", 0},
1022                       /*success=*/false, error_msg, {});
1023   // Add control fanin to node where node and fanin are missing.
1024   error_msg =
1025       "MutableGraphView::AddRegularFanin(node_name='foo_missing', "
1026       "fanin='^bar_missing') error: fanin '^bar_missing' must be a regular "
1027       "tensor id.";
1028   TestAddRegularFanin("foo_missing", /*node_exists=*/false,
1029                       {"bar_missing", Graph::kControlSlot},
1030                       /*success=*/false, error_msg, {});
1031 
1032   // Add self to create cycle.
1033   error_msg =
1034       "MutableGraphView::AddRegularFanin(node_name='foo_6', fanin='foo_6:2') "
1035       "error: can't add fanin 'foo_6:2' to self.";
1036   TestAddRegularFanin("foo_6", /*node_exists=*/true, {"foo_6", 2},
1037                       /*success=*/false, error_msg, {"^a", "^b"});
1038 }
1039 
TestAddRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,const TensorId & fanin_to_add,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1040 void TestAddRegularFaninByPort(absl::string_view node_name, bool node_exists,
1041                                int port, const TensorId& fanin_to_add,
1042                                bool success, const string& error_msg,
1043                                absl::Span<const string> expected_fanins) {
1044   GraphDef graph_def = SimpleMutateFaninGraph();
1045 
1046   MutableGraphView graph(&graph_def);
1047 
1048   NodeDef* node = graph.GetNode(node_name);
1049   if (node_exists) {
1050     EXPECT_NE(node, nullptr);
1051   } else {
1052     EXPECT_EQ(node, nullptr);
1053   }
1054 
1055   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1056       GetNodeInputsFromGraph(graph_def, node_name);
1057 
1058   Status s = graph.AddRegularFaninByPort(node_name, port, fanin_to_add);
1059   EXPECT_EQ(s.ok(), success);
1060   if (!success) {
1061     EXPECT_EQ(s.error_message(), error_msg);
1062   }
1063   if (node_exists) {
1064     CompareNodeFanins(graph, node, expected_fanins);
1065   }
1066 
1067   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1068 
1069   CheckGraph(graph);
1070 }
1071 
TEST(MutableGraphViewTest,AddRegularFaninByPort)1072 TEST(MutableGraphViewTest, AddRegularFaninByPort) {
1073   string error_msg;
1074   // Add input at start to node with some inputs and no controls.
1075   TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/0, {"d", 2},
1076                             /*success=*/true, error_msg,
1077                             {"d:2", "b", "a:1", "a:1"});
1078   // Add input at end to node with some inputs and no controls.
1079   TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/3, {"d", 2},
1080                             /*success=*/true, error_msg,
1081                             {"b", "a:1", "a:1", "d:2"});
1082   // Add input in middle to node with some inputs and no controls.
1083   TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/2, {"d", 2},
1084                             /*success=*/true, error_msg,
1085                             {"b", "a:1", "d:2", "a:1"});
1086   // Add input at start to node with some inputs and some controls.
1087   TestAddRegularFaninByPort("foo_2", /*node_exists=*/true, /*port=*/0, {"d", 2},
1088                             /*success=*/true, error_msg,
1089                             {"d:2", "b", "^c", "^a"});
1090   // Add input at end to node with some inputs and some controls.
1091   TestAddRegularFaninByPort("foo_2", /*node_exists=*/true, /*port=*/1, {"d", 2},
1092                             /*success=*/true, error_msg,
1093                             {"b", "d:2", "^c", "^a"});
1094   // Add input in middle to node with some inputs and some controls, and dedup
1095   // controls.
1096   TestAddRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/2, {"d", 2},
1097                             /*success=*/true, error_msg,
1098                             {"a", "b:2", "d:2", "b:2", "^c"});
1099   // Add input to node with no inputs and no controls.
1100   TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0, {"d", 2},
1101                             /*success=*/true, error_msg, {"d:2"});
1102   // Add input to node with no inputs and some controls.
1103   TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0, {"d", 2},
1104                             /*success=*/true, error_msg, {"d:2", "^b", "^a"});
1105   // Add fanin should dedup control.
1106   TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0, {"b", 2},
1107                             /*success=*/true, error_msg, {"b:2", "^a"});
1108 
1109   // Add controlling fanin.
1110   error_msg =
1111       "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=2, "
1112       "fanin='^d') error: fanin '^d' must be a regular tensor id.";
1113   TestAddRegularFaninByPort(
1114       "foo_4", /*node_exists=*/true, /*port=*/2, {"d", Graph::kControlSlot},
1115       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1116 
1117   // Add fanin at out of bounds port.
1118   error_msg =
1119       "MutableGraphView::AddRegularFaninByPort(node_name='foo_5', port=-1, "
1120       "fanin='d:2') error: port must be in range [0, 0].";
1121   TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/-1,
1122                             {"d", 2},
1123                             /*success=*/false, error_msg, {});
1124   error_msg =
1125       "MutableGraphView::AddRegularFaninByPort(node_name='foo_5', port=1, "
1126       "fanin='d:2') error: port must be in range [0, 0].";
1127   TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/1, {"d", 2},
1128                             /*success=*/false, error_msg, {});
1129   error_msg =
1130       "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=-1, "
1131       "fanin='d:2') error: port must be in range [0, 0].";
1132   TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/-1,
1133                             {"d", 2},
1134                             /*success=*/false, error_msg, {"^a", "^b"});
1135   error_msg =
1136       "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=1, "
1137       "fanin='d:2') error: port must be in range [0, 0].";
1138   TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1, {"d", 2},
1139                             /*success=*/false, error_msg, {"^a", "^b"});
1140   error_msg =
1141       "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=-1, "
1142       "fanin='d:2') error: port must be in range [0, 3].";
1143   TestAddRegularFaninByPort(
1144       "foo_4", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1145       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1146   error_msg =
1147       "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=4, "
1148       "fanin='d:2') error: port must be in range [0, 3].";
1149   TestAddRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/4, {"d", 2},
1150                             /*success=*/false, error_msg,
1151                             {"a", "b:2", "b:2", "^c", "^d"});
1152 
1153   // Add fanin to node where node is missing.
1154   error_msg =
1155       "MutableGraphView::AddRegularFaninByPort(node_name='foo_missing', "
1156       "port=0, fanin='a:0') error: node 'foo_missing' was not found.";
1157   TestAddRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1158                             {"a", 0},
1159                             /*success=*/false, error_msg, {});
1160   // Add fanin to node where fanin is missing.
1161   error_msg =
1162       "MutableGraphView::AddRegularFaninByPort(node_name='foo_1', port=0, "
1163       "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1164   TestAddRegularFaninByPort("foo_1", /*node_exists=*/true, /*port=*/0,
1165                             {"bar_missing", 0},
1166                             /*success=*/false, error_msg, {"a"});
1167   // Add fanin to node where node and fanin are missing.
1168   error_msg =
1169       "MutableGraphView::AddRegularFaninByPort(node_name='foo_missing', "
1170       "port=0, fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1171   TestAddRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1172                             {"bar_missing", 0},
1173                             /*success=*/false, error_msg, {});
1174 
1175   // Add self to create cycle.
1176   error_msg =
1177       "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=0, "
1178       "fanin='foo_6:2') error: can't add fanin 'foo_6:2' to self.";
1179   TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1180                             {"foo_6", 2},
1181                             /*success=*/false, error_msg, {"^a", "^b"});
1182 }
1183 
CheckFanoutRemoved(const MutableGraphView & graph,const TensorId & fanin,absl::string_view node_name)1184 void CheckFanoutRemoved(const MutableGraphView& graph, const TensorId& fanin,
1185                         absl::string_view node_name) {
1186   MutableGraphView::OutputPort output_port =
1187       graph.GetOutputPort(fanin.node(), fanin.index());
1188   auto fanouts = graph.GetFanout(output_port);
1189   for (auto fanout : fanouts) {
1190     EXPECT_NE(fanout.node->name(), fanin.node());
1191   }
1192 }
1193 
TestRemoveRegularFanin(absl::string_view node_name,bool node_exists,const TensorId & fanin_to_remove,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1194 void TestRemoveRegularFanin(absl::string_view node_name, bool node_exists,
1195                             const TensorId& fanin_to_remove, bool success,
1196                             const string& error_msg,
1197                             absl::Span<const string> expected_fanins) {
1198   GraphDef graph_def = SimpleMutateFaninGraph();
1199 
1200   MutableGraphView graph(&graph_def);
1201 
1202   NodeDef* node = graph.GetNode(node_name);
1203   if (node_exists) {
1204     EXPECT_NE(nullptr, node);
1205   } else {
1206     EXPECT_EQ(nullptr, node);
1207   }
1208 
1209   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1210       GetNodeInputsFromGraph(graph_def, node_name);
1211 
1212   Status s = graph.RemoveRegularFanin(node_name, fanin_to_remove);
1213   EXPECT_EQ(s.ok(), success);
1214   if (!success) {
1215     EXPECT_EQ(s.error_message(), error_msg);
1216   }
1217   if (node_exists) {
1218     CompareNodeFanins(graph, node, expected_fanins);
1219     if (success) {
1220       CheckFanoutRemoved(graph, fanin_to_remove, node_name);
1221     }
1222   }
1223 
1224   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1225 
1226   CheckGraph(graph);
1227 }
1228 
TEST(MutableGraphViewTest,RemoveRegularFanin)1229 TEST(MutableGraphViewTest, RemoveRegularFanin) {
1230   string error_msg;
1231   // Remove input from node with 1 input 0 controls.
1232   TestRemoveRegularFanin("foo_1", /*node_exists=*/true, {"a", 0},
1233                          /*success=*/true, error_msg, {});
1234   // Remove input from node with multiple inputs and 0 controls.
1235   TestRemoveRegularFanin("foo_3", /*node_exists=*/true, {"a", 1},
1236                          /*success=*/true, error_msg, {"b"});
1237   // Remove input from node with 1 input multiple controls.
1238   TestRemoveRegularFanin("foo_2", /*node_exists=*/true, {"b", 0},
1239                          /*success=*/true, error_msg, {"^a", "^c"});
1240   // Remove input from node with multiple inputs and controls.
1241   TestRemoveRegularFanin("foo_4", /*node_exists=*/true, {"b", 2},
1242                          /*success=*/true, error_msg, {"a", "^c", "^d"});
1243   // Remove input from node with multiple inputs and controls, and results in
1244   // shifting of ports.
1245   TestRemoveRegularFanin("foo_4", /*node_exists=*/true, {"a", 0},
1246                          /*success=*/true, error_msg,
1247                          {"b:2", "b:2", "^c", "^d"});
1248 
1249   // Remove control from node with 1 input multiple controls.
1250   error_msg =
1251       "MutableGraphView::RemoveRegularFanin(node_name='foo_2', fanin='^a') "
1252       "error: fanin '^a' must be a regular tensor id.";
1253   TestRemoveRegularFanin("foo_2", /*node_exists=*/true,
1254                          {"a", Graph::kControlSlot},
1255                          /*success=*/false, error_msg, {"b", "^a", "^c"});
1256   // Remove control from node with multiple input multiple controls.
1257   error_msg =
1258       "MutableGraphView::RemoveRegularFanin(node_name='foo_4', fanin='^d') "
1259       "error: fanin '^d' must be a regular tensor id.";
1260   TestRemoveRegularFanin(
1261       "foo_4", /*node_exists=*/true, {"d", Graph::kControlSlot},
1262       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1263   // Remove control from node with 0 inputs multiple controls.
1264   error_msg =
1265       "MutableGraphView::RemoveRegularFanin(node_name='foo_6', fanin='^a') "
1266       "error: fanin '^a' must be a regular tensor id.";
1267   TestRemoveRegularFanin("foo_6", /*node_exists=*/true,
1268                          {"a", Graph::kControlSlot},
1269                          /*success=*/false, error_msg, {"^a", "^b"});
1270 
1271   // Remove input from node with 0 inputs 0 controls.
1272   error_msg = "";
1273   TestRemoveRegularFanin("foo_5", /*node_exists=*/true, {"a", 1},
1274                          /*success=*/true, error_msg, {});
1275   // Remove input from node with 0 inputs multiple controls.
1276   TestRemoveRegularFanin("foo_6", /*node_exists=*/true, {"a", 1},
1277                          /*success=*/true, error_msg, {"^a", "^b"});
1278 
1279   // Remove control from node with 1 input 0 controls.
1280   error_msg =
1281       "MutableGraphView::RemoveRegularFanin(node_name='foo_1', fanin='^b') "
1282       "error: fanin '^b' must be a regular tensor id.";
1283   TestRemoveRegularFanin("foo_1", /*node_exists=*/true,
1284                          {"b", Graph::kControlSlot},
1285                          /*success=*/false, error_msg, {"a"});
1286   // Remove control from node with multiple inputs and 0 controls.
1287   error_msg =
1288       "MutableGraphView::RemoveRegularFanin(node_name='foo_3', fanin='^c') "
1289       "error: fanin '^c' must be a regular tensor id.";
1290   TestRemoveRegularFanin("foo_3", /*node_exists=*/true,
1291                          {"c", Graph::kControlSlot},
1292                          /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1293   // Remove control from node with 0 inputs 0 controls.
1294   error_msg =
1295       "MutableGraphView::RemoveRegularFanin(node_name='foo_5', fanin='^a') "
1296       "error: fanin '^a' must be a regular tensor id.";
1297   TestRemoveRegularFanin("foo_5", /*node_exists=*/true,
1298                          {"a", Graph::kControlSlot},
1299                          /*success=*/false, error_msg, {});
1300 
1301   // Remove fanin from node where node is missing.
1302   error_msg =
1303       "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1304       "fanin='a:0') error: node 'foo_missing' was not found.";
1305   TestRemoveRegularFanin("foo_missing", /*node_exists=*/false, {"a", 0},
1306                          /*success=*/false, error_msg, {});
1307   // Remove fanin from node where fanin is missing.
1308   error_msg =
1309       "MutableGraphView::RemoveRegularFanin(node_name='foo_1', "
1310       "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1311   TestRemoveRegularFanin("foo_1", /*node_exists=*/true, {"bar_missing", 0},
1312                          /*success=*/false, error_msg, {"a"});
1313   // Remove fanin from node where node and fanin are missing.
1314   error_msg =
1315       "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1316       "fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1317   TestRemoveRegularFanin("foo_missing", /*node_exists=*/false,
1318                          {"bar_missing", 0}, /*success=*/false, error_msg, {});
1319   // Remove control from node where node and fanin are missing.
1320   error_msg =
1321       "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1322       "fanin='^bar_missing') error: fanin '^bar_missing' must be a regular "
1323       "tensor id.";
1324   TestRemoveRegularFanin("foo_missing", /*node_exists=*/false,
1325                          {"bar_missing", Graph::kControlSlot},
1326                          /*success=*/false, error_msg, {});
1327 
1328   // Remove self.
1329   error_msg =
1330       "MutableGraphView::RemoveRegularFanin(node_name='foo_6', "
1331       "fanin='foo_6:2') error: can't remove fanin 'foo_6:2' from self.";
1332   TestRemoveRegularFanin("foo_6", /*node_exists=*/true, {"foo_6", 2},
1333                          /*success=*/false, error_msg, {"^a", "^b"});
1334 }
1335 
TestRemoveRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1336 void TestRemoveRegularFaninByPort(absl::string_view node_name, bool node_exists,
1337                                   int port, bool success,
1338                                   const string& error_msg,
1339                                   absl::Span<const string> expected_fanins) {
1340   GraphDef graph_def = SimpleMutateFaninGraph();
1341 
1342   MutableGraphView graph(&graph_def);
1343 
1344   NodeDef* node = graph.GetNode(node_name);
1345   if (node_exists) {
1346     EXPECT_NE(nullptr, node);
1347   } else {
1348     EXPECT_EQ(nullptr, node);
1349   }
1350 
1351   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1352       GetNodeInputsFromGraph(graph_def, node_name);
1353 
1354   Status s = graph.RemoveRegularFaninByPort(node_name, port);
1355   EXPECT_EQ(s.ok(), success);
1356   if (!success) {
1357     EXPECT_EQ(s.error_message(), error_msg);
1358   }
1359   if (node_exists) {
1360     CompareNodeFanins(graph, node, expected_fanins);
1361   }
1362 
1363   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1364 
1365   CheckGraph(graph);
1366 }
1367 
TEST(MutableGraphViewTest,RemoveRegularFaninByPort)1368 TEST(MutableGraphViewTest, RemoveRegularFaninByPort) {
1369   string error_msg;
1370   // Remove input at start of node with some inputs and no controls.
1371   TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/0,
1372                                /*success=*/true, error_msg, {"a:1", "a:1"});
1373   // Remove input at end of node with some inputs and no controls.
1374   TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/2,
1375                                /*success=*/true, error_msg, {"b", "a:1"});
1376   // Remove input in middle of node with some inputs and no controls.
1377   TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/1,
1378                                /*success=*/true, error_msg, {"b", "a:1"});
1379   // Remove input at start of node with some inputs and some controls.
1380   TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/0,
1381                                /*success=*/true, error_msg,
1382                                {"b:2", "b:2", "^d", "^c"});
1383   // Remove input at end of node with some inputs and some controls.
1384   TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/2,
1385                                /*success=*/true, error_msg,
1386                                {"a", "b:2", "^d", "^c"});
1387   // Remove input in middle of node with some inputs and some controls.
1388   TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/1,
1389                                /*success=*/true, error_msg,
1390                                {"a", "b:2", "^d", "^c"});
1391 
1392   // Remove input from node with no inputs and no controls.
1393   error_msg =
1394       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_5', port=0) "
1395       "error: no available ports as node has no regular fanins.";
1396   TestRemoveRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0,
1397                                /*success=*/false, error_msg, {});
1398   // Remove input from node with no inputs and some controls.
1399   error_msg =
1400       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_6', port=1) "
1401       "error: no available ports as node has no regular fanins.";
1402   TestRemoveRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1,
1403                                /*success=*/false, error_msg, {"^a", "^b"});
1404 
1405   // Remove fanin at out of bounds port.
1406   error_msg =
1407       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_3', port=-1) "
1408       "error: port must be in range [0, 2].";
1409   TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/-1,
1410                                /*success=*/false, error_msg,
1411                                {"b", "a:1", "a:1"});
1412   error_msg =
1413       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_3', port=3) "
1414       "error: port must be in range [0, 2].";
1415   TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/3,
1416                                /*success=*/false, error_msg,
1417                                {"b", "a:1", "a:1"});
1418   error_msg =
1419       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_4', port=-1) "
1420       "error: port must be in range [0, 2].";
1421   TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/-1,
1422                                /*success=*/false, error_msg,
1423                                {"a", "b:2", "b:2", "^c", "^d"});
1424   error_msg =
1425       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_4', port=3) "
1426       "error: port must be in range [0, 2].";
1427   TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/3,
1428                                /*success=*/false, error_msg,
1429                                {"a", "b:2", "b:2", "^c", "^d"});
1430 
1431   // Remove fanin from node where node is missing.
1432   error_msg =
1433       "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_missing', "
1434       "port=0) error: node 'foo_missing' was not found.";
1435   TestRemoveRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1436                                /*success=*/false, error_msg, {});
1437 }
1438 
TestRemoveAllFanins(absl::string_view node_name,bool node_exists,bool keep_controlling_nodes,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1439 void TestRemoveAllFanins(absl::string_view node_name, bool node_exists,
1440                          bool keep_controlling_nodes, bool success,
1441                          const string& error_msg,
1442                          absl::Span<const string> expected_fanins) {
1443   GraphDef graph_def = SimpleMutateFaninGraph();
1444 
1445   MutableGraphView graph(&graph_def);
1446 
1447   NodeDef* node = graph.GetNode(node_name);
1448   absl::flat_hash_set<string> fanin_strings;
1449   if (node_exists) {
1450     EXPECT_NE(node, nullptr);
1451     fanin_strings.insert(node->input().begin(), node->input().end());
1452   } else {
1453     EXPECT_EQ(node, nullptr);
1454   }
1455 
1456   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1457       GetNodeInputsFromGraph(graph_def, node_name);
1458 
1459   Status s = graph.RemoveAllFanins(node_name, keep_controlling_nodes);
1460   EXPECT_EQ(s.ok(), success);
1461   if (!success) {
1462     EXPECT_EQ(s.error_message(), error_msg);
1463   }
1464   if (node_exists) {
1465     CompareNodeFanins(graph, node, expected_fanins);
1466     if (success) {
1467       TensorId tensor_id;
1468       auto retained_inputs = absl::flat_hash_set<string>(node->input().begin(),
1469                                                          node->input().end());
1470       for (const string& fanin : fanin_strings) {
1471         if (!retained_inputs.contains(fanin)) {
1472           tensor_id = ParseTensorName(fanin);
1473           CheckFanoutRemoved(graph, tensor_id, node_name);
1474         }
1475       }
1476     }
1477   }
1478 
1479   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1480 
1481   CheckGraph(graph);
1482 }
1483 
TEST(MutableGraphViewTest,RemoveAllFanins)1484 TEST(MutableGraphViewTest, RemoveAllFanins) {
1485   string error_msg;
1486   // Remove all fanins from node with no control dependencies.
1487   TestRemoveAllFanins("foo_3", /*node_exists=*/true,
1488                       /*keep_controlling_nodes=*/false,
1489                       /*success=*/true, error_msg, {});
1490   // Remove all fanins from node with control dependencies.
1491   TestRemoveAllFanins("foo_4", /*node_exists=*/true,
1492                       /*keep_controlling_nodes=*/false,
1493                       /*success=*/true, error_msg, {});
1494 
1495   // Remove all fanins from node with no control dependencies and preserve
1496   // control dependencies.
1497   TestRemoveAllFanins("foo_3", /*node_exists=*/true,
1498                       /*keep_controlling_nodes=*/true,
1499                       /*success=*/true, error_msg, {});
1500   // Remove all fanins from node with control dependencies and preserve control
1501   // dependencies.
1502   TestRemoveAllFanins("foo_4", /*node_exists=*/true,
1503                       /*keep_controlling_nodes=*/true,
1504                       /*success=*/true, error_msg, {"^c", "^d"});
1505 
1506   // Remove all fanins from node with no fanins.
1507   TestRemoveAllFanins("foo_5", /*node_exists=*/true,
1508                       /*keep_controlling_nodes=*/false,
1509                       /*success=*/true, error_msg, {});
1510   TestRemoveAllFanins("foo_5", /*node_exists=*/true,
1511                       /*keep_controlling_nodes=*/true,
1512                       /*success=*/true, error_msg, {});
1513 
1514   // Remove all fanins from node with only control dependencies.
1515   TestRemoveAllFanins("foo_6", /*node_exists=*/true,
1516                       /*keep_controlling_nodes=*/false,
1517                       /*success=*/true, error_msg, {});
1518   TestRemoveAllFanins("foo_6", /*node_exists=*/true,
1519                       /*keep_controlling_nodes=*/true,
1520                       /*success=*/true, error_msg, {"^a", "^b"});
1521 
1522   // Remove all fanins from node where node is missing.
1523   error_msg =
1524       "MutableGraphView::RemoveAllFanins(node_name='foo_missing', "
1525       "keep_controlling_fanins=false) error: node 'foo_missing' was not found.";
1526   TestRemoveAllFanins("foo_missing", /*node_exists=*/false,
1527                       /*keep_controlling_nodes=*/false,
1528                       /*success=*/false, error_msg, {});
1529   error_msg =
1530       "MutableGraphView::RemoveAllFanins(node_name='foo_missing', "
1531       "keep_controlling_fanins=true) error: node 'foo_missing' was not found.";
1532   TestRemoveAllFanins("foo_missing", /*node_exists=*/false,
1533                       /*keep_controlling_nodes=*/true,
1534                       /*success=*/false, error_msg, {});
1535 }
1536 
TestUpdateFanin(absl::string_view node_name,bool node_exists,const TensorId & from_fanin,const TensorId & to_fanin,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1537 void TestUpdateFanin(absl::string_view node_name, bool node_exists,
1538                      const TensorId& from_fanin, const TensorId& to_fanin,
1539                      bool success, const string& error_msg,
1540                      absl::Span<const string> expected_fanins) {
1541   GraphDef graph_def = SimpleMutateFaninGraph();
1542 
1543   MutableGraphView graph(&graph_def);
1544 
1545   NodeDef* node = graph.GetNode(node_name);
1546   if (node_exists) {
1547     EXPECT_NE(node, nullptr);
1548   } else {
1549     EXPECT_EQ(node, nullptr);
1550   }
1551 
1552   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1553       GetNodeInputsFromGraph(graph_def, node_name);
1554 
1555   Status s = graph.UpdateFanin(node_name, from_fanin, to_fanin);
1556   EXPECT_EQ(s.ok(), success);
1557   if (!success) {
1558     EXPECT_EQ(s.error_message(), error_msg);
1559   }
1560   if (node_exists) {
1561     CompareNodeFanins(graph, node, expected_fanins);
1562     if (success) {
1563       CheckFanoutRemoved(graph, from_fanin, node_name);
1564     }
1565   }
1566 
1567   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1568 
1569   CheckGraph(graph);
1570 }
1571 
TEST(MutableGraphViewTest,UpdateFanin)1572 TEST(MutableGraphViewTest, UpdateFanin) {
1573   string error_msg;
1574   // Update fanin from non control to non control.
1575   TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2}, {"b", 3},
1576                   /*success=*/true, error_msg, {"a", "b:3", "b:3", "^c", "^d"});
1577   // Update fanin from non control to control.
1578   TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2},
1579                   {"b", Graph::kControlSlot},
1580                   /*success=*/true, error_msg, {"a", "^c", "^d", "^b"});
1581   // Update fanin from control to non control.
1582   TestUpdateFanin(
1583       "foo_4", /*node_exists=*/true, {"d", Graph::kControlSlot}, {"d", 1},
1584       /*success=*/true, error_msg, {"a", "b:2", "b:2", "d:1", "^c"});
1585   // Update fanin from control to control.
1586   TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1587                   {"b", Graph::kControlSlot}, /*success=*/true, error_msg,
1588                   {"a", "b:2", "b:2", "^d"});
1589   // Update fanin from control to existing control.
1590   TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1591                   {"d", Graph::kControlSlot}, /*success=*/true, error_msg,
1592                   {"a", "b:2", "b:2", "^d"});
1593 
1594   // Update fanin of node where from and to fanins are the same.
1595   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", -1}, {"a", -1},
1596                   /*success=*/true, error_msg, {"a"});
1597   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0}, {"a", 0},
1598                   /*success=*/true, error_msg, {"a"});
1599   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 1}, {"a", 1},
1600                   /*success=*/true, error_msg, {"a"});
1601 
1602   // Update fanin of node where node is missing.
1603   error_msg =
1604       "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1605       "from_fanin='a:0', to_fanin='a:1') error: node 'foo_missing' was not "
1606       "found.";
1607   TestUpdateFanin("foo_missing", /*node_exists=*/false, {"a", 0}, {"a", 1},
1608                   /*success=*/false, error_msg, {});
1609   // Update fanin of node where from fanin is missing.
1610   error_msg =
1611       "MutableGraphView::UpdateFanin(node_name='foo_1', "
1612       "from_fanin='from_bar_missing:0', to_fanin='a:1') error: node "
1613       "'from_bar_missing' was not found.";
1614   TestUpdateFanin("foo_1", /*node_exists=*/true, {"from_bar_missing", 0},
1615                   {"a", 1},
1616                   /*success=*/false, error_msg, {"a"});
1617   // Update fanin of node where to fanin is missing.
1618   error_msg =
1619       "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:0', "
1620       "to_fanin='to_bar_missing:1') error: node 'to_bar_missing' was not "
1621       "found.";
1622   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0},
1623                   {"to_bar_missing", 1}, /*success=*/false, error_msg, {"a"});
1624   // Update fanin of node where from/to fanins and node are missing.
1625   error_msg =
1626       "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1627       "from_fanin='from_bar_missing:0', to_fanin='to_bar_missing:1') error: "
1628       "node 'foo_missing' was not found.";
1629   TestUpdateFanin("foo_missing", /*node_exists=*/false, {"from_bar_missing", 0},
1630                   {"to_bar_missing", 1},
1631                   /*success=*/false, error_msg, {});
1632   // Update fanin of node where from fanin is invalid.
1633   error_msg =
1634       "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:-2', "
1635       "to_fanin='a:0') error: fanin 'a:-2' must be a valid tensor id.";
1636   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", -2}, {"a", 0},
1637                   /*success=*/false, error_msg, {"a"});
1638   // Update fanin of node where to fanin is invalid.
1639   error_msg =
1640       "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:0', "
1641       "to_fanin='a:-2') error: fanin 'a:-2' must be a valid tensor id.";
1642   TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0}, {"a", -2},
1643                   /*success=*/false, error_msg, {"a"});
1644   // Update fanin of node where from/to fanins are invalid and missing and node
1645   // is missing.
1646   error_msg =
1647       "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1648       "from_fanin='from_bar_missing:-2', to_fanin='to_bar_missing:-3') error: "
1649       "fanin 'from_bar_missing:-2' must be a valid tensor id.";
1650   TestUpdateFanin("foo_missing", /*node_exists=*/false,
1651                   {"from_bar_missing", -2}, {"to_bar_missing", -3},
1652                   /*success=*/false, error_msg, {});
1653 
1654   // Update to self to create cycle.
1655   error_msg =
1656       "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='b:2', "
1657       "to_fanin='foo_4:3') error: can't update fanin to or from self.";
1658   TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2}, {"foo_4", 3},
1659                   /*success=*/false, error_msg,
1660                   {"a", "b:2", "b:2", "^c", "^d"});
1661   error_msg =
1662       "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='b:2', "
1663       "to_fanin='^foo_4') error: can't update fanin to or from self.";
1664   TestUpdateFanin(
1665       "foo_4", /*node_exists=*/true, {"b", 2}, {"foo_4", Graph::kControlSlot},
1666       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1667   error_msg =
1668       "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='^c', "
1669       "to_fanin='foo_4:4') error: can't update fanin to or from self.";
1670   TestUpdateFanin(
1671       "foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot}, {"foo_4", 4},
1672       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1673   error_msg =
1674       "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='^c', "
1675       "to_fanin='^foo_4') error: can't update fanin to or from self.";
1676   TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1677                   {"foo_4", Graph::kControlSlot}, /*success=*/false, error_msg,
1678                   {"a", "b:2", "b:2", "^c", "^d"});
1679 }
1680 
TestUpdateFaninFromFaninToNodeAsSwitchControl(const TensorId & fanin)1681 void TestUpdateFaninFromFaninToNodeAsSwitchControl(const TensorId& fanin) {
1682   string tensor_id_str = TensorIdToString(fanin);
1683   GraphDef graph_def = test::function::GDef(
1684       {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
1685        NDef("c", "NotImportant", {tensor_id_str})},
1686       /*funcs=*/{});
1687 
1688   MutableGraphView graph(&graph_def);
1689 
1690   Status s = graph.UpdateFanin("c", fanin, {"b", Graph::kControlSlot});
1691   EXPECT_FALSE(s.ok());
1692   string expected_msg = absl::Substitute(
1693       "MutableGraphView::UpdateFanin(node_name='c', from_fanin='$0', "
1694       "to_fanin='^b') error: can't update to fanin '^b' as it will become a "
1695       "Switch control dependency.",
1696       fanin.ToString());
1697   EXPECT_EQ(s.error_message(), expected_msg);
1698 
1699   EXPECT_EQ(graph.graph()->node_size(), 3);
1700 
1701   string fanout = IsControlInput(fanin) ? AsControlDependency("c") : "c";
1702   CheckNode(graph, "a", "NotImportant", "", {}, {}, {fanout});
1703   CheckNode(graph, "b", "Switch", "", {}, {}, {});
1704   CheckNode(graph, "c", "NotImportant", "", {}, {tensor_id_str}, {});
1705 
1706   CheckGraph(graph);
1707 }
1708 
TEST(MutableGraphViewTest,UpdateFaninToNodeAsSwitchControl)1709 TEST(MutableGraphViewTest, UpdateFaninToNodeAsSwitchControl) {
1710   TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", 0});
1711   TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", 1});
1712   TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", Graph::kControlSlot});
1713 }
1714 
TestUpdateRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,const TensorId & fanin,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1715 void TestUpdateRegularFaninByPort(absl::string_view node_name, bool node_exists,
1716                                   int port, const TensorId& fanin, bool success,
1717                                   const string& error_msg,
1718                                   absl::Span<const string> expected_fanins) {
1719   GraphDef graph_def = SimpleMutateFaninGraph();
1720 
1721   MutableGraphView graph(&graph_def);
1722 
1723   NodeDef* node = graph.GetNode(node_name);
1724   if (node_exists) {
1725     EXPECT_NE(node, nullptr);
1726   } else {
1727     EXPECT_EQ(node, nullptr);
1728   }
1729 
1730   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1731       GetNodeInputsFromGraph(graph_def, node_name);
1732 
1733   Status s = graph.UpdateRegularFaninByPort(node_name, port, fanin);
1734   EXPECT_EQ(s.ok(), success);
1735   if (!success) {
1736     EXPECT_EQ(s.error_message(), error_msg);
1737   }
1738   if (node_exists) {
1739     CompareNodeFanins(graph, node, expected_fanins);
1740   }
1741 
1742   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1743 
1744   CheckGraph(graph);
1745 }
1746 
TEST(MutableGraphViewTest,UpdateRegularFaninByPort)1747 TEST(MutableGraphViewTest, UpdateRegularFaninByPort) {
1748   string error_msg;
1749   // Update input at start to node with some inputs and no controls.
1750   TestUpdateRegularFaninByPort(
1751       "foo_3", /*node_exists=*/true, /*port=*/0, {"d", 2},
1752       /*success=*/true, error_msg, {"d:2", "a:1", "a:1"});
1753   // Update input at end to node with some inputs and no controls.
1754   TestUpdateRegularFaninByPort(
1755       "foo_3", /*node_exists=*/true, /*port=*/2, {"d", 2},
1756       /*success=*/true, error_msg, {"b", "a:1", "d:2"});
1757   // Update input in middle to node with some inputs and no controls.
1758   TestUpdateRegularFaninByPort(
1759       "foo_3", /*node_exists=*/true, /*port=*/1, {"d", 2},
1760       /*success=*/true, error_msg, {"b", "d:2", "a:1"});
1761   // Update input at start to node with some inputs and some controls, and dedup
1762   // controls.
1763   TestUpdateRegularFaninByPort(
1764       "foo_4", /*node_exists=*/true, /*port=*/0, {"d", 2},
1765       /*success=*/true, error_msg, {"d:2", "b:2", "b:2", "^c"});
1766   // Update input at end to node with some inputs and some controls, and dedup
1767   // controls.
1768   TestUpdateRegularFaninByPort(
1769       "foo_4", /*node_exists=*/true, /*port=*/2, {"d", 2},
1770       /*success=*/true, error_msg, {"a", "b:2", "d:2", "^c"});
1771   // Update input in middle to node with some inputs and some controls and
1772   // dedup controls.
1773   TestUpdateRegularFaninByPort(
1774       "foo_4", /*node_exists=*/true, /*port=*/1, {"d", 2},
1775       /*success=*/true, error_msg, {"a", "d:2", "b:2", "^c"});
1776 
1777   // Update input to controlling fanin.
1778   error_msg =
1779       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=1, "
1780       "fanin='^d') error: fanin '^d' must be a regular tensor id.";
1781   TestUpdateRegularFaninByPort(
1782       "foo_4", /*node_exists=*/true, /*port=*/1, {"d", Graph::kControlSlot},
1783       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1784 
1785   // Update fanin at out of bounds port.
1786   error_msg =
1787       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=-1, "
1788       "fanin='d:2') error: no available ports as node has no regular fanins.";
1789   TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/-1,
1790                                {"d", 2},
1791                                /*success=*/false, error_msg, {});
1792   error_msg =
1793       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=0, "
1794       "fanin='d:2') error: no available ports as node has no regular fanins.";
1795   TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0,
1796                                {"d", 2},
1797                                /*success=*/false, error_msg, {});
1798   error_msg =
1799       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=1, "
1800       "fanin='d:2') error: no available ports as node has no regular fanins.";
1801   TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/1,
1802                                {"d", 2},
1803                                /*success=*/false, error_msg, {});
1804   error_msg =
1805       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=-1, "
1806       "fanin='d:2') error: no available ports as node has no regular fanins.";
1807   TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/-1,
1808                                {"d", 2},
1809                                /*success=*/false, error_msg, {"^a", "^b"});
1810   error_msg =
1811       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=0, "
1812       "fanin='d:2') error: no available ports as node has no regular fanins.";
1813   TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1814                                {"d", 2},
1815                                /*success=*/false, error_msg, {"^a", "^b"});
1816   error_msg =
1817       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=1, "
1818       "fanin='d:2') error: no available ports as node has no regular fanins.";
1819   TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1,
1820                                {"d", 2},
1821                                /*success=*/false, error_msg, {"^a", "^b"});
1822   error_msg =
1823       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_3', port=-1, "
1824       "fanin='d:2') error: port must be in range [0, 2].";
1825   TestUpdateRegularFaninByPort(
1826       "foo_3", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1827       /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1828   error_msg =
1829       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_3', port=3, "
1830       "fanin='d:2') error: port must be in range [0, 2].";
1831   TestUpdateRegularFaninByPort(
1832       "foo_3", /*node_exists=*/true, /*port=*/3, {"d", 2},
1833       /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1834   error_msg =
1835       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=-1, "
1836       "fanin='d:2') error: port must be in range [0, 2].";
1837   TestUpdateRegularFaninByPort(
1838       "foo_4", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1839       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1840   error_msg =
1841       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=3, "
1842       "fanin='d:2') error: port must be in range [0, 2].";
1843   TestUpdateRegularFaninByPort(
1844       "foo_4", /*node_exists=*/true, /*port=*/3, {"d", 2},
1845       /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1846 
1847   // Update fanin to node where node is missing.
1848   error_msg =
1849       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_missing', "
1850       "port=0, fanin='a:0') error: node 'foo_missing' was not found.";
1851   TestUpdateRegularFaninByPort("foo_missing", /*node_exists=*/false,
1852                                /*port=*/0, {"a", 0},
1853                                /*success=*/false, error_msg, {});
1854   // Update fanin to node where fanin is missing.
1855   error_msg =
1856       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_1', port=0, "
1857       "fanin='bar_missing:0') error: node 'bar_missing' was not "
1858       "found.";
1859   TestUpdateRegularFaninByPort("foo_1", /*node_exists=*/true, /*port=*/0,
1860                                {"bar_missing", 0},
1861                                /*success=*/false, error_msg, {"a"});
1862   // Update fanin to node where node and fanin are missing.
1863   error_msg =
1864       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_missing', "
1865       "port=0, fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1866   TestUpdateRegularFaninByPort("foo_missing", /*node_exists=*/false,
1867                                /*port=*/0, {"bar_missing", 0},
1868                                /*success=*/false, error_msg, {});
1869 
1870   // Update self to create cycle.
1871   error_msg =
1872       "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=0, "
1873       "fanin='foo_6:2') error: can't add fanin 'foo_6:2' to self.";
1874   TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1875                                {"foo_6", 2},
1876                                /*success=*/false, error_msg, {"^a", "^b"});
1877 }
1878 
TestSwapRegularFaninsByPorts(absl::string_view node_name,bool node_exists,int from_port,int to_port,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1879 void TestSwapRegularFaninsByPorts(absl::string_view node_name, bool node_exists,
1880                                   int from_port, int to_port, bool success,
1881                                   const string& error_msg,
1882                                   absl::Span<const string> expected_fanins) {
1883   GraphDef graph_def = SimpleMutateFaninGraph();
1884 
1885   MutableGraphView graph(&graph_def);
1886 
1887   NodeDef* node = graph.GetNode(node_name);
1888   if (node_exists) {
1889     EXPECT_NE(node, nullptr);
1890   } else {
1891     EXPECT_EQ(node, nullptr);
1892   }
1893 
1894   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1895       GetNodeInputsFromGraph(graph_def, node_name);
1896 
1897   Status s = graph.SwapRegularFaninsByPorts(node_name, from_port, to_port);
1898   EXPECT_EQ(s.ok(), success);
1899   if (!success) {
1900     EXPECT_EQ(s.error_message(), error_msg);
1901   }
1902   if (node_exists) {
1903     CompareNodeFanins(graph, node, expected_fanins);
1904   }
1905 
1906   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1907 
1908   CheckGraph(graph);
1909 }
1910 
TEST(MutableGraphViewTest,SwapRegularFaninsByPorts)1911 TEST(MutableGraphViewTest, SwapRegularFaninsByPorts) {
1912   string error_msg;
1913   // Swapping first and last regular fanins
1914   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
1915                                /*to_port=*/2, /*success=*/true, error_msg,
1916                                {"a:1", "a:1", "b"});
1917   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/2,
1918                                /*to_port=*/0, /*success=*/true, error_msg,
1919                                {"a:1", "a:1", "b"});
1920   // Swapping first and last regular fanins, in node with controls.
1921   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
1922                                /*to_port=*/2, /*success=*/true, error_msg,
1923                                {"b:2", "b:2", "a", "^c", "^d"});
1924   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/2,
1925                                /*to_port=*/0, /*success=*/true, error_msg,
1926                                {"b:2", "b:2", "a", "^c", "^d"});
1927   // Swapping middle regular fanin.
1928   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
1929                                /*to_port=*/1, /*success=*/true, error_msg,
1930                                {"a:1", "b", "a:1"});
1931   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/1,
1932                                /*to_port=*/0, /*success=*/true, error_msg,
1933                                {"a:1", "b", "a:1"});
1934   // Swapping middle regular fanin, in node with controls.
1935   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
1936                                /*to_port=*/1, /*success=*/true, error_msg,
1937                                {"b:2", "a", "b:2", "^c", "^d"});
1938   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1939                                /*to_port=*/0, /*success=*/true, error_msg,
1940                                {"b:2", "a", "b:2", "^c", "^d"});
1941   // Swapping same port.
1942   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1943                                /*to_port=*/1, /*success=*/true, error_msg,
1944                                {"a", "b:2", "b:2", "^c", "^d"});
1945   // Swapping same fanin but different port.
1946   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1947                                /*to_port=*/2, /*success=*/true, error_msg,
1948                                {"a", "b:2", "b:2", "^c", "^d"});
1949 
1950   // Swapping fanins at out of bounds ports.
1951   // Node with no regular fanins and no controls.
1952   error_msg =
1953       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1954       "from_port=-1, to_port=0) error: no available ports as node has no "
1955       "regular fanins.";
1956   TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/-1,
1957                                /*to_port=*/0, /*success=*/false, error_msg, {});
1958   error_msg =
1959       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1960       "from_port=0, to_port=-1) error: no available ports as node has no "
1961       "regular fanins.";
1962   TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1963                                /*to_port=*/-1, /*success=*/false, error_msg,
1964                                {});
1965   error_msg =
1966       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1967       "from_port=0, to_port=0) error: no available ports as node has no "
1968       "regular fanins.";
1969   TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1970                                /*to_port=*/0, /*success=*/false, error_msg, {});
1971   error_msg =
1972       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1973       "from_port=0, to_port=1) error: no available ports as node has no "
1974       "regular fanins.";
1975   TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1976                                /*to_port=*/1, /*success=*/false, error_msg, {});
1977   error_msg =
1978       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1979       "from_port=1, to_port=0) error: no available ports as node has no "
1980       "regular fanins.";
1981   TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/1,
1982                                /*to_port=*/0, /*success=*/false, error_msg, {});
1983   // Node with no regular fanins and some controls.
1984   error_msg =
1985       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
1986       "from_port=-1, to_port=0) error: no available ports as node has no "
1987       "regular fanins.";
1988   TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/-1,
1989                                /*to_port=*/0, /*success=*/false, error_msg,
1990                                {"^a", "^b"});
1991   error_msg =
1992       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
1993       "from_port=0, to_port=-1) error: no available ports as node has no "
1994       "regular fanins.";
1995   TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
1996                                /*to_port=*/-1, /*success=*/false, error_msg,
1997                                {"^a", "^b"});
1998   error_msg =
1999       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2000       "from_port=0, to_port=0) error: no available ports as node has no "
2001       "regular fanins.";
2002   TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
2003                                /*to_port=*/0, /*success=*/false, error_msg,
2004                                {"^a", "^b"});
2005   error_msg =
2006       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2007       "from_port=0, to_port=1) error: no available ports as node has no "
2008       "regular fanins.";
2009   TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
2010                                /*to_port=*/1, /*success=*/false, error_msg,
2011                                {"^a", "^b"});
2012   error_msg =
2013       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2014       "from_port=1, to_port=0) error: no available ports as node has no "
2015       "regular fanins.";
2016   TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/1,
2017                                /*to_port=*/0, /*success=*/false, error_msg,
2018                                {"^a", "^b"});
2019   // Node with regular fanins and no controls.
2020   error_msg =
2021       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2022       "from_port=-1, to_port=0) error: port must be in range [0, 2].";
2023   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/-1,
2024                                /*to_port=*/0, /*success=*/false, error_msg,
2025                                {"b", "a:1", "a:1"});
2026   error_msg =
2027       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2028       "from_port=0, to_port=-1) error: port must be in range [0, 2].";
2029   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
2030                                /*to_port=*/-1, /*success=*/false, error_msg,
2031                                {"b", "a:1", "a:1"});
2032   error_msg =
2033       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2034       "from_port=0, to_port=3) error: port must be in range [0, 2].";
2035   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
2036                                /*to_port=*/3, /*success=*/false, error_msg,
2037                                {"b", "a:1", "a:1"});
2038   error_msg =
2039       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2040       "from_port=3, to_port=0) error: port must be in range [0, 2].";
2041   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/3,
2042                                /*to_port=*/0, /*success=*/false, error_msg,
2043                                {"b", "a:1", "a:1"});
2044   error_msg =
2045       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2046       "from_port=-1, to_port=3) error: port must be in range [0, 2].";
2047   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/-1,
2048                                /*to_port=*/3, /*success=*/false, error_msg,
2049                                {"b", "a:1", "a:1"});
2050   error_msg =
2051       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2052       "from_port=3, to_port=-1) error: port must be in range [0, 2].";
2053   TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/3,
2054                                /*to_port=*/-1, /*success=*/false, error_msg,
2055                                {"b", "a:1", "a:1"});
2056   // Node with regular fanins and controls.
2057   error_msg =
2058       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2059       "from_port=-1, to_port=0) error: port must be in range [0, 2].";
2060   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/-1,
2061                                /*to_port=*/0, /*success=*/false, error_msg,
2062                                {"a", "b:2", "b:2", "^c", "^d"});
2063   error_msg =
2064       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2065       "from_port=0, to_port=-1) error: port must be in range [0, 2].";
2066   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
2067                                /*to_port=*/-1, /*success=*/false, error_msg,
2068                                {"a", "b:2", "b:2", "^c", "^d"});
2069   error_msg =
2070       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2071       "from_port=0, to_port=3) error: port must be in range [0, 2].";
2072   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
2073                                /*to_port=*/3, /*success=*/false, error_msg,
2074                                {"a", "b:2", "b:2", "^c", "^d"});
2075   error_msg =
2076       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2077       "from_port=3, to_port=0) error: port must be in range [0, 2].";
2078   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/3,
2079                                /*to_port=*/0, /*success=*/false, error_msg,
2080                                {"a", "b:2", "b:2", "^c", "^d"});
2081   error_msg =
2082       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2083       "from_port=-1, to_port=3) error: port must be in range [0, 2].";
2084   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/-1,
2085                                /*to_port=*/3, /*success=*/false, error_msg,
2086                                {"a", "b:2", "b:2", "^c", "^d"});
2087   error_msg =
2088       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2089       "from_port=3, to_port=-1) error: port must be in range [0, 2].";
2090   TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/3,
2091                                /*to_port=*/-1, /*success=*/false, error_msg,
2092                                {"a", "b:2", "b:2", "^c", "^d"});
2093 
2094   // Swapping fanin to node where node is missing.
2095   error_msg =
2096       "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_missing', "
2097       "from_port=0, to_port=1) error: node 'foo_missing' was not found.";
2098   TestSwapRegularFaninsByPorts("foo_missing", /*node_exists=*/false,
2099                                /*from_port=*/0, /*to_port=*/1,
2100                                /*success=*/false, error_msg, {});
2101 }
2102 
TEST(MutableGraphViewTest,DedupControllingFaninsOnGraphInit)2103 TEST(MutableGraphViewTest, DedupControllingFaninsOnGraphInit) {
2104   GraphDef graph_def = test::function::GDef(
2105       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2106        NDef("c", "Switch", {}, {}), NDef("d", "Identity", {"c:1"}),
2107        NDef("foo_1", "IdentityN", {"a", "b:1", "^b"}),
2108        NDef("foo_2", "IdentityN", {"a", "^b", "^b"}),
2109        NDef("foo_3", "IdentityN", {"a", "b:1", "^b", "^b"}),
2110        NDef("foo_4", "IdentityN", {"a:2", "b:1", "^b", "^b", "^a", "^a"}),
2111        NDef("foo_5", "NotImportant", {"a:2", "b:1", "^b", "^b", "^a", "^a"}),
2112        NDef("foo_6", "Identity", {"d", "^d"}),
2113        NDef("foo_7", "NotImportant",
2114             {"a:3", "b:2", "d", "^d", "^d", "^a", "^b", "^a", "^b"})},
2115       /*funcs=*/{});
2116 
2117   MutableGraphView graph(&graph_def);
2118 
2119   EXPECT_EQ(graph.graph()->node_size(), 11);
2120 
2121   CheckNode(graph, "a", "NotImportant", "", {}, {},
2122             {"foo_1", "foo_2", "foo_3", "foo_4", "foo_5", "foo_7"});
2123   CheckNode(graph, "b", "NotImportant", "", {}, {},
2124             {"foo_1:1", "^foo_2", "foo_3:1", "foo_4:1", "foo_5:1", "foo_7:1"});
2125   CheckNode(graph, "c", "Switch", "", {}, {}, {"d"});
2126   CheckNode(graph, "d", "Identity", "", {}, {"c:1"},
2127             {"foo_6", "^foo_6", "foo_7:2", "^foo_7"});
2128   CheckNode(graph, "foo_1", "IdentityN", "", {}, {"a", "b:1"}, {});
2129   CheckNode(graph, "foo_2", "IdentityN", "", {}, {"a", "^b"}, {});
2130   CheckNode(graph, "foo_3", "IdentityN", "", {}, {"a", "b:1"}, {});
2131   CheckNode(graph, "foo_4", "IdentityN", "", {}, {"a:2", "b:1"}, {});
2132   CheckNode(graph, "foo_5", "NotImportant", "", {}, {"a:2", "b:1"}, {});
2133   CheckNode(graph, "foo_6", "Identity", "", {}, {"d", "^d"}, {});
2134   CheckNode(graph, "foo_7", "NotImportant", "", {}, {"a:3", "b:2", "d", "^d"},
2135             {});
2136 
2137   CheckGraph(graph);
2138 }
2139 
TEST(MutableGraphViewTest,DedupControllingFaninsOnAddFanin)2140 TEST(MutableGraphViewTest, DedupControllingFaninsOnAddFanin) {
2141   // Actual node.op() is not important in this test.
2142   GraphDef graph_def = test::function::GDef(
2143       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"^a"}),
2144        NDef("c", "NotImportant", {"a:1"})},
2145       /*funcs=*/{});
2146 
2147   MutableGraphView graph(&graph_def);
2148 
2149   TF_EXPECT_OK(graph.AddRegularFanin("b", {"a", 2}));
2150   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {});
2151 
2152   TF_EXPECT_OK(graph.AddControllingFanin("c", {"a", Graph::kControlSlot}));
2153   CheckNode(graph, "c", "NotImportant", "", {}, {"a:1"}, {});
2154 
2155   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b:0", "c:0"});
2156 
2157   CheckGraph(graph);
2158 }
2159 
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnAddFanin)2160 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnAddFanin) {
2161   GraphDef graph_def = test::function::GDef(
2162       {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2163        NDef("c", "", {}, {}), NDef("d", "", {}, {})},
2164       /*funcs=*/{});
2165 
2166   MutableGraphView graph(&graph_def);
2167 
2168   TF_EXPECT_OK(graph.AddRegularFanin("c", {"b", 2}));
2169   CheckNode(graph, "c", "", "", {}, {"b:2"}, {});
2170   TF_EXPECT_OK(graph.AddControllingFanin("c", {"b", Graph::kControlSlot}));
2171   CheckNode(graph, "c", "", "", {}, {"b:2", "^b"}, {});
2172   TF_EXPECT_OK(graph.AddControllingFanin("c", {"b", Graph::kControlSlot}));
2173   CheckNode(graph, "c", "", "", {}, {"b:2", "^b"}, {});
2174   TF_EXPECT_OK(graph.AddRegularFanin("c", {"b", 2}));
2175   CheckNode(graph, "c", "", "", {}, {"b:2", "b:2", "^b"}, {});
2176 
2177   TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2178   CheckNode(graph, "d", "", "", {}, {"^b"}, {});
2179   TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2180   CheckNode(graph, "d", "", "", {}, {"^b"}, {});
2181 
2182   CheckNode(graph, "a", "Switch", "", {}, {}, {"b"});
2183   CheckNode(graph, "b", "Identity", "", {}, {"a:1"},
2184             {"c:0", "c:1", "^c", "^d"});
2185 
2186   CheckGraph(graph);
2187 }
2188 
TEST(MutableGraphViewTest,DedupControllingFaninsOnAddFaninByPort)2189 TEST(MutableGraphViewTest, DedupControllingFaninsOnAddFaninByPort) {
2190   // Actual node.op() is not important in this test.
2191   GraphDef graph_def =
2192       test::function::GDef({NDef("a", "NotImportant", {}, {}),
2193                             NDef("b", "NotImportant", {"c", "^a"}),
2194                             NDef("c", "NotImportant", {"a:1"})},
2195                            /*funcs=*/{});
2196 
2197   MutableGraphView graph(&graph_def);
2198 
2199   TF_EXPECT_OK(graph.AddRegularFaninByPort("b", 0, {"a", 2}));
2200   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2", "c"}, {});
2201 
2202   TF_EXPECT_OK(graph.AddControllingFanin("c", {"a", Graph::kControlSlot}));
2203   CheckNode(graph, "c", "NotImportant", "", {}, {"a:1"}, {"b:1"});
2204 
2205   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b:0", "c:0"});
2206 
2207   CheckGraph(graph);
2208 }
2209 
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnAddFaninByPort)2210 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnAddFaninByPort) {
2211   GraphDef graph_def = test::function::GDef(
2212       {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2213        NDef("c", "", {}, {}), NDef("d", "", {"c:2"}, {})},
2214       /*funcs=*/{});
2215 
2216   MutableGraphView graph(&graph_def);
2217 
2218   TF_EXPECT_OK(graph.AddRegularFaninByPort("d", 1, {"b", 2}));
2219   CheckNode(graph, "d", "", "", {}, {"c:2", "b:2"}, {});
2220   TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2221   CheckNode(graph, "d", "", "", {}, {"c:2", "b:2", "^b"}, {});
2222   TF_EXPECT_OK(graph.AddRegularFaninByPort("d", 0, {"b", 2}));
2223   CheckNode(graph, "d", "", "", {}, {"b:2", "c:2", "b:2", "^b"}, {});
2224 
2225   CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0"});
2226   CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {"d:0", "d:2", "^d"});
2227   CheckNode(graph, "c", "", "", {}, {}, {"d:1"});
2228 
2229   CheckGraph(graph);
2230 }
2231 
TEST(MutableGraphViewTest,DedupControllingFaninsOnUpdateFanin)2232 TEST(MutableGraphViewTest, DedupControllingFaninsOnUpdateFanin) {
2233   // Actual node.op() is not important in this test.
2234   GraphDef graph_def = test::function::GDef(
2235       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2236        NDef("c", "NotImportant", {"a:1", "^b"})},
2237       /*funcs=*/{});
2238 
2239   MutableGraphView graph(&graph_def);
2240 
2241   TF_EXPECT_OK(graph.UpdateFanin("c", {"a", 1}, {"b", 2}));
2242 
2243   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2244   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"c"});
2245   CheckNode(graph, "c", "NotImportant", "", {}, {"b:2"}, {});
2246 
2247   CheckGraph(graph);
2248 }
2249 
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnUpdateFanin)2250 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnUpdateFanin) {
2251   GraphDef graph_def = test::function::GDef(
2252       {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2253        NDef("c", "Identity", {"a:2"}), NDef("d", "NotImportant", {"c", "^b"}),
2254        NDef("e", "NotImportant", {"b", "^c"})},
2255       /*funcs=*/{});
2256 
2257   MutableGraphView graph(&graph_def);
2258 
2259   TF_EXPECT_OK(graph.UpdateFanin("d", {"b", Graph::kControlSlot},
2260                                  {"c", Graph::kControlSlot}));
2261   CheckNode(graph, "d", "NotImportant", "", {}, {"c", "^c"}, {});
2262 
2263   TF_EXPECT_OK(graph.UpdateFanin("e", {"b", 0}, {"c", 3}));
2264   CheckNode(graph, "e", "NotImportant", "", {}, {"c:3", "^c"}, {});
2265 
2266   TF_EXPECT_OK(graph.UpdateFanin("e", {"c", 3}, {"c", Graph::kControlSlot}));
2267   CheckNode(graph, "e", "NotImportant", "", {}, {"^c"}, {});
2268 
2269   CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0", "c:0"});
2270   CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {});
2271   CheckNode(graph, "c", "Identity", "", {}, {"a:2"}, {"d:0", "^d", "^e"});
2272 
2273   CheckGraph(graph);
2274 }
2275 
TEST(MutableGraphViewTest,DedupControllingFaninsOnUpdateFaninByPort)2276 TEST(MutableGraphViewTest, DedupControllingFaninsOnUpdateFaninByPort) {
2277   // Actual node.op() is not important in this test.
2278   GraphDef graph_def = test::function::GDef(
2279       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2280        NDef("c", "NotImportant", {"a:1", "^b"})},
2281       /*funcs=*/{});
2282 
2283   MutableGraphView graph(&graph_def);
2284 
2285   TF_EXPECT_OK(graph.UpdateRegularFaninByPort("c", 0, {"b", 2}));
2286 
2287   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2288   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"c"});
2289   CheckNode(graph, "c", "NotImportant", "", {}, {"b:2"}, {});
2290 
2291   CheckGraph(graph);
2292 }
2293 
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnUpdateFaninByPort)2294 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnUpdateFaninByPort) {
2295   GraphDef graph_def = test::function::GDef(
2296       {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2297        NDef("c", "Identity", {"a:2"}), NDef("d", "NotImportant", {"c", "^b"}),
2298        NDef("e", "NotImportant", {"b", "^c"})},
2299       /*funcs=*/{});
2300 
2301   MutableGraphView graph(&graph_def);
2302 
2303   TF_EXPECT_OK(graph.UpdateRegularFaninByPort("d", 0, {"b", 1}));
2304   CheckNode(graph, "d", "NotImportant", "", {}, {"b:1", "^b"}, {});
2305 
2306   TF_EXPECT_OK(graph.UpdateRegularFaninByPort("e", 0, {"c", 2}));
2307   CheckNode(graph, "e", "NotImportant", "", {}, {"c:2", "^c"}, {});
2308 
2309   CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0", "c:0"});
2310   CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {"d:0", "^d"});
2311   CheckNode(graph, "c", "Identity", "", {}, {"a:2"}, {"e:0", "^e"});
2312 
2313   CheckGraph(graph);
2314 }
2315 
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnAddFanin)2316 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnAddFanin) {
2317   // Actual node.op() is not important in this test.
2318   GraphDef graph_def = test::function::GDef(
2319       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2320        NDef("c", "NotImportant", {"^b"})},
2321       /*funcs=*/{});
2322 
2323   MutableGraphView graph(&graph_def);
2324 
2325   TF_EXPECT_OK(graph.AddRegularFanin("c", {"a", 3}));
2326 
2327   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2328   CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {"^c"});
2329   CheckNode(graph, "c", "NotImportant", "", {}, {"a:3", "^b"}, {});
2330 
2331   CheckGraph(graph);
2332 }
2333 
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnRemoveFanin)2334 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnRemoveFanin) {
2335   // Actual node.op() is not important in this test.
2336   GraphDef graph_def = test::function::GDef(
2337       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2338        NDef("c", "NotImportant", {"a:2"})},
2339       /*funcs=*/{});
2340 
2341   MutableGraphView graph(&graph_def);
2342 
2343   TF_EXPECT_OK(graph.RemoveRegularFanin("c", {"a", 2}));
2344   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b"});
2345   CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {});
2346   CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2347 
2348   CheckGraph(graph);
2349 }
2350 
TEST(MutableGraphViewTest,KeepMaxRegularOutputPortOnRemoveFanin)2351 TEST(MutableGraphViewTest, KeepMaxRegularOutputPortOnRemoveFanin) {
2352   // Actual node.op() is not important in this test.
2353   GraphDef graph_def = test::function::GDef(
2354       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2355        NDef("c", "NotImportant", {"a:2"})},
2356       /*funcs=*/{});
2357 
2358   MutableGraphView graph(&graph_def);
2359 
2360   TF_EXPECT_OK(graph.RemoveRegularFanin("b", {"a", 1}));
2361 
2362   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"c"});
2363   CheckNode(graph, "b", "NotImportant", "", {}, {}, {});
2364   CheckNode(graph, "c", "NotImportant", "", {}, {"a:2"}, {});
2365 
2366   CheckGraph(graph);
2367 }
2368 
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnUpdateFanin)2369 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnUpdateFanin) {
2370   // Actual node.op() is not important in this test.
2371   GraphDef graph_def = test::function::GDef(
2372       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2373        NDef("c", "NotImportant", {"a:2"})},
2374       /*funcs=*/{});
2375 
2376   MutableGraphView graph(&graph_def);
2377 
2378   TF_EXPECT_OK(graph.UpdateFanin("c", {"a", 2}, {"b", 3}));
2379 
2380   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b"});
2381   CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {"c"});
2382   CheckNode(graph, "c", "NotImportant", "", {}, {"b:3"}, {});
2383 
2384   CheckGraph(graph);
2385 }
2386 
TEST(MutableGraphViewTest,AddControllingFaninMissing)2387 TEST(MutableGraphViewTest, AddControllingFaninMissing) {
2388   // Actual node.op() is not important in this test.
2389   GraphDef graph_def = test::function::GDef(
2390       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2391       /*funcs=*/{});
2392 
2393   MutableGraphView graph(&graph_def);
2394   // Missing fanin.
2395   Status s = graph.AddControllingFanin("a", {"c", Graph::kControlSlot});
2396   EXPECT_FALSE(s.ok());
2397   string expected_msg =
2398       "MutableGraphView::AddControllingFanin(node_name='a', fanin='^c') error: "
2399       "node 'c' was not found.";
2400   EXPECT_EQ(s.error_message(), expected_msg);
2401   // Missing node.
2402   s = graph.AddControllingFanin("d", {"a", Graph::kControlSlot});
2403   EXPECT_FALSE(s.ok());
2404   expected_msg =
2405       "MutableGraphView::AddControllingFanin(node_name='d', fanin='^a') error: "
2406       "node 'd' was not found.";
2407   EXPECT_EQ(s.error_message(), expected_msg);
2408   // Missing node and fanin.
2409   s = graph.AddControllingFanin("c", {"d", Graph::kControlSlot});
2410   EXPECT_FALSE(s.ok());
2411   expected_msg =
2412       "MutableGraphView::AddControllingFanin(node_name='c', fanin='^d') error: "
2413       "node 'c' was not found.";
2414   EXPECT_EQ(s.error_message(), expected_msg);
2415 
2416   ASSERT_EQ(graph.graph()->node_size(), 2);
2417 
2418   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2419   CheckNode(graph, "b", "NotImportant", "", {}, {}, {});
2420 
2421   CheckGraph(graph);
2422 }
2423 
TEST(MutableGraphViewTest,AddControllingFaninExistingControl)2424 TEST(MutableGraphViewTest, AddControllingFaninExistingControl) {
2425   // Actual node.op() is not important in this test.
2426   GraphDef graph_def = test::function::GDef(
2427       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2428       /*funcs=*/{});
2429 
2430   MutableGraphView graph(&graph_def);
2431   TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
2432   TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
2433 
2434   ASSERT_EQ(graph.graph()->node_size(), 2);
2435 
2436   CheckNode(graph, "a", "NotImportant", "", {}, {"^b"}, {});
2437   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^a"});
2438 
2439   CheckGraph(graph);
2440 }
2441 
TEST(MutableGraphViewTest,AddControllingFaninNotSwitch)2442 TEST(MutableGraphViewTest, AddControllingFaninNotSwitch) {
2443   // Actual node.op() is not important in this test.
2444   GraphDef graph_def = test::function::GDef(
2445       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2446       /*funcs=*/{});
2447 
2448   MutableGraphView graph(&graph_def);
2449   TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", 2}));
2450   TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", 2}));
2451 
2452   ASSERT_EQ(graph.graph()->node_size(), 2);
2453 
2454   CheckNode(graph, "a", "NotImportant", "", {}, {"^b"}, {});
2455   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^a"});
2456 
2457   CheckGraph(graph);
2458 }
2459 
TEST(MutableGraphViewTest,AddControllingFaninSwitch)2460 TEST(MutableGraphViewTest, AddControllingFaninSwitch) {
2461   GraphDef graph_def = test::function::GDef(
2462       {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {})},
2463       /*funcs=*/{});
2464 
2465   MutableGraphView graph(&graph_def);
2466 
2467   Status s = graph.AddControllingFanin("a", {"b", Graph::kControlSlot});
2468   EXPECT_FALSE(s.ok());
2469   string expected_msg =
2470       "MutableGraphView::AddControllingFanin(node_name='a', fanin='^b') error: "
2471       "can't add fanin '^b' as it will become a Switch control dependency.";
2472   EXPECT_EQ(s.error_message(), expected_msg);
2473 
2474   ASSERT_EQ(graph.graph()->node_size(), 2);
2475 
2476   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2477   CheckNode(graph, "b", "Switch", "", {}, {}, {});
2478 
2479   CheckGraph(graph);
2480 }
2481 
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithIdentity)2482 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithIdentity) {
2483   GraphDef graph_def = test::function::GDef(
2484       {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
2485        NDef("identity", "Identity", {"switch"})},
2486       /*funcs=*/{});
2487 
2488   MutableGraphView graph(&graph_def);
2489 
2490   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2491   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2492 
2493   ASSERT_EQ(graph.graph()->node_size(), 3);
2494 
2495   CheckNode(graph, "a", "NotImportant", "", {}, {"^identity"}, {});
2496   CheckNode(graph, "switch", "Switch", "", {}, {}, {"identity"});
2497   CheckNode(graph, "identity", "Identity", "", {}, {"switch"}, {"^a"});
2498 
2499   CheckGraph(graph);
2500 }
2501 
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithNoExistingIdentity)2502 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithNoExistingIdentity) {
2503   constexpr char kDevice[] = "/device:foo:0";
2504   GraphDef graph_def = test::function::GDef(
2505       {NDef("a", "NotImportant", {}, {}),
2506        NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice)},
2507       /*funcs=*/{});
2508 
2509   MutableGraphView graph(&graph_def);
2510 
2511   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2512   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2513 
2514   ASSERT_EQ(graph.graph()->node_size(), 3);
2515 
2516   CheckNode(graph, "a", "NotImportant", "", {},
2517             {"^ConstantFoldingCtrl/switch_0"}, {});
2518   CheckNode(graph, "switch", "Switch", kDevice, {{"T", DT_FLOAT}}, {},
2519             {"ConstantFoldingCtrl/switch_0"});
2520   CheckNode(graph, "ConstantFoldingCtrl/switch_0", "Identity", kDevice,
2521             {{"T", DT_FLOAT}}, {"switch"}, {"^a"});
2522 
2523   CheckGraph(graph);
2524 }
2525 
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithExistingAddedIdentity)2526 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithExistingAddedIdentity) {
2527   GraphDef graph_def = test::function::GDef(
2528       {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
2529        NDef("ConstantFoldingCtrl/switch_0", "Identity", {"switch"})},
2530       /*funcs=*/{});
2531 
2532   MutableGraphView graph(&graph_def);
2533 
2534   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2535   TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2536 
2537   ASSERT_EQ(graph.graph()->node_size(), 3);
2538 
2539   CheckNode(graph, "a", "NotImportant", "", {},
2540             {"^ConstantFoldingCtrl/switch_0"}, {});
2541   CheckNode(graph, "switch", "Switch", "", {}, {},
2542             {"ConstantFoldingCtrl/switch_0"});
2543   CheckNode(graph, "ConstantFoldingCtrl/switch_0", "Identity", "", {},
2544             {"switch"}, {"^a"});
2545 
2546   CheckGraph(graph);
2547 }
2548 
TestAddControllingFaninSelfLoops(absl::string_view node_name,const TensorId & fanin,const string & error_msg)2549 void TestAddControllingFaninSelfLoops(absl::string_view node_name,
2550                                       const TensorId& fanin,
2551                                       const string& error_msg) {
2552   GraphDef graph_def = test::function::GDef(
2553       {NDef("a", "NotImportant", {}, {}),
2554        NDef("b", "Switch", {}, {{"T", DT_FLOAT}}),
2555        NDef("c", "Identity", {"b:0"}), NDef("d", "Identity", {"b:1"}),
2556        NDef("e", "NotImportant", {"^a"})},
2557       /*funcs=*/{});
2558 
2559   MutableGraphView graph(&graph_def);
2560 
2561   Status s = graph.AddControllingFanin(node_name, fanin);
2562   EXPECT_FALSE(s.ok());
2563   EXPECT_EQ(s.error_message(), error_msg);
2564 
2565   EXPECT_EQ(graph.graph()->node_size(), 5);
2566 
2567   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
2568   CheckNode(graph, "b", "Switch", "", {{"T", DT_FLOAT}}, {}, {"c", "d"});
2569   CheckNode(graph, "c", "Identity", "", {}, {"b"}, {});
2570   CheckNode(graph, "d", "Identity", "", {}, {"b:1"}, {});
2571   CheckNode(graph, "e", "NotImportant", "", {}, {"^a"}, {});
2572 
2573   CheckGraph(graph);
2574 }
2575 
TEST(MutableGraphViewTest,AddControllingFaninSelfLoops)2576 TEST(MutableGraphViewTest, AddControllingFaninSelfLoops) {
2577   string error_msg =
2578       "MutableGraphView::AddControllingFanin(node_name='a', fanin='^a') error: "
2579       "can't add fanin '^a' to self.";
2580   TestAddControllingFaninSelfLoops("a", {"a", Graph::kControlSlot}, error_msg);
2581 
2582   // Adding Switch control dependency to Identity consumer. Node `c` is
2583   // consuming `b:0`, so adding `b:0` as a control dependency, because it is a
2584   // Switch, should trigger a lookup of outputs. As `c` is a consumer and an
2585   // Identity, this will introduce a self loop, so no control dependency should
2586   // be added.
2587   error_msg =
2588       "MutableGraphView::AddControllingFanin(node_name='c', fanin='b:0') "
2589       "error: can't add found fanin '^c' to self.";
2590   TestAddControllingFaninSelfLoops("c", {"b", 0}, error_msg);
2591 
2592   // Adding Switch control dependency to Identity consumer. Node `d` is
2593   // consuming `b:1`, so adding `b:1` as a control dependency, because it is a
2594   // Switch, should trigger a lookup of outputs. As `d` is a consumer and an
2595   // Identity, this will introduce a self loop, so no control dependency should
2596   // be added.
2597   error_msg =
2598       "MutableGraphView::AddControllingFanin(node_name='d', fanin='b:1') "
2599       "error: can't add found fanin '^d' to self.";
2600   TestAddControllingFaninSelfLoops("d", {"b", 1}, error_msg);
2601 }
2602 
TEST(MutableGraphViewTest,AddControllingFaninSelfLoopsGeneratedIdentity)2603 TEST(MutableGraphViewTest, AddControllingFaninSelfLoopsGeneratedIdentity) {
2604   GraphDef graph_def =
2605       test::function::GDef({NDef("a", "NotImportant", {}, {}),
2606                             NDef("b", "Switch", {}, {{"T", DT_FLOAT}}),
2607                             NDef("c", "NotImportant", {}),
2608                             NDef("ConstantFoldingCtrl/b_1", "Identity", {})},
2609                            /*funcs=*/{});
2610 
2611   MutableGraphView graph(&graph_def);
2612 
2613   // Adding Switch control dependency to Identity node of the same name as a
2614   // generated Identity node for pinning the control dependency. Because there
2615   // are no consumers of `b:1`, there will be an attempt to generate an Identity
2616   // node, with name `ConstantFoldingCtrl/b_1`. As the input node is of the same
2617   // name, we will introduce a self loop, so no control dependency should be
2618   // added.
2619   Status s = graph.AddControllingFanin("ConstantFoldingCtrl/b_1", {"b", 1});
2620   EXPECT_FALSE(s.ok());
2621   string expected_msg =
2622       "MutableGraphView::AddControllingFanin(node_name='ConstantFoldingCtrl/"
2623       "b_1', fanin='b:1') error: can't add generated fanin "
2624       "'^ConstantFoldingCtrl/b_1' to self.";
2625   EXPECT_EQ(s.error_message(), expected_msg);
2626 
2627   EXPECT_EQ(graph.graph()->node_size(), 4);
2628 
2629   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2630   CheckNode(graph, "b", "Switch", "", {{"T", DT_FLOAT}}, {}, {});
2631   CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2632   CheckNode(graph, "ConstantFoldingCtrl/b_1", "Identity", "", {}, {}, {});
2633 
2634   CheckGraph(graph);
2635 }
2636 
TEST(MutableGraphViewTest,RemoveControllingFaninMissing)2637 TEST(MutableGraphViewTest, RemoveControllingFaninMissing) {
2638   // Actual node.op() is not important in this test.
2639   GraphDef graph_def = test::function::GDef(
2640       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2641        NDef("c", "NotImportant", {}, {}),
2642        NDef("d", "NotImportant", {"^a", "^b"})},
2643       /*funcs=*/{});
2644 
2645   MutableGraphView graph(&graph_def);
2646 
2647   TF_EXPECT_OK(graph.RemoveControllingFanin("d", "c"));
2648 
2649   ASSERT_EQ(graph.graph()->node_size(), 4);
2650 
2651   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^d"});
2652   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^d"});
2653   CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2654   CheckNode(graph, "d", "NotImportant", "", {}, {"^a", "^b"}, {});
2655 
2656   CheckGraph(graph);
2657 }
2658 
TEST(MutableGraphViewTest,RemoveControllingFaninExisting)2659 TEST(MutableGraphViewTest, RemoveControllingFaninExisting) {
2660   // Actual node.op() is not important in this test.
2661   GraphDef graph_def = test::function::GDef(
2662       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2663        NDef("c", "NotImportant", {}, {}),
2664        NDef("d", "NotImportant", {"^a", "^b", "^c"})},
2665       /*funcs=*/{});
2666 
2667   MutableGraphView graph(&graph_def);
2668 
2669   TF_EXPECT_OK(graph.RemoveControllingFanin("d", "a"));
2670   TF_EXPECT_OK(graph.RemoveControllingFanin("d", "a"));
2671 
2672   ASSERT_EQ(graph.graph()->node_size(), 4);
2673 
2674   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2675   CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^d"});
2676   CheckNode(graph, "c", "NotImportant", "", {}, {}, {"^d"});
2677   CheckNode(graph, "d", "NotImportant", "", {}, {"^c", "^b"}, {});
2678 
2679   CheckGraph(graph);
2680 }
2681 
TEST(MutableGraphViewTest,RemoveControllingFaninOnRegularFanin)2682 TEST(MutableGraphViewTest, RemoveControllingFaninOnRegularFanin) {
2683   // Actual node.op() is not important in this test.
2684   GraphDef graph_def = test::function::GDef(
2685       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
2686        NDef("c", "NotImportant", {"a", "b"})},
2687       /*funcs=*/{});
2688 
2689   MutableGraphView graph(&graph_def);
2690 
2691   TF_EXPECT_OK(graph.RemoveControllingFanin("c", "a"));
2692   TF_EXPECT_OK(graph.RemoveControllingFanin("c", "b"));
2693 
2694   ASSERT_EQ(graph.graph()->node_size(), 3);
2695 
2696   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2697   CheckNode(graph, "b", "NotImportant", "", {}, {"a"}, {"c:1"});
2698   CheckNode(graph, "c", "NotImportant", "", {}, {"a", "b"}, {});
2699 
2700   CheckGraph(graph);
2701 }
2702 
TEST(MutableGraphViewTest,RemoveControllingFaninSelfLoop)2703 TEST(MutableGraphViewTest, RemoveControllingFaninSelfLoop) {
2704   // Actual node.op() is not important in this test.
2705   GraphDef graph_def = test::function::GDef(
2706       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
2707        NDef("c", "NotImportant", {"a", "b"})},
2708       /*funcs=*/{});
2709 
2710   MutableGraphView graph(&graph_def);
2711 
2712   Status s = graph.RemoveControllingFanin("c", "c");
2713   EXPECT_FALSE(s.ok());
2714   string expected_msg =
2715       "MutableGraphView::RemoveControllingFanin(node_name='c', "
2716       "fanin_node_name='c') error: can't remove fanin '^c' from "
2717       "self.";
2718   EXPECT_EQ(s.error_message(), expected_msg);
2719 
2720   ASSERT_EQ(graph.graph()->node_size(), 3);
2721 
2722   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2723   CheckNode(graph, "b", "NotImportant", "", {}, {"a"}, {"c:1"});
2724   CheckNode(graph, "c", "NotImportant", "", {}, {"a", "b"}, {});
2725 
2726   CheckGraph(graph);
2727 }
2728 
TestUpdateAllRegularFaninsToControlling(absl::string_view node_name,bool node_exists,bool success,const string & error_msg,absl::Span<const string> expected_fanins)2729 void TestUpdateAllRegularFaninsToControlling(
2730     absl::string_view node_name, bool node_exists, bool success,
2731     const string& error_msg, absl::Span<const string> expected_fanins) {
2732   constexpr char kDevice[] = "/device:foo:0";
2733   GraphDef graph_def = test::function::GDef(
2734       {NDef("a", "NotImportant", {}, {}),
2735        NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
2736        NDef("b", "NotImportant", {"switch:1"}, {}),
2737        NDef("ConstantFoldingCtrl/switch_1", "Identity", {"switch:1"},
2738             {{"T", DT_FLOAT}}, kDevice),
2739        NDef("c", "NotImportant", {"a", "^b"}, {}),
2740        NDef("d", "NotImportant", {"b", "c"}, {}),
2741        NDef("e", "NotImportant", {"^d"}, {})},
2742       /*funcs=*/{});
2743 
2744   MutableGraphView graph(&graph_def);
2745 
2746   NodeDef* node = graph.GetNode(node_name);
2747   if (node_exists) {
2748     EXPECT_NE(node, nullptr);
2749   } else {
2750     EXPECT_EQ(node, nullptr);
2751   }
2752 
2753   absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
2754       GetNodeInputsFromGraph(graph_def, node_name);
2755 
2756   Status s = graph.UpdateAllRegularFaninsToControlling(node_name);
2757   EXPECT_EQ(s.ok(), success);
2758   if (!success) {
2759     EXPECT_EQ(s.error_message(), error_msg);
2760   }
2761   if (node_exists) {
2762     CompareNodeFanins(graph, node, expected_fanins);
2763   }
2764 
2765   CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
2766 
2767   CheckGraph(graph);
2768 }
2769 
TEST(MutableGraphViewTest,UpdateAllRegularFaninsToControlling)2770 TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControlling) {
2771   string error_msg;
2772   // Nodes with some regular fanins and some controls.
2773   TestUpdateAllRegularFaninsToControlling("a", /*node_exists=*/true,
2774                                           /*success=*/true, error_msg, {});
2775   TestUpdateAllRegularFaninsToControlling("c", /*node_exists=*/true,
2776                                           /*success=*/true, error_msg,
2777                                           {"^a", "^b"});
2778   TestUpdateAllRegularFaninsToControlling("d", /*node_exists=*/true,
2779                                           /*success=*/true, error_msg,
2780                                           {"^b", "^c"});
2781   TestUpdateAllRegularFaninsToControlling("e", /*node_exists=*/true,
2782                                           /*success=*/true, error_msg, {"^d"});
2783 
2784   // Use existing Identity to pin control dependency of Switch.
2785   TestUpdateAllRegularFaninsToControlling("b", /*node_exists=*/true,
2786                                           /*success=*/true, error_msg,
2787                                           {"^ConstantFoldingCtrl/switch_1"});
2788 
2789   // Missing node.
2790   error_msg =
2791       "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='f') "
2792       "error: node 'f' was not found.";
2793   TestUpdateAllRegularFaninsToControlling("f", /*node_exists=*/false,
2794                                           /*success=*/false, error_msg, {});
2795 
2796   // Error in getting controlling fanin.
2797   error_msg =
2798       "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='"
2799       "ConstantFoldingCtrl/switch_1') error: can't add found fanin "
2800       "'^ConstantFoldingCtrl/switch_1' to self.";
2801   TestUpdateAllRegularFaninsToControlling("ConstantFoldingCtrl/switch_1",
2802                                           /*node_exists=*/true,
2803                                           /*success=*/false, error_msg,
2804                                           {"switch:1"});
2805 }
2806 
TEST(MutableGraphViewTest,UpdateAllRegularFaninsToControllingConsumingSwitch)2807 TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControllingConsumingSwitch) {
2808   constexpr char kDevice[] = "/device:foo:0";
2809   GraphDef graph_def = test::function::GDef(
2810       {NDef("a", "NotImportant", {}, {}),
2811        NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
2812        NDef("b", "NotImportant", {"switch:1"}, {})},
2813       /*funcs=*/{});
2814 
2815   MutableGraphView graph(&graph_def);
2816 
2817   TF_EXPECT_OK(graph.UpdateAllRegularFaninsToControlling("b"));
2818 
2819   EXPECT_EQ(graph.graph()->node_size(), 4);
2820 
2821   CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2822   CheckNode(graph, "switch", "Switch", kDevice, {{"T", DT_FLOAT}}, {},
2823             {"ConstantFoldingCtrl/switch_1"});
2824   CheckNode(graph, "b", "NotImportant", "", {},
2825             {"^ConstantFoldingCtrl/switch_1"}, {});
2826   CheckNode(graph, "ConstantFoldingCtrl/switch_1", "Identity", kDevice,
2827             {{"T", DT_FLOAT}}, {"switch:1"}, {"^b"});
2828 
2829   CheckGraph(graph);
2830 }
2831 
TEST(MutableGraphViewTest,DeleteNodes)2832 TEST(MutableGraphViewTest, DeleteNodes) {
2833   // Actual node.op() is not important in this test.
2834   GraphDef graph_def = test::function::GDef(
2835       {NDef("bar", "NotImportant", {}, {}),
2836        NDef("other", "NotImportant", {}, {}),
2837        NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
2838        NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
2839       /*funcs=*/{});
2840 
2841   MutableGraphView graph(&graph_def);
2842 
2843   EXPECT_NE(graph.GetNode("foo_1"), nullptr);
2844   TF_EXPECT_OK(graph.DeleteNodes({"foo_1"}));
2845 
2846   EXPECT_EQ(graph.graph()->node_size(), 3);
2847   EXPECT_EQ(graph.GetNode("foo_1"), nullptr);
2848 
2849   CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"foo_2:1"});
2850   CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_2"});
2851   CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "bar:2"}, {});
2852 
2853   CheckGraph(graph);
2854 }
2855 
SimpleDeleteNodeGraph()2856 GraphDef SimpleDeleteNodeGraph() {
2857   // Actual node.op() is not important in this test.
2858   GraphDef graph_def = test::function::GDef(
2859       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
2860        NDef("c", "NotImportant", {"a:5", "^b"}), NDef("d", "NotImportant", {}),
2861        NDef("e", "NotImportant", {"d:2"}),
2862        NDef("f", "NotImportant", {"d:3", "^e"})},
2863       /*funcs=*/{});
2864   return graph_def;
2865 }
2866 
TEST(MutableGraphViewTest,DeleteNodesWithFanoutsBeingDeleted)2867 TEST(MutableGraphViewTest, DeleteNodesWithFanoutsBeingDeleted) {
2868   GraphDef graph_def = SimpleDeleteNodeGraph();
2869 
2870   MutableGraphView graph(&graph_def);
2871   EXPECT_NE(graph.GetNode("a"), nullptr);
2872   EXPECT_NE(graph.GetNode("b"), nullptr);
2873   EXPECT_NE(graph.GetNode("c"), nullptr);
2874   TF_EXPECT_OK(graph.DeleteNodes({"c", "a", "b"}));
2875 
2876   EXPECT_EQ(graph.graph()->node_size(), 3);
2877   EXPECT_EQ(graph.GetNode("a"), nullptr);
2878   EXPECT_EQ(graph.GetNode("b"), nullptr);
2879   EXPECT_EQ(graph.GetNode("c"), nullptr);
2880 
2881   CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2882   CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2883   CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2884 
2885   CheckGraph(graph);
2886 }
2887 
TEST(MutableGraphViewTest,DeleteMissingNodes)2888 TEST(MutableGraphViewTest, DeleteMissingNodes) {
2889   GraphDef graph_def = SimpleDeleteNodeGraph();
2890 
2891   MutableGraphView graph(&graph_def);
2892 
2893   EXPECT_EQ(graph.GetNode("g"), nullptr);
2894   EXPECT_EQ(graph.GetNode("h"), nullptr);
2895   TF_EXPECT_OK(graph.DeleteNodes({"g", "h"}));
2896 
2897   EXPECT_EQ(graph.graph()->node_size(), 6);
2898   EXPECT_EQ(graph.GetNode("g"), nullptr);
2899   EXPECT_EQ(graph.GetNode("h"), nullptr);
2900 
2901   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2902   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2903   CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2904   CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2905   CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2906   CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2907 
2908   CheckGraph(graph);
2909 }
2910 
TEST(MutableGraphViewTest,DeleteMissingNodesAndNodesWithFanoutsBeingDeleted)2911 TEST(MutableGraphViewTest, DeleteMissingNodesAndNodesWithFanoutsBeingDeleted) {
2912   GraphDef graph_def = SimpleDeleteNodeGraph();
2913 
2914   MutableGraphView graph(&graph_def);
2915 
2916   EXPECT_NE(graph.GetNode("d"), nullptr);
2917   EXPECT_NE(graph.GetNode("e"), nullptr);
2918   EXPECT_NE(graph.GetNode("f"), nullptr);
2919   TF_EXPECT_OK(graph.DeleteNodes({"d", "e", "f", "g", "h"}));
2920 
2921   EXPECT_EQ(graph.graph()->node_size(), 3);
2922   EXPECT_EQ(graph.GetNode("d"), nullptr);
2923   EXPECT_EQ(graph.GetNode("e"), nullptr);
2924   EXPECT_EQ(graph.GetNode("f"), nullptr);
2925 
2926   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2927   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2928   CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2929 
2930   CheckGraph(graph);
2931 }
2932 
TEST(MutableGraphViewTest,DeleteNodesWithError)2933 TEST(MutableGraphViewTest, DeleteNodesWithError) {
2934   GraphDef graph_def = SimpleDeleteNodeGraph();
2935 
2936   MutableGraphView graph(&graph_def);
2937 
2938   Status s = graph.DeleteNodes({"b", "a"});
2939   EXPECT_FALSE(s.ok());
2940   string error_msg =
2941       "MutableGraphView::DeleteNodes(nodes_to_delete={a, b}) error: can't "
2942       "delete node(s) with retained fanouts(s) [a, b].";
2943   EXPECT_EQ(s.error_message(), error_msg);
2944 
2945   EXPECT_EQ(graph.graph()->node_size(), 6);
2946 
2947   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2948   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2949   CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2950   CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2951   CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2952   CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2953 
2954   CheckGraph(graph);
2955 }
2956 
TEST(MutableGraphViewTest,DeleteNodesWithLargeError)2957 TEST(MutableGraphViewTest, DeleteNodesWithLargeError) {
2958   // Actual node.op() is not important in this test.
2959   GraphDef graph_def = test::function::GDef(
2960       {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
2961        NDef("c", "NotImportant", {"^b"}), NDef("d", "NotImportant", {"c:6"}),
2962        NDef("e", "NotImportant", {"d:2"}),
2963        NDef("f", "NotImportant", {"d:3", "^e"}),
2964        NDef("g", "NotImportant", {"f"}), NDef("h", "NotImportant", {"a"}),
2965        NDef("i", "NotImportant", {"b"}), NDef("j", "NotImportant", {"c"}),
2966        NDef("k", "NotImportant", {"d"}), NDef("l", "NotImportant", {"e"}),
2967        NDef("m", "NotImportant", {"f"})},
2968       /*funcs=*/{});
2969 
2970   MutableGraphView graph(&graph_def);
2971 
2972   Status s = graph.DeleteNodes({"a", "b", "c", "d", "e", "f"});
2973   EXPECT_FALSE(s.ok());
2974   string error_msg =
2975       "MutableGraphView::DeleteNodes(nodes_to_delete={a, b, c, d, e, ...}) "
2976       "error: can't delete node(s) with retained fanouts(s) [a, b, c, d, e, "
2977       "...].";
2978   EXPECT_EQ(s.error_message(), error_msg);
2979 
2980   EXPECT_EQ(graph.graph()->node_size(), 13);
2981 
2982   CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "h"});
2983   CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c", "i"});
2984   CheckNode(graph, "c", "NotImportant", "", {}, {"^b"}, {"d", "j"});
2985   CheckNode(graph, "d", "NotImportant", "", {}, {"c:6"}, {"e", "f", "k"});
2986   CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f", "l"});
2987   CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {"g", "m"});
2988   CheckNode(graph, "g", "NotImportant", "", {}, {"f"}, {});
2989   CheckNode(graph, "h", "NotImportant", "", {}, {"a"}, {});
2990   CheckNode(graph, "i", "NotImportant", "", {}, {"b"}, {});
2991   CheckNode(graph, "j", "NotImportant", "", {}, {"c"}, {});
2992   CheckNode(graph, "k", "NotImportant", "", {}, {"d"}, {});
2993   CheckNode(graph, "l", "NotImportant", "", {}, {"e"}, {});
2994   CheckNode(graph, "m", "NotImportant", "", {}, {"f"}, {});
2995 
2996   CheckGraph(graph);
2997 }
2998 
2999 }  // namespace
3000 }  // namespace grappler
3001 }  // namespace tensorflow
3002