1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17 #include "absl/strings/substitute.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/function_testlib.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/graph/tensor_id.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace grappler {
31 namespace {
32
33 using ::tensorflow::test::function::NDef;
34 using FDH = FunctionDefHelper;
35
CompareNodeFanins(const MutableGraphView & graph,NodeDef * node,absl::Span<const string> fanins)36 void CompareNodeFanins(const MutableGraphView& graph, NodeDef* node,
37 absl::Span<const string> fanins) {
38 ASSERT_EQ(node->input_size(), fanins.size());
39 for (int i = 0; i < node->input_size(); ++i) {
40 TensorId tensor_id = ParseTensorName(fanins[i]);
41 EXPECT_EQ(ParseTensorName(node->input(i)), tensor_id);
42 int port;
43 if (tensor_id.index() == Graph::kControlSlot) {
44 port = Graph::kControlSlot;
45 } else {
46 port = i;
47 }
48 MutableGraphView::InputPort input_port(node, port);
49 MutableGraphView::OutputPort output_port =
50 graph.GetOutputPort(tensor_id.node(), tensor_id.index());
51 EXPECT_TRUE(graph.GetFanin(input_port).contains(output_port));
52 EXPECT_TRUE(graph.GetFanout(output_port).contains(input_port));
53 }
54 }
55
CompareNodeFanouts(const MutableGraphView & graph,NodeDef * node,absl::Span<const string> fanouts)56 void CompareNodeFanouts(const MutableGraphView& graph, NodeDef* node,
57 absl::Span<const string> fanouts) {
58 auto node_fanouts =
59 graph.GetFanouts(*node, /*include_controlled_nodes=*/true);
60 EXPECT_EQ(node_fanouts.size(), fanouts.size());
61 for (const string& fanout : fanouts) {
62 TensorId tensor_id = ParseTensorName(fanout);
63 MutableGraphView::InputPort input_port(graph.GetNode(tensor_id.node()),
64 tensor_id.index());
65 EXPECT_TRUE(node_fanouts.contains(input_port));
66 }
67 }
68
CheckNode(const MutableGraphView & graph,absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,FDH::AttrValueWrapper>> attrs,absl::Span<const string> fanins,absl::Span<const string> fanouts)69 void CheckNode(const MutableGraphView& graph, absl::string_view node_name,
70 absl::string_view op, absl::string_view device,
71 absl::Span<const std::pair<string, FDH::AttrValueWrapper>> attrs,
72 absl::Span<const string> fanins,
73 absl::Span<const string> fanouts) {
74 NodeDef* node = graph.GetNode(node_name);
75 ASSERT_NE(node, nullptr);
76 EXPECT_EQ(node->op(), op);
77 EXPECT_EQ(node->device(), device);
78 EXPECT_EQ(node->attr_size(), attrs.size());
79 for (const auto& attr : attrs) {
80 auto it = node->attr().find(attr.first);
81 ASSERT_NE(it, node->attr().end());
82 EXPECT_TRUE(AreAttrValuesEqual(it->second, attr.second.proto));
83 }
84 CompareNodeFanins(graph, node, fanins);
85 CompareNodeFanouts(graph, node, fanouts);
86 }
87
CheckGraph(const MutableGraphView & mutable_graph)88 void CheckGraph(const MutableGraphView& mutable_graph) {
89 GraphView immutable_graph(mutable_graph.graph());
90 EXPECT_EQ(mutable_graph.graph()->node_size(),
91 immutable_graph.graph()->node_size());
92 EXPECT_EQ(mutable_graph.graph(), immutable_graph.graph());
93
94 auto check_edges =
95 [](const absl::flat_hash_set<MutableGraphView::Edge>& mutable_edges,
96 const absl::flat_hash_set<GraphView::Edge>& immutable_edges) {
97 EXPECT_EQ(mutable_edges.size(), immutable_edges.size());
98 for (const auto& fanin_edge : mutable_edges) {
99 GraphView::Edge immutable_edge(
100 {fanin_edge.src.node, fanin_edge.src.port_id},
101 {fanin_edge.dst.node, fanin_edge.dst.port_id});
102 EXPECT_TRUE(immutable_edges.contains(immutable_edge));
103 }
104 };
105
106 // Check graph connectivity.
107 for (auto& node : *mutable_graph.graph()->mutable_node()) {
108 EXPECT_EQ(&node, immutable_graph.GetNode(node.name()));
109
110 auto mutable_fanins =
111 mutable_graph.GetFanins(node, /*include_controlling_nodes=*/true);
112 auto immutable_fanins =
113 immutable_graph.GetFanins(node, /*include_controlling_nodes=*/true);
114 EXPECT_EQ(mutable_fanins.size(), immutable_fanins.size());
115 for (const auto& fanin : mutable_fanins) {
116 GraphView::OutputPort immutable_fanin(fanin.node, fanin.port_id);
117 EXPECT_TRUE(immutable_fanins.contains(immutable_fanin));
118 }
119
120 auto mutable_fanouts =
121 mutable_graph.GetFanouts(node, /*include_controlled_nodes=*/true);
122 auto immutable_fanouts =
123 immutable_graph.GetFanouts(node, /*include_controlled_nodes=*/true);
124 EXPECT_EQ(mutable_fanouts.size(), immutable_fanouts.size());
125 for (const auto& fanout : mutable_fanouts) {
126 GraphView::InputPort immutable_fanout(fanout.node, fanout.port_id);
127 EXPECT_TRUE(immutable_fanouts.contains(immutable_fanout));
128 }
129
130 auto mutable_fanin_edges =
131 mutable_graph.GetFaninEdges(node, /*include_controlling_edges=*/true);
132 auto immutable_fanin_edges =
133 immutable_graph.GetFaninEdges(node, /*include_controlling_edges=*/true);
134 check_edges(mutable_fanin_edges, immutable_fanin_edges);
135
136 auto mutable_fanout_edges =
137 mutable_graph.GetFanoutEdges(node, /*include_controlled_edges=*/true);
138 auto immutable_fanout_edges =
139 immutable_graph.GetFanoutEdges(node, /*include_controlled_edges=*/true);
140 check_edges(mutable_fanout_edges, immutable_fanout_edges);
141 }
142 }
143
TEST(MutableGraphViewTest,AddSubgraph)144 TEST(MutableGraphViewTest, AddSubgraph) {
145 GraphDef graph_def = test::function::GDef(
146 {
147 NDef("foo", "NotImportant", {}, {}),
148 NDef("bar", "NotImportant", {}, {}),
149 NDef("baz", "NotImportant", {"foo", "bar"}),
150 },
151 /*funcs=*/{});
152 MutableGraphView graph(&graph_def);
153
154 // `s/bar` node has inputs that are valid only if we add subgraph into the
155 // original graph.
156 GraphDef subgraph = test::function::GDef(
157 {
158 NDef("s/n0", "NotImportant", {}, {}),
159 NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
160 },
161 /*funcs=*/{});
162
163 TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
164
165 // Fanins and fanouts must be updated for the nodes of the original graph, and
166 // added subgraph.
167 CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"baz:1", "s/n1"});
168 CheckNode(graph, "s/n1", "NotImportant", "", {}, {"bar", "s/n0"}, {});
169 CheckGraph(graph);
170 }
171
TEST(MutableGraphViewTest,AddSubgraphAndAddFunction)172 TEST(MutableGraphViewTest, AddSubgraphAndAddFunction) {
173 GraphDef graph_def;
174 MutableGraphView graph(&graph_def);
175
176 FunctionDef x_times_two = test::function::XTimesTwo();
177 GraphDef subgraph = test::function::GDef({}, {x_times_two});
178
179 TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
180 EXPECT_EQ(graph_def.library().function_size(), 1);
181 }
182
TEST(MutableGraphViewTest,AddSubgraphAndSkipSameFunction)183 TEST(MutableGraphViewTest, AddSubgraphAndSkipSameFunction) {
184 FunctionDef x_times_two = test::function::XTimesTwo();
185
186 GraphDef graph_def = test::function::GDef({}, {x_times_two});
187 MutableGraphView graph(&graph_def);
188
189 GraphDef subgraph = test::function::GDef({}, {x_times_two});
190
191 TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
192 EXPECT_EQ(graph_def.library().function_size(), 1);
193 }
194
TEST(MutableGraphViewTest,AddSubgraphAndFailIfFunctionDifferent)195 TEST(MutableGraphViewTest, AddSubgraphAndFailIfFunctionDifferent) {
196 FunctionDef x_times_four = test::function::XTimesFour();
197 x_times_four.mutable_signature()->set_name("XTimesTwo");
198
199 GraphDef graph_def = test::function::GDef({}, {x_times_four});
200 MutableGraphView graph(&graph_def);
201
202 FunctionDef x_times_two = test::function::XTimesTwo();
203 GraphDef subgraph = test::function::GDef({}, {x_times_two});
204
205 Status status = graph.AddSubgraph(std::move(subgraph));
206 EXPECT_FALSE(status.ok());
207 EXPECT_EQ(status.error_message(),
208 "MutableGraphView::AddSubgraph(function_size=1) error: Found "
209 "different function definition with the same name: XTimesTwo.");
210 }
211
TEST(MutableGraphViewTest,UpdateNodeNoDedupControlDependency)212 TEST(MutableGraphViewTest, UpdateNodeNoDedupControlDependency) {
213 constexpr char kDevice[] = "/device:foo:0";
214 GraphDef graph_def = test::function::GDef(
215 {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
216 NDef("other", "NotImportant", {}, {}),
217 NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
218 NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
219 /*funcs=*/{});
220
221 MutableGraphView graph(&graph_def);
222
223 AttrValue list_value;
224 list_value.mutable_list()->add_type(DT_FLOAT);
225 TF_EXPECT_OK(
226 graph.UpdateNode("bar_2", "IdentityN", kDevice, {{"T", list_value}}));
227
228 CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
229 CheckNode(graph, "bar_2", "IdentityN", kDevice, {{"T", list_value}},
230 {"bar_1:1"}, {"foo_1", "foo_1:2", "^foo_1", "foo_2:1", "^foo_2"});
231 CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
232 CheckNode(graph, "foo_1", "NotImportant", "", {},
233 {"bar_2", "other", "bar_2:1", "^bar_2"}, {});
234 CheckNode(graph, "foo_2", "NotImportant", "", {},
235 {"other:1", "bar_2:2", "^bar_2"}, {});
236
237 CheckGraph(graph);
238 }
239
TEST(MutableGraphViewTest,UpdateNodeDedupControlDependency)240 TEST(MutableGraphViewTest, UpdateNodeDedupControlDependency) {
241 constexpr char kDevice[] = "/device:foo:0";
242 GraphDef graph_def = test::function::GDef(
243 {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
244 NDef("other", "NotImportant", {}, {}),
245 NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
246 NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
247 /*funcs=*/{});
248
249 MutableGraphView graph(&graph_def);
250
251 TF_EXPECT_OK(graph.UpdateNode("bar_2", "NotImportant", kDevice, {}));
252
253 CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
254 CheckNode(graph, "bar_2", "NotImportant", kDevice, {}, {"bar_1:1"},
255 {"foo_1", "foo_1:2", "foo_2:1"});
256 CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
257 CheckNode(graph, "foo_1", "NotImportant", "", {},
258 {"bar_2", "other", "bar_2:1"}, {});
259 CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "bar_2:2"}, {});
260
261 CheckGraph(graph);
262 }
263
TEST(MutableGraphViewTest,UpdateNodeSwitchNoControlDependency)264 TEST(MutableGraphViewTest, UpdateNodeSwitchNoControlDependency) {
265 constexpr char kDevice[] = "/device:foo:0";
266 GraphDef graph_def =
267 test::function::GDef({NDef("foo", "NotImportant", {}, {}),
268 NDef("bar", "NotImportant", {"foo:1"})},
269 /*funcs=*/{});
270
271 MutableGraphView graph(&graph_def);
272
273 TF_EXPECT_OK(graph.UpdateNode("foo", "Switch", kDevice, {}));
274
275 CheckNode(graph, "foo", "Switch", kDevice, {}, {}, {"bar"});
276 CheckNode(graph, "bar", "NotImportant", "", {}, {"foo:1"}, {});
277
278 CheckGraph(graph);
279 }
280
TEST(MutableGraphViewTest,UpdateNodeSwitchControlDependency)281 TEST(MutableGraphViewTest, UpdateNodeSwitchControlDependency) {
282 constexpr char kDevice[] = "/device:foo:0";
283 GraphDef graph_def =
284 test::function::GDef({NDef("foo", "NotImportant", {}, {}),
285 NDef("bar", "NotImportant", {"^foo"})},
286 /*funcs=*/{});
287
288 MutableGraphView graph(&graph_def);
289
290 AttrValue attr;
291 attr.set_type(DT_FLOAT);
292 Status s = graph.UpdateNode("foo", "Switch", kDevice, {{"T", attr}});
293 EXPECT_FALSE(s.ok());
294 string expected_msg =
295 "MutableGraphView::UpdateNodeOp(node_name='foo', op='Switch', "
296 "device='/device:foo:0', attrs={('T', type: DT_FLOAT)}) error: can't "
297 "change node op to Switch when node drives a control dependency "
298 "(alternatively, we could add the identity node needed, but it seems "
299 "like an unlikely event and probably a mistake).";
300 EXPECT_EQ(s.error_message(), expected_msg);
301
302 CheckNode(graph, "foo", "NotImportant", "", {}, {}, {"^bar"});
303 CheckNode(graph, "bar", "NotImportant", "", {}, {"^foo"}, {});
304
305 CheckGraph(graph);
306 }
307
GetNodeInputsFromGraph(const GraphDef & graph,absl::string_view node_to_exclude)308 absl::flat_hash_map<string, std::vector<string>> GetNodeInputsFromGraph(
309 const GraphDef& graph, absl::string_view node_to_exclude) {
310 absl::flat_hash_map<string, std::vector<string>> node_inputs;
311 for (const auto& node : graph.node()) {
312 if (node.name() == node_to_exclude) {
313 continue;
314 }
315 node_inputs[node.name()] =
316 std::vector<string>(node.input().begin(), node.input().end());
317 }
318 return node_inputs;
319 }
320
CheckUnmodifiedNodeFanins(const GraphDef & graph,absl::string_view node_to_exclude,const absl::flat_hash_map<string,std::vector<string>> & unmodified_node_inputs)321 void CheckUnmodifiedNodeFanins(
322 const GraphDef& graph, absl::string_view node_to_exclude,
323 const absl::flat_hash_map<string, std::vector<string>>&
324 unmodified_node_inputs) {
325 for (const auto& node : graph.node()) {
326 if (node.name() == node_to_exclude) {
327 continue;
328 }
329 auto it = unmodified_node_inputs.find(node.name());
330 ASSERT_NE(it, unmodified_node_inputs.end());
331 ASSERT_EQ(it->second.size(), node.input_size());
332 for (int i = 0; i < node.input_size(); ++i) {
333 EXPECT_EQ(node.input(i), it->second[i]);
334 }
335 }
336 }
337
TestUpdateNodeName(absl::string_view from_node_name,bool node_exists,absl::string_view to_node_name,bool update_fanouts,bool success,const string & error_msg,absl::Span<const string> expected_fanins)338 void TestUpdateNodeName(absl::string_view from_node_name, bool node_exists,
339 absl::string_view to_node_name, bool update_fanouts,
340 bool success, const string& error_msg,
341 absl::Span<const string> expected_fanins) {
342 GraphDef graph_def = test::function::GDef(
343 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
344 NDef("c", "NotImportant", {}, {})},
345 /*funcs=*/{});
346
347 MutableGraphView graph(&graph_def);
348
349 NodeDef* node = graph.GetNode(from_node_name);
350 if (node_exists) {
351 EXPECT_NE(node, nullptr);
352 } else {
353 EXPECT_EQ(node, nullptr);
354 }
355
356 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
357 GetNodeInputsFromGraph(graph_def, from_node_name);
358
359 Status s = graph.UpdateNodeName(from_node_name, to_node_name, update_fanouts);
360 EXPECT_EQ(s.ok(), success);
361 string updated_node_name;
362 if (success) {
363 updated_node_name = string(to_node_name);
364 } else {
365 updated_node_name = string(from_node_name);
366 EXPECT_EQ(s.error_message(), error_msg);
367 }
368 if (node_exists) {
369 EXPECT_EQ(node->name(), updated_node_name);
370 CompareNodeFanins(graph, node, expected_fanins);
371 }
372
373 CheckUnmodifiedNodeFanins(graph_def, updated_node_name,
374 unmodified_node_inputs);
375
376 CheckGraph(graph);
377 }
378
TEST(MutableGraphViewTest,UpdateNodeName)379 TEST(MutableGraphViewTest, UpdateNodeName) {
380 string error_msg;
381 // Node has no fanouts.
382 TestUpdateNodeName("b", /*node_exists=*/true, "d", /*update_fanouts=*/false,
383 /*success=*/true, error_msg, {"a"});
384 // Node has fanouts and rename to self.
385 TestUpdateNodeName("b", /*node_exists=*/true, "b", /*update_fanouts=*/false,
386 /*success=*/true, error_msg, {"a"});
387 // Node has no fanouts and rename to self.
388 TestUpdateNodeName("a", /*node_exists=*/true, "a", /*update_fanouts=*/false,
389 /*success=*/true, error_msg, {});
390
391 // New node name is in use.
392 error_msg =
393 "MutableGraphView::UpdateNodeName(from_node_name='c', to_node_name='b', "
394 "update_fanouts=false) error: can't update node name because new node "
395 "name is in use.";
396 TestUpdateNodeName("c", /*node_exists=*/true, "b", /*update_fanouts=*/false,
397 /*success=*/false, error_msg, {});
398 error_msg =
399 "MutableGraphView::UpdateNodeName(from_node_name='a', to_node_name='b', "
400 "update_fanouts=true) error: can't update node name because new node "
401 "name is in use.";
402 TestUpdateNodeName("a", /*node_exists=*/true, "b", /*update_fanouts=*/true,
403 /*success=*/false, error_msg, {});
404 // Node has fanouts.
405 error_msg =
406 "MutableGraphView::UpdateNodeName(from_node_name='a', to_node_name='d', "
407 "update_fanouts=false) error: can't update node name because node has "
408 "fanouts.";
409 TestUpdateNodeName("a", /*node_exists=*/true, "d", /*update_fanouts=*/false,
410 /*success=*/false, error_msg, {});
411 // Node does not exist.
412 error_msg =
413 "MutableGraphView::UpdateNodeName(from_node_name='d', to_node_name='e', "
414 "update_fanouts=false) error: node 'd' was not found.";
415 TestUpdateNodeName("d", /*node_exists=*/false, "e", /*update_fanouts=*/false,
416 /*success=*/false, error_msg, {});
417 error_msg =
418 "MutableGraphView::UpdateNodeName(from_node_name='d', to_node_name='e', "
419 "update_fanouts=true) error: node 'd' was not found.";
420 TestUpdateNodeName("d", /*node_exists=*/false, "e", /*update_fanouts=*/true,
421 /*success=*/false, error_msg, {});
422 }
423
TEST(MutableGraphViewTest,UpdateNodeNameWithFanouts)424 TEST(MutableGraphViewTest, UpdateNodeNameWithFanouts) {
425 GraphDef graph_def = test::function::GDef(
426 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
427 NDef("c", "NotImportant", {"b", "^a"}),
428 NDef("d", "NotImportant", {"^b", "^a"}),
429 NDef("e", "NotImportant", {"b:2", "c:4", "b:1", "^a"})},
430 /*funcs=*/{});
431
432 MutableGraphView graph(&graph_def);
433
434 TF_EXPECT_OK(graph.UpdateNodeName("b", "f", /*update_fanouts=*/true));
435
436 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"f", "^c", "^d", "^e"});
437 CheckNode(graph, "f", "NotImportant", "", {}, {"a:2"},
438 {"c", "^d", "e", "e:2"});
439 CheckNode(graph, "c", "NotImportant", "", {}, {"f", "^a"}, {"e:1"});
440 CheckNode(graph, "d", "NotImportant", "", {}, {"^f", "^a"}, {});
441 CheckNode(graph, "e", "NotImportant", "", {}, {"f:2", "c:4", "f:1", "^a"},
442 {});
443
444 CheckGraph(graph);
445 }
446
SimpleSwapNodeNamesMutationGraph()447 GraphDef SimpleSwapNodeNamesMutationGraph() {
448 return test::function::GDef(
449 {NDef("a", "NotImportant", {}, {}), NDef("switch_1", "Switch", {"a"}),
450 NDef("identity_1", "Identity", {"switch_1:1"}),
451 NDef("b", "NotImportant", {}, {}), NDef("switch_2", "Switch", {"b"}),
452 NDef("identity_2", "Identity", {"switch_2:0"}),
453 NDef("foo_1", "NotImportant", {"identity_1", "^identity_1"}),
454 NDef("foo_2", "NotImportant", {"identity_2", "^identity_2"})},
455 /*funcs=*/{});
456 }
457
TestSwapNodeNames(bool update_fanouts)458 void TestSwapNodeNames(bool update_fanouts) {
459 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
460
461 MutableGraphView graph(&graph_def);
462
463 TF_EXPECT_OK(graph.SwapNodeNames("foo_1", "foo_2", update_fanouts));
464
465 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
466 CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
467 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
468 {"foo_2", "^foo_2"});
469 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
470 CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
471 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
472 {"foo_1", "^foo_1"});
473 CheckNode(graph, "foo_2", "NotImportant", "", {},
474 {"identity_1", "^identity_1"}, {});
475 CheckNode(graph, "foo_1", "NotImportant", "", {},
476 {"identity_2", "^identity_2"}, {});
477
478 CheckGraph(graph);
479 }
480
TEST(MutableGraphView,SwapNodeNames)481 TEST(MutableGraphView, SwapNodeNames) {
482 TestSwapNodeNames(/*update_fanouts=*/false);
483 TestSwapNodeNames(/*update_fanouts=*/true);
484 }
485
TestSwapNodeNamesWithSameNames(bool update_fanouts)486 void TestSwapNodeNamesWithSameNames(bool update_fanouts) {
487 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
488
489 MutableGraphView graph(&graph_def);
490
491 TF_EXPECT_OK(graph.SwapNodeNames("identity_1", "identity_1", update_fanouts));
492
493 // No changes to graph.
494 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
495 CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
496 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
497 {"foo_1", "^foo_1"});
498 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
499 CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
500 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
501 {"foo_2", "^foo_2"});
502 CheckNode(graph, "foo_1", "NotImportant", "", {},
503 {"identity_1", "^identity_1"}, {});
504 CheckNode(graph, "foo_2", "NotImportant", "", {},
505 {"identity_2", "^identity_2"}, {});
506
507 CheckGraph(graph);
508 }
509
TEST(MutableGraphView,SwapNodeNamesSameName)510 TEST(MutableGraphView, SwapNodeNamesSameName) {
511 TestSwapNodeNamesWithSameNames(/*update_fanouts=*/false);
512 TestSwapNodeNamesWithSameNames(/*update_fanouts=*/true);
513 }
514
TEST(MutableGraphView,SwapNodeNamesBetweenSwitches)515 TEST(MutableGraphView, SwapNodeNamesBetweenSwitches) {
516 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
517
518 MutableGraphView graph(&graph_def);
519
520 TF_EXPECT_OK(
521 graph.SwapNodeNames("switch_1", "switch_2", /*update_fanouts=*/false));
522
523 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_2"});
524 CheckNode(graph, "switch_2", "Switch", "", {}, {"a"}, {"identity_2"});
525 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
526 {"foo_1", "^foo_1"});
527 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_1"});
528 CheckNode(graph, "switch_1", "Switch", "", {}, {"b"}, {"identity_1"});
529 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
530 {"foo_2", "^foo_2"});
531 CheckNode(graph, "foo_1", "NotImportant", "", {},
532 {"identity_1", "^identity_1"}, {});
533 CheckNode(graph, "foo_2", "NotImportant", "", {},
534 {"identity_2", "^identity_2"}, {});
535
536 CheckGraph(graph);
537 }
538
TEST(MutableGraphView,SwapNodeNamesBetweenSwitchesAndUpdateFanouts)539 TEST(MutableGraphView, SwapNodeNamesBetweenSwitchesAndUpdateFanouts) {
540 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
541
542 MutableGraphView graph(&graph_def);
543
544 TF_EXPECT_OK(
545 graph.SwapNodeNames("switch_1", "switch_2", /*update_fanouts=*/true));
546
547 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_2"});
548 CheckNode(graph, "switch_2", "Switch", "", {}, {"a"}, {"identity_1"});
549 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_2:1"},
550 {"foo_1", "^foo_1"});
551 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_1"});
552 CheckNode(graph, "switch_1", "Switch", "", {}, {"b"}, {"identity_2"});
553 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_1:0"},
554 {"foo_2", "^foo_2"});
555 CheckNode(graph, "foo_1", "NotImportant", "", {},
556 {"identity_1", "^identity_1"}, {});
557 CheckNode(graph, "foo_2", "NotImportant", "", {},
558 {"identity_2", "^identity_2"}, {});
559
560 CheckGraph(graph);
561 }
562
TEST(MutableGraphView,SwapNodeNamesSwitchAndNonSwitch)563 TEST(MutableGraphView, SwapNodeNamesSwitchAndNonSwitch) {
564 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
565
566 MutableGraphView graph(&graph_def);
567
568 TF_EXPECT_OK(graph.SwapNodeNames("a", "switch_1", /*update_fanouts=*/false));
569
570 // Dedup controls and fix self loop.
571 CheckNode(graph, "switch_1", "NotImportant", "", {}, {}, {"a", "identity_1"});
572 CheckNode(graph, "a", "Switch", "", {}, {"switch_1"}, {});
573 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"}, {"foo_1"});
574 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
575 CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
576 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
577 {"foo_2", "^foo_2"});
578 CheckNode(graph, "foo_1", "NotImportant", "", {}, {"identity_1"}, {});
579 CheckNode(graph, "foo_2", "NotImportant", "", {},
580 {"identity_2", "^identity_2"}, {});
581
582 CheckGraph(graph);
583 }
584
TEST(MutableGraphView,SwapNodeNamesSwitchAndNonSwitchAndUpdateFanouts)585 TEST(MutableGraphView, SwapNodeNamesSwitchAndNonSwitchAndUpdateFanouts) {
586 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
587
588 MutableGraphView graph(&graph_def);
589
590 TF_EXPECT_OK(graph.SwapNodeNames("a", "switch_1", /*update_fanouts=*/true));
591
592 CheckNode(graph, "switch_1", "NotImportant", "", {}, {}, {"a"});
593 CheckNode(graph, "a", "Switch", "", {}, {"switch_1"}, {"identity_1"});
594 CheckNode(graph, "identity_1", "Identity", "", {}, {"a:1"},
595 {"foo_1", "^foo_1"});
596 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
597 CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
598 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
599 {"foo_2", "^foo_2"});
600 CheckNode(graph, "foo_1", "NotImportant", "", {},
601 {"identity_1", "^identity_1"}, {});
602 CheckNode(graph, "foo_2", "NotImportant", "", {},
603 {"identity_2", "^identity_2"}, {});
604
605 CheckGraph(graph);
606 }
607
TEST(MutableGraphView,SwapNodeNamesNonSwitchAndSwitch)608 TEST(MutableGraphView, SwapNodeNamesNonSwitchAndSwitch) {
609 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
610
611 MutableGraphView graph(&graph_def);
612
613 TF_EXPECT_OK(graph.SwapNodeNames("switch_2", "b", /*update_fanouts=*/false));
614
615 // Dedup controls and fix self loop.
616 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
617 CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
618 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
619 {"foo_1", "^foo_1"});
620 CheckNode(graph, "switch_2", "NotImportant", "", {}, {}, {"b", "identity_2"});
621 CheckNode(graph, "b", "Switch", "", {}, {"switch_2"}, {});
622 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"}, {"foo_2"});
623 CheckNode(graph, "foo_1", "NotImportant", "", {},
624 {"identity_1", "^identity_1"}, {});
625 CheckNode(graph, "foo_2", "NotImportant", "", {}, {"identity_2"}, {});
626
627 CheckGraph(graph);
628 }
629
TEST(MutableGraphView,SwapNodeNamesNonSwitchAndSwitchAndUpdateFanouts)630 TEST(MutableGraphView, SwapNodeNamesNonSwitchAndSwitchAndUpdateFanouts) {
631 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
632
633 MutableGraphView graph(&graph_def);
634
635 TF_EXPECT_OK(graph.SwapNodeNames("switch_2", "b", /*update_fanouts=*/true));
636
637 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
638 CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
639 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
640 {"foo_1", "^foo_1"});
641 CheckNode(graph, "switch_2", "NotImportant", "", {}, {}, {"b"});
642 CheckNode(graph, "b", "Switch", "", {}, {"switch_2"}, {"identity_2"});
643 CheckNode(graph, "identity_2", "Identity", "", {}, {"b:0"},
644 {"foo_2", "^foo_2"});
645 CheckNode(graph, "foo_1", "NotImportant", "", {},
646 {"identity_1", "^identity_1"}, {});
647 CheckNode(graph, "foo_2", "NotImportant", "", {},
648 {"identity_2", "^identity_2"}, {});
649
650 CheckGraph(graph);
651 }
652
TestSwapNodeNamesSimpleSelfLoop(bool update_fanouts)653 void TestSwapNodeNamesSimpleSelfLoop(bool update_fanouts) {
654 GraphDef graph_def = test::function::GDef(
655 {NDef("a", "NotImportant", {"b:7"}), NDef("b", "NotImportant", {"a:10"})},
656 /*funcs=*/{});
657
658 MutableGraphView graph(&graph_def);
659
660 TF_EXPECT_OK(graph.SwapNodeNames("a", "b", update_fanouts));
661
662 // No self loops.
663 CheckNode(graph, "a", "NotImportant", "", {}, {"b:10"}, {"b:0"});
664 CheckNode(graph, "b", "NotImportant", "", {}, {"a:7"}, {"a:0"});
665
666 CheckGraph(graph);
667 }
668
TEST(MutableGraphView,SwapNodeNamesSelfLoops)669 TEST(MutableGraphView, SwapNodeNamesSelfLoops) {
670 TestSwapNodeNamesSimpleSelfLoop(/*update_fanouts=*/false);
671 TestSwapNodeNamesSimpleSelfLoop(/*update_fanouts=*/true);
672 }
673
TestSwapNodeNamesError(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts,const string & error_msg)674 void TestSwapNodeNamesError(absl::string_view from_node_name,
675 absl::string_view to_node_name, bool update_fanouts,
676 const string& error_msg) {
677 GraphDef graph_def = SimpleSwapNodeNamesMutationGraph();
678
679 MutableGraphView graph(&graph_def);
680
681 Status s = graph.SwapNodeNames(from_node_name, to_node_name, update_fanouts);
682 EXPECT_EQ(s.ok(), false);
683 EXPECT_EQ(s.error_message(), error_msg);
684
685 // No changes to graph.
686 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"switch_1"});
687 CheckNode(graph, "switch_1", "Switch", "", {}, {"a"}, {"identity_1"});
688 CheckNode(graph, "identity_1", "Identity", "", {}, {"switch_1:1"},
689 {"foo_1", "^foo_1"});
690 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"switch_2"});
691 CheckNode(graph, "switch_2", "Switch", "", {}, {"b"}, {"identity_2"});
692 CheckNode(graph, "identity_2", "Identity", "", {}, {"switch_2:0"},
693 {"foo_2", "^foo_2"});
694 CheckNode(graph, "foo_1", "NotImportant", "", {},
695 {"identity_1", "^identity_1"}, {});
696 CheckNode(graph, "foo_2", "NotImportant", "", {},
697 {"identity_2", "^identity_2"}, {});
698
699 CheckGraph(graph);
700 }
701
702 // TODO(lyandy): add tests with update_fanouts == true.
TEST(MutableGraphView,SwapNodeNamesError)703 TEST(MutableGraphView, SwapNodeNamesError) {
704 string error_msg;
705 // Missing nodes.
706 error_msg =
707 "MutableGraphView::SwapNodeNames(from_node_name='foo_3', "
708 "to_node_name='foo_2', update_fanouts=false) error: node 'foo_3' was not "
709 "found.";
710 TestSwapNodeNamesError("foo_3", "foo_2", /*update_fanouts=*/false, error_msg);
711 error_msg =
712 "MutableGraphView::SwapNodeNames(from_node_name='foo_3', "
713 "to_node_name='foo_2', update_fanouts=true) error: node 'foo_3' was not "
714 "found.";
715 TestSwapNodeNamesError("foo_3", "foo_2", /*update_fanouts=*/true, error_msg);
716 error_msg =
717 "MutableGraphView::SwapNodeNames(from_node_name='foo_1', "
718 "to_node_name='foo_4', update_fanouts=false) error: node 'foo_4' was not "
719 "found.";
720 TestSwapNodeNamesError("foo_1", "foo_4", /*update_fanouts=*/false, error_msg);
721 error_msg =
722 "MutableGraphView::SwapNodeNames(from_node_name='foo_1', "
723 "to_node_name='foo_4', update_fanouts=true) error: node 'foo_4' was not "
724 "found.";
725 TestSwapNodeNamesError("foo_1", "foo_4", /*update_fanouts=*/true, error_msg);
726 error_msg =
727 "MutableGraphView::SwapNodeNames(from_node_name='foo_5', "
728 "to_node_name='foo_6', update_fanouts=false) error: node 'foo_5' was not "
729 "found.";
730 TestSwapNodeNamesError("foo_5", "foo_6", /*update_fanouts=*/false, error_msg);
731 error_msg =
732 "MutableGraphView::SwapNodeNames(from_node_name='foo_5', "
733 "to_node_name='foo_6', update_fanouts=true) error: node 'foo_5' was not "
734 "found.";
735 TestSwapNodeNamesError("foo_5", "foo_6", /*update_fanouts=*/true, error_msg);
736
737 // Switch control dependencies.
738 error_msg =
739 "MutableGraphView::SwapNodeNames(from_node_name='switch_2', "
740 "to_node_name='identity_1', update_fanouts=false) error: can't swap node "
741 "name 'switch_2' as it will become a Switch control dependency.";
742 TestSwapNodeNamesError("switch_2", "identity_1", /*update_fanouts=*/false,
743 error_msg);
744 error_msg =
745 "MutableGraphView::SwapNodeNames(from_node_name='identity_2', "
746 "to_node_name='switch_1', update_fanouts=false) error: can't swap node "
747 "name 'switch_1' as it will become a Switch control dependency.";
748 TestSwapNodeNamesError("identity_2", "switch_1", /*update_fanouts=*/false,
749 error_msg);
750 }
751
TEST(MutableGraphViewTest,AddAndUpdateFanouts)752 TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
753 // Actual node.op() is not important in this test.
754 GraphDef graph_def = test::function::GDef(
755 {NDef("bar", "NotImportant", {}, {}),
756 NDef("other", "NotImportant", {}, {}),
757 NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
758 NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"}),
759 NDef("foo_3", "NotImportant", {"other:2", "^bar"})},
760 /*funcs=*/{});
761
762 MutableGraphView graph(&graph_def);
763
764 NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NotImportant", {}, {}));
765
766 TF_EXPECT_OK(graph.UpdateFanouts("bar", new_bar->name()));
767
768 // Fanins and fanouts must be updated.
769 CheckNode(graph, "bar", "NotImportant", "", {}, {}, {});
770 CheckNode(graph, "other", "NotImportant", "", {}, {},
771 {"foo_1:1", "foo_2", "foo_3"});
772 CheckNode(graph, "foo_1", "NotImportant", "", {},
773 {"new_bar", "other", "new_bar:1"}, {});
774 CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "new_bar:2"},
775 {});
776 CheckNode(graph, "foo_3", "NotImportant", "", {}, {"other:2", "^new_bar"},
777 {});
778 CheckNode(graph, "new_bar", "NotImportant", "", {}, {},
779 {"foo_1:0", "foo_1:2", "foo_2:1", "^foo_3"});
780
781 CheckGraph(graph);
782 }
783
TEST(MutableGraphViewTest,AddAndUpdateFanoutsKeepControls)784 TEST(MutableGraphViewTest, AddAndUpdateFanoutsKeepControls) {
785 GraphDef graph_def = test::function::GDef(
786 {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
787 NDef("other", "NotImportant", {}, {}),
788 NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
789 NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
790 /*funcs=*/{});
791
792 MutableGraphView graph(&graph_def);
793
794 NodeDef* new_bar = graph.AddNode(NDef("new_bar", "Identity", {"bar_1:2"}));
795
796 TF_EXPECT_OK(graph.UpdateFanouts("bar_2", new_bar->name()));
797
798 // Fanins and fanouts must be updated.
799 CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2", "new_bar"});
800 CheckNode(graph, "bar_2", "Identity", "", {}, {"bar_1:1"}, {});
801 CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
802 CheckNode(graph, "foo_1", "NotImportant", "", {},
803 {"new_bar", "other", "new_bar:1", "^new_bar"}, {});
804 CheckNode(graph, "foo_2", "NotImportant", "", {},
805 {"other:1", "new_bar:2", "^new_bar"}, {});
806 CheckNode(graph, "new_bar", "Identity", "", {}, {"bar_1:2"},
807 {"foo_1", "foo_1:2", "^foo_1", "foo_2:1", "^foo_2"});
808
809 CheckGraph(graph);
810 }
811
TEST(MutableGraphViewTest,AddAndUpdateFanoutsWithoutSelfLoops)812 TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
813 // Actual node.op() is not important in this test.
814 GraphDef graph_def =
815 test::function::GDef({NDef("bar", "NotImportant", {}, {}),
816 NDef("foo_1", "NotImportant", {"bar", "^bar"}),
817 NDef("foo_2", "NotImportant", {"^bar"})},
818 /*funcs=*/{});
819
820 MutableGraphView graph(&graph_def);
821
822 // `new_bar` reads the output of an original `bar` node.
823 NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NewBar", {"bar"}, {}));
824
825 TF_EXPECT_OK(graph.UpdateFanouts("bar", new_bar->name()));
826
827 // Fanins and fanouts must be updated.
828 CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"new_bar"});
829 CheckNode(graph, "foo_1", "NotImportant", "", {}, {"new_bar"}, {});
830 CheckNode(graph, "foo_2", "NotImportant", "", {}, {"^new_bar"}, {});
831 CheckNode(graph, "new_bar", "NewBar", "", {}, {"bar"}, {"foo_1", "^foo_2"});
832
833 CheckGraph(graph);
834 }
835
TEST(MutableGraphViewTest,UpdateFanoutsToSwitchWithControlFromSwitch)836 TEST(MutableGraphViewTest, UpdateFanoutsToSwitchWithControlFromSwitch) {
837 GraphDef graph_def = test::function::GDef(
838 {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
839 NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
840 NDef("e", "NotImportant", {"c", "b", "^a", "^d"})},
841 /*funcs=*/{});
842
843 MutableGraphView graph(&graph_def);
844
845 Status s = graph.UpdateFanouts("a", "b");
846 EXPECT_FALSE(s.ok());
847 string expected_msg =
848 "MutableGraphView::UpdateFanouts(from_node_name='a', to_node_name='b') "
849 "error: can't update fanouts to node 'b' as it will become a Switch "
850 "control dependency.";
851 EXPECT_EQ(s.error_message(), expected_msg);
852 s = graph.UpdateFanouts("d", "b");
853 EXPECT_FALSE(s.ok());
854 expected_msg =
855 "MutableGraphView::UpdateFanouts(from_node_name='d', to_node_name='b') "
856 "error: can't update fanouts to node 'b' as it will become a Switch "
857 "control dependency.";
858 EXPECT_EQ(s.error_message(), expected_msg);
859
860 EXPECT_EQ(graph.graph()->node_size(), 5);
861
862 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
863 CheckNode(graph, "b", "Switch", "", {}, {}, {"e:1"});
864 CheckNode(graph, "c", "NotImportant", "", {}, {}, {"e:0"});
865 CheckNode(graph, "d", "NotImportant", "", {}, {}, {"^e"});
866 CheckNode(graph, "e", "NotImportant", "", {}, {"c", "b", "^a", "^d"}, {});
867
868 CheckGraph(graph);
869 }
870
TEST(MutableGraphViewTest,UpdateFanoutsToSwitchWithNoControlFromSwitch)871 TEST(MutableGraphViewTest, UpdateFanoutsToSwitchWithNoControlFromSwitch) {
872 GraphDef graph_def = test::function::GDef(
873 {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
874 NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
875 NDef("e", "NotImportant", {"c", "b", "^a", "^d"})},
876 /*funcs=*/{});
877
878 MutableGraphView graph(&graph_def);
879
880 TF_EXPECT_OK(graph.UpdateFanouts("c", "b"));
881
882 EXPECT_EQ(graph.graph()->node_size(), 5);
883
884 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
885 CheckNode(graph, "b", "Switch", "", {}, {}, {"e:0", "e:1"});
886 CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
887 CheckNode(graph, "d", "NotImportant", "", {}, {}, {"^e"});
888 CheckNode(graph, "e", "NotImportant", "", {}, {"b", "b", "^a", "^d"}, {});
889
890 CheckGraph(graph);
891 }
892
SimpleMutateFaninGraph()893 GraphDef SimpleMutateFaninGraph() {
894 // Actual node.op() is not important in this test.
895 GraphDef graph_def = test::function::GDef(
896 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
897 NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
898 NDef("foo_1", "NotImportant", {"a"}),
899 NDef("foo_2", "NotImportant", {"b", "^a", "^c"}),
900 NDef("foo_3", "NotImportant", {"b", "a:1", "a:1"}),
901 NDef("foo_4", "NotImportant", {"a", "b:2", "b:2", "^c", "^d"}),
902 NDef("foo_5", "NotImportant", {}),
903 NDef("foo_6", "NotImportant", {"^a", "^b"})},
904 /*funcs=*/{});
905 return graph_def;
906 }
907
TestAddRegularFanin(absl::string_view node_name,bool node_exists,const TensorId & fanin_to_add,bool success,const string & error_msg,absl::Span<const string> expected_fanins)908 void TestAddRegularFanin(absl::string_view node_name, bool node_exists,
909 const TensorId& fanin_to_add, bool success,
910 const string& error_msg,
911 absl::Span<const string> expected_fanins) {
912 GraphDef graph_def = SimpleMutateFaninGraph();
913
914 MutableGraphView graph(&graph_def);
915
916 NodeDef* node = graph.GetNode(node_name);
917 if (node_exists) {
918 EXPECT_NE(node, nullptr);
919 } else {
920 EXPECT_EQ(node, nullptr);
921 }
922
923 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
924 GetNodeInputsFromGraph(graph_def, node_name);
925
926 Status s = graph.AddRegularFanin(node_name, fanin_to_add);
927 EXPECT_EQ(s.ok(), success);
928 if (!success) {
929 EXPECT_EQ(s.error_message(), error_msg);
930 }
931 if (node_exists) {
932 CompareNodeFanins(graph, node, expected_fanins);
933 }
934
935 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
936
937 CheckGraph(graph);
938 }
939
TEST(MutableGraphViewTest,AddRegularFanin)940 TEST(MutableGraphViewTest, AddRegularFanin) {
941 string error_msg;
942 // Add input to node with 1 input 0 controls.
943 TestAddRegularFanin("foo_1", /*node_exists=*/true, {"b", 1}, /*success=*/true,
944 error_msg, {"a", "b:1"});
945 // Add input to node with multiple inputs and 0 controls.
946 TestAddRegularFanin("foo_3", /*node_exists=*/true, {"b", 2}, /*success=*/true,
947 error_msg, {"b", "a:1", "a:1", "b:2"});
948 // Add input to node with 1 input multiple controls.
949 TestAddRegularFanin("foo_2", /*node_exists=*/true, {"a", 0}, /*success=*/true,
950 error_msg, {"b", "a", "^c"});
951 // Add input to node with multiple inputs and controls.
952 TestAddRegularFanin("foo_4", /*node_exists=*/true, {"a", 1}, /*success=*/true,
953 error_msg, {"a", "b:2", "b:2", "a:1", "^d", "^c"});
954 // Add input to node with 0 inputs 0 controls.
955 TestAddRegularFanin("foo_5", /*node_exists=*/true, {"a", 1}, /*success=*/true,
956 error_msg, {"a:1"});
957 // Add input to node with 0 inputs multiple controls.
958 TestAddRegularFanin("foo_6", /*node_exists=*/true, {"c", 1}, /*success=*/true,
959 error_msg, {"c:1", "^b", "^a"});
960
961 // Add control to node with 1 input 0 controls.
962 error_msg =
963 "MutableGraphView::AddRegularFanin(node_name='foo_1', fanin='^b') error: "
964 "fanin '^b' must be a regular tensor id.";
965 TestAddRegularFanin("foo_1", /*node_exists=*/true, {"b", Graph::kControlSlot},
966 /*success=*/false, error_msg, {"a"});
967 // Add control to node with multiple inputs and 0 controls.
968 error_msg =
969 "MutableGraphView::AddRegularFanin(node_name='foo_3', fanin='^c') error: "
970 "fanin '^c' must be a regular tensor id.";
971 TestAddRegularFanin("foo_3", /*node_exists=*/true, {"c", Graph::kControlSlot},
972 /*success=*/false, error_msg, {"b", "a:1", "a:1"});
973 // Add control to node with 1 input multiple controls.
974 error_msg =
975 "MutableGraphView::AddRegularFanin(node_name='foo_2', fanin='^d') error: "
976 "fanin '^d' must be a regular tensor id.";
977 TestAddRegularFanin("foo_2", /*node_exists=*/true, {"d", Graph::kControlSlot},
978 /*success=*/false, error_msg, {"b", "^a", "^c"});
979 // Add control to node with multiple input multiple controls.
980 error_msg =
981 "MutableGraphView::AddRegularFanin(node_name='foo_4', fanin='^a') error: "
982 "fanin '^a' must be a regular tensor id.";
983 TestAddRegularFanin("foo_4", /*node_exists=*/true, {"a", Graph::kControlSlot},
984 /*success=*/false, error_msg,
985 {"a", "b:2", "b:2", "^c", "^d"});
986 // Add control to node with 0 inputs 0 controls.
987 error_msg =
988 "MutableGraphView::AddRegularFanin(node_name='foo_5', fanin='^a') error: "
989 "fanin '^a' must be a regular tensor id.";
990 TestAddRegularFanin("foo_5", /*node_exists=*/true, {"a", Graph::kControlSlot},
991 /*success=*/false, error_msg, {});
992 // Add control to node with 0 inputs multiple controls.
993 error_msg =
994 "MutableGraphView::AddRegularFanin(node_name='foo_6', fanin='^c') error: "
995 "fanin '^c' must be a regular tensor id.";
996 TestAddRegularFanin("foo_6", /*node_exists=*/true, {"c", Graph::kControlSlot},
997 /*success=*/false, error_msg, {"^a", "^b"});
998 // Add control to node with control that already exists.
999 error_msg =
1000 "MutableGraphView::AddRegularFanin(node_name='foo_2', fanin='^a') error: "
1001 "fanin '^a' must be a regular tensor id.";
1002 TestAddRegularFanin("foo_2", /*node_exists=*/true, {"a", Graph::kControlSlot},
1003 /*success=*/false, error_msg, {"b", "^a", "^c"});
1004
1005 // Add fanin to node where node is missing.
1006 error_msg =
1007 "MutableGraphView::AddRegularFanin(node_name='foo_missing', fanin='a:0') "
1008 "error: node 'foo_missing' was not found.";
1009 TestAddRegularFanin("foo_missing", /*node_exists=*/false, {"a", 0},
1010 /*success=*/false, error_msg, {});
1011 // Add fanin to node where fanin is missing.
1012 error_msg =
1013 "MutableGraphView::AddRegularFanin(node_name='foo_1', "
1014 "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1015 TestAddRegularFanin("foo_1", /*node_exists=*/true, {"bar_missing", 0},
1016 /*success=*/false, error_msg, {"a"});
1017 // Add fanin to node where node and fanin are missing.
1018 error_msg =
1019 "MutableGraphView::AddRegularFanin(node_name='foo_missing', "
1020 "fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1021 TestAddRegularFanin("foo_missing", /*node_exists=*/false, {"bar_missing", 0},
1022 /*success=*/false, error_msg, {});
1023 // Add control fanin to node where node and fanin are missing.
1024 error_msg =
1025 "MutableGraphView::AddRegularFanin(node_name='foo_missing', "
1026 "fanin='^bar_missing') error: fanin '^bar_missing' must be a regular "
1027 "tensor id.";
1028 TestAddRegularFanin("foo_missing", /*node_exists=*/false,
1029 {"bar_missing", Graph::kControlSlot},
1030 /*success=*/false, error_msg, {});
1031
1032 // Add self to create cycle.
1033 error_msg =
1034 "MutableGraphView::AddRegularFanin(node_name='foo_6', fanin='foo_6:2') "
1035 "error: can't add fanin 'foo_6:2' to self.";
1036 TestAddRegularFanin("foo_6", /*node_exists=*/true, {"foo_6", 2},
1037 /*success=*/false, error_msg, {"^a", "^b"});
1038 }
1039
TestAddRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,const TensorId & fanin_to_add,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1040 void TestAddRegularFaninByPort(absl::string_view node_name, bool node_exists,
1041 int port, const TensorId& fanin_to_add,
1042 bool success, const string& error_msg,
1043 absl::Span<const string> expected_fanins) {
1044 GraphDef graph_def = SimpleMutateFaninGraph();
1045
1046 MutableGraphView graph(&graph_def);
1047
1048 NodeDef* node = graph.GetNode(node_name);
1049 if (node_exists) {
1050 EXPECT_NE(node, nullptr);
1051 } else {
1052 EXPECT_EQ(node, nullptr);
1053 }
1054
1055 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1056 GetNodeInputsFromGraph(graph_def, node_name);
1057
1058 Status s = graph.AddRegularFaninByPort(node_name, port, fanin_to_add);
1059 EXPECT_EQ(s.ok(), success);
1060 if (!success) {
1061 EXPECT_EQ(s.error_message(), error_msg);
1062 }
1063 if (node_exists) {
1064 CompareNodeFanins(graph, node, expected_fanins);
1065 }
1066
1067 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1068
1069 CheckGraph(graph);
1070 }
1071
TEST(MutableGraphViewTest,AddRegularFaninByPort)1072 TEST(MutableGraphViewTest, AddRegularFaninByPort) {
1073 string error_msg;
1074 // Add input at start to node with some inputs and no controls.
1075 TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/0, {"d", 2},
1076 /*success=*/true, error_msg,
1077 {"d:2", "b", "a:1", "a:1"});
1078 // Add input at end to node with some inputs and no controls.
1079 TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/3, {"d", 2},
1080 /*success=*/true, error_msg,
1081 {"b", "a:1", "a:1", "d:2"});
1082 // Add input in middle to node with some inputs and no controls.
1083 TestAddRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/2, {"d", 2},
1084 /*success=*/true, error_msg,
1085 {"b", "a:1", "d:2", "a:1"});
1086 // Add input at start to node with some inputs and some controls.
1087 TestAddRegularFaninByPort("foo_2", /*node_exists=*/true, /*port=*/0, {"d", 2},
1088 /*success=*/true, error_msg,
1089 {"d:2", "b", "^c", "^a"});
1090 // Add input at end to node with some inputs and some controls.
1091 TestAddRegularFaninByPort("foo_2", /*node_exists=*/true, /*port=*/1, {"d", 2},
1092 /*success=*/true, error_msg,
1093 {"b", "d:2", "^c", "^a"});
1094 // Add input in middle to node with some inputs and some controls, and dedup
1095 // controls.
1096 TestAddRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/2, {"d", 2},
1097 /*success=*/true, error_msg,
1098 {"a", "b:2", "d:2", "b:2", "^c"});
1099 // Add input to node with no inputs and no controls.
1100 TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0, {"d", 2},
1101 /*success=*/true, error_msg, {"d:2"});
1102 // Add input to node with no inputs and some controls.
1103 TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0, {"d", 2},
1104 /*success=*/true, error_msg, {"d:2", "^b", "^a"});
1105 // Add fanin should dedup control.
1106 TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0, {"b", 2},
1107 /*success=*/true, error_msg, {"b:2", "^a"});
1108
1109 // Add controlling fanin.
1110 error_msg =
1111 "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=2, "
1112 "fanin='^d') error: fanin '^d' must be a regular tensor id.";
1113 TestAddRegularFaninByPort(
1114 "foo_4", /*node_exists=*/true, /*port=*/2, {"d", Graph::kControlSlot},
1115 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1116
1117 // Add fanin at out of bounds port.
1118 error_msg =
1119 "MutableGraphView::AddRegularFaninByPort(node_name='foo_5', port=-1, "
1120 "fanin='d:2') error: port must be in range [0, 0].";
1121 TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/-1,
1122 {"d", 2},
1123 /*success=*/false, error_msg, {});
1124 error_msg =
1125 "MutableGraphView::AddRegularFaninByPort(node_name='foo_5', port=1, "
1126 "fanin='d:2') error: port must be in range [0, 0].";
1127 TestAddRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/1, {"d", 2},
1128 /*success=*/false, error_msg, {});
1129 error_msg =
1130 "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=-1, "
1131 "fanin='d:2') error: port must be in range [0, 0].";
1132 TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/-1,
1133 {"d", 2},
1134 /*success=*/false, error_msg, {"^a", "^b"});
1135 error_msg =
1136 "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=1, "
1137 "fanin='d:2') error: port must be in range [0, 0].";
1138 TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1, {"d", 2},
1139 /*success=*/false, error_msg, {"^a", "^b"});
1140 error_msg =
1141 "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=-1, "
1142 "fanin='d:2') error: port must be in range [0, 3].";
1143 TestAddRegularFaninByPort(
1144 "foo_4", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1145 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1146 error_msg =
1147 "MutableGraphView::AddRegularFaninByPort(node_name='foo_4', port=4, "
1148 "fanin='d:2') error: port must be in range [0, 3].";
1149 TestAddRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/4, {"d", 2},
1150 /*success=*/false, error_msg,
1151 {"a", "b:2", "b:2", "^c", "^d"});
1152
1153 // Add fanin to node where node is missing.
1154 error_msg =
1155 "MutableGraphView::AddRegularFaninByPort(node_name='foo_missing', "
1156 "port=0, fanin='a:0') error: node 'foo_missing' was not found.";
1157 TestAddRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1158 {"a", 0},
1159 /*success=*/false, error_msg, {});
1160 // Add fanin to node where fanin is missing.
1161 error_msg =
1162 "MutableGraphView::AddRegularFaninByPort(node_name='foo_1', port=0, "
1163 "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1164 TestAddRegularFaninByPort("foo_1", /*node_exists=*/true, /*port=*/0,
1165 {"bar_missing", 0},
1166 /*success=*/false, error_msg, {"a"});
1167 // Add fanin to node where node and fanin are missing.
1168 error_msg =
1169 "MutableGraphView::AddRegularFaninByPort(node_name='foo_missing', "
1170 "port=0, fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1171 TestAddRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1172 {"bar_missing", 0},
1173 /*success=*/false, error_msg, {});
1174
1175 // Add self to create cycle.
1176 error_msg =
1177 "MutableGraphView::AddRegularFaninByPort(node_name='foo_6', port=0, "
1178 "fanin='foo_6:2') error: can't add fanin 'foo_6:2' to self.";
1179 TestAddRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1180 {"foo_6", 2},
1181 /*success=*/false, error_msg, {"^a", "^b"});
1182 }
1183
CheckFanoutRemoved(const MutableGraphView & graph,const TensorId & fanin,absl::string_view node_name)1184 void CheckFanoutRemoved(const MutableGraphView& graph, const TensorId& fanin,
1185 absl::string_view node_name) {
1186 MutableGraphView::OutputPort output_port =
1187 graph.GetOutputPort(fanin.node(), fanin.index());
1188 auto fanouts = graph.GetFanout(output_port);
1189 for (auto fanout : fanouts) {
1190 EXPECT_NE(fanout.node->name(), fanin.node());
1191 }
1192 }
1193
TestRemoveRegularFanin(absl::string_view node_name,bool node_exists,const TensorId & fanin_to_remove,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1194 void TestRemoveRegularFanin(absl::string_view node_name, bool node_exists,
1195 const TensorId& fanin_to_remove, bool success,
1196 const string& error_msg,
1197 absl::Span<const string> expected_fanins) {
1198 GraphDef graph_def = SimpleMutateFaninGraph();
1199
1200 MutableGraphView graph(&graph_def);
1201
1202 NodeDef* node = graph.GetNode(node_name);
1203 if (node_exists) {
1204 EXPECT_NE(nullptr, node);
1205 } else {
1206 EXPECT_EQ(nullptr, node);
1207 }
1208
1209 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1210 GetNodeInputsFromGraph(graph_def, node_name);
1211
1212 Status s = graph.RemoveRegularFanin(node_name, fanin_to_remove);
1213 EXPECT_EQ(s.ok(), success);
1214 if (!success) {
1215 EXPECT_EQ(s.error_message(), error_msg);
1216 }
1217 if (node_exists) {
1218 CompareNodeFanins(graph, node, expected_fanins);
1219 if (success) {
1220 CheckFanoutRemoved(graph, fanin_to_remove, node_name);
1221 }
1222 }
1223
1224 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1225
1226 CheckGraph(graph);
1227 }
1228
TEST(MutableGraphViewTest,RemoveRegularFanin)1229 TEST(MutableGraphViewTest, RemoveRegularFanin) {
1230 string error_msg;
1231 // Remove input from node with 1 input 0 controls.
1232 TestRemoveRegularFanin("foo_1", /*node_exists=*/true, {"a", 0},
1233 /*success=*/true, error_msg, {});
1234 // Remove input from node with multiple inputs and 0 controls.
1235 TestRemoveRegularFanin("foo_3", /*node_exists=*/true, {"a", 1},
1236 /*success=*/true, error_msg, {"b"});
1237 // Remove input from node with 1 input multiple controls.
1238 TestRemoveRegularFanin("foo_2", /*node_exists=*/true, {"b", 0},
1239 /*success=*/true, error_msg, {"^a", "^c"});
1240 // Remove input from node with multiple inputs and controls.
1241 TestRemoveRegularFanin("foo_4", /*node_exists=*/true, {"b", 2},
1242 /*success=*/true, error_msg, {"a", "^c", "^d"});
1243 // Remove input from node with multiple inputs and controls, and results in
1244 // shifting of ports.
1245 TestRemoveRegularFanin("foo_4", /*node_exists=*/true, {"a", 0},
1246 /*success=*/true, error_msg,
1247 {"b:2", "b:2", "^c", "^d"});
1248
1249 // Remove control from node with 1 input multiple controls.
1250 error_msg =
1251 "MutableGraphView::RemoveRegularFanin(node_name='foo_2', fanin='^a') "
1252 "error: fanin '^a' must be a regular tensor id.";
1253 TestRemoveRegularFanin("foo_2", /*node_exists=*/true,
1254 {"a", Graph::kControlSlot},
1255 /*success=*/false, error_msg, {"b", "^a", "^c"});
1256 // Remove control from node with multiple input multiple controls.
1257 error_msg =
1258 "MutableGraphView::RemoveRegularFanin(node_name='foo_4', fanin='^d') "
1259 "error: fanin '^d' must be a regular tensor id.";
1260 TestRemoveRegularFanin(
1261 "foo_4", /*node_exists=*/true, {"d", Graph::kControlSlot},
1262 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1263 // Remove control from node with 0 inputs multiple controls.
1264 error_msg =
1265 "MutableGraphView::RemoveRegularFanin(node_name='foo_6', fanin='^a') "
1266 "error: fanin '^a' must be a regular tensor id.";
1267 TestRemoveRegularFanin("foo_6", /*node_exists=*/true,
1268 {"a", Graph::kControlSlot},
1269 /*success=*/false, error_msg, {"^a", "^b"});
1270
1271 // Remove input from node with 0 inputs 0 controls.
1272 error_msg = "";
1273 TestRemoveRegularFanin("foo_5", /*node_exists=*/true, {"a", 1},
1274 /*success=*/true, error_msg, {});
1275 // Remove input from node with 0 inputs multiple controls.
1276 TestRemoveRegularFanin("foo_6", /*node_exists=*/true, {"a", 1},
1277 /*success=*/true, error_msg, {"^a", "^b"});
1278
1279 // Remove control from node with 1 input 0 controls.
1280 error_msg =
1281 "MutableGraphView::RemoveRegularFanin(node_name='foo_1', fanin='^b') "
1282 "error: fanin '^b' must be a regular tensor id.";
1283 TestRemoveRegularFanin("foo_1", /*node_exists=*/true,
1284 {"b", Graph::kControlSlot},
1285 /*success=*/false, error_msg, {"a"});
1286 // Remove control from node with multiple inputs and 0 controls.
1287 error_msg =
1288 "MutableGraphView::RemoveRegularFanin(node_name='foo_3', fanin='^c') "
1289 "error: fanin '^c' must be a regular tensor id.";
1290 TestRemoveRegularFanin("foo_3", /*node_exists=*/true,
1291 {"c", Graph::kControlSlot},
1292 /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1293 // Remove control from node with 0 inputs 0 controls.
1294 error_msg =
1295 "MutableGraphView::RemoveRegularFanin(node_name='foo_5', fanin='^a') "
1296 "error: fanin '^a' must be a regular tensor id.";
1297 TestRemoveRegularFanin("foo_5", /*node_exists=*/true,
1298 {"a", Graph::kControlSlot},
1299 /*success=*/false, error_msg, {});
1300
1301 // Remove fanin from node where node is missing.
1302 error_msg =
1303 "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1304 "fanin='a:0') error: node 'foo_missing' was not found.";
1305 TestRemoveRegularFanin("foo_missing", /*node_exists=*/false, {"a", 0},
1306 /*success=*/false, error_msg, {});
1307 // Remove fanin from node where fanin is missing.
1308 error_msg =
1309 "MutableGraphView::RemoveRegularFanin(node_name='foo_1', "
1310 "fanin='bar_missing:0') error: node 'bar_missing' was not found.";
1311 TestRemoveRegularFanin("foo_1", /*node_exists=*/true, {"bar_missing", 0},
1312 /*success=*/false, error_msg, {"a"});
1313 // Remove fanin from node where node and fanin are missing.
1314 error_msg =
1315 "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1316 "fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1317 TestRemoveRegularFanin("foo_missing", /*node_exists=*/false,
1318 {"bar_missing", 0}, /*success=*/false, error_msg, {});
1319 // Remove control from node where node and fanin are missing.
1320 error_msg =
1321 "MutableGraphView::RemoveRegularFanin(node_name='foo_missing', "
1322 "fanin='^bar_missing') error: fanin '^bar_missing' must be a regular "
1323 "tensor id.";
1324 TestRemoveRegularFanin("foo_missing", /*node_exists=*/false,
1325 {"bar_missing", Graph::kControlSlot},
1326 /*success=*/false, error_msg, {});
1327
1328 // Remove self.
1329 error_msg =
1330 "MutableGraphView::RemoveRegularFanin(node_name='foo_6', "
1331 "fanin='foo_6:2') error: can't remove fanin 'foo_6:2' from self.";
1332 TestRemoveRegularFanin("foo_6", /*node_exists=*/true, {"foo_6", 2},
1333 /*success=*/false, error_msg, {"^a", "^b"});
1334 }
1335
TestRemoveRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1336 void TestRemoveRegularFaninByPort(absl::string_view node_name, bool node_exists,
1337 int port, bool success,
1338 const string& error_msg,
1339 absl::Span<const string> expected_fanins) {
1340 GraphDef graph_def = SimpleMutateFaninGraph();
1341
1342 MutableGraphView graph(&graph_def);
1343
1344 NodeDef* node = graph.GetNode(node_name);
1345 if (node_exists) {
1346 EXPECT_NE(nullptr, node);
1347 } else {
1348 EXPECT_EQ(nullptr, node);
1349 }
1350
1351 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1352 GetNodeInputsFromGraph(graph_def, node_name);
1353
1354 Status s = graph.RemoveRegularFaninByPort(node_name, port);
1355 EXPECT_EQ(s.ok(), success);
1356 if (!success) {
1357 EXPECT_EQ(s.error_message(), error_msg);
1358 }
1359 if (node_exists) {
1360 CompareNodeFanins(graph, node, expected_fanins);
1361 }
1362
1363 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1364
1365 CheckGraph(graph);
1366 }
1367
TEST(MutableGraphViewTest,RemoveRegularFaninByPort)1368 TEST(MutableGraphViewTest, RemoveRegularFaninByPort) {
1369 string error_msg;
1370 // Remove input at start of node with some inputs and no controls.
1371 TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/0,
1372 /*success=*/true, error_msg, {"a:1", "a:1"});
1373 // Remove input at end of node with some inputs and no controls.
1374 TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/2,
1375 /*success=*/true, error_msg, {"b", "a:1"});
1376 // Remove input in middle of node with some inputs and no controls.
1377 TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/1,
1378 /*success=*/true, error_msg, {"b", "a:1"});
1379 // Remove input at start of node with some inputs and some controls.
1380 TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/0,
1381 /*success=*/true, error_msg,
1382 {"b:2", "b:2", "^d", "^c"});
1383 // Remove input at end of node with some inputs and some controls.
1384 TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/2,
1385 /*success=*/true, error_msg,
1386 {"a", "b:2", "^d", "^c"});
1387 // Remove input in middle of node with some inputs and some controls.
1388 TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/1,
1389 /*success=*/true, error_msg,
1390 {"a", "b:2", "^d", "^c"});
1391
1392 // Remove input from node with no inputs and no controls.
1393 error_msg =
1394 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_5', port=0) "
1395 "error: no available ports as node has no regular fanins.";
1396 TestRemoveRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0,
1397 /*success=*/false, error_msg, {});
1398 // Remove input from node with no inputs and some controls.
1399 error_msg =
1400 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_6', port=1) "
1401 "error: no available ports as node has no regular fanins.";
1402 TestRemoveRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1,
1403 /*success=*/false, error_msg, {"^a", "^b"});
1404
1405 // Remove fanin at out of bounds port.
1406 error_msg =
1407 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_3', port=-1) "
1408 "error: port must be in range [0, 2].";
1409 TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/-1,
1410 /*success=*/false, error_msg,
1411 {"b", "a:1", "a:1"});
1412 error_msg =
1413 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_3', port=3) "
1414 "error: port must be in range [0, 2].";
1415 TestRemoveRegularFaninByPort("foo_3", /*node_exists=*/true, /*port=*/3,
1416 /*success=*/false, error_msg,
1417 {"b", "a:1", "a:1"});
1418 error_msg =
1419 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_4', port=-1) "
1420 "error: port must be in range [0, 2].";
1421 TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/-1,
1422 /*success=*/false, error_msg,
1423 {"a", "b:2", "b:2", "^c", "^d"});
1424 error_msg =
1425 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_4', port=3) "
1426 "error: port must be in range [0, 2].";
1427 TestRemoveRegularFaninByPort("foo_4", /*node_exists=*/true, /*port=*/3,
1428 /*success=*/false, error_msg,
1429 {"a", "b:2", "b:2", "^c", "^d"});
1430
1431 // Remove fanin from node where node is missing.
1432 error_msg =
1433 "MutableGraphView::RemoveRegularFaninByPort(node_name='foo_missing', "
1434 "port=0) error: node 'foo_missing' was not found.";
1435 TestRemoveRegularFaninByPort("foo_missing", /*node_exists=*/false, /*port=*/0,
1436 /*success=*/false, error_msg, {});
1437 }
1438
TestRemoveAllFanins(absl::string_view node_name,bool node_exists,bool keep_controlling_nodes,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1439 void TestRemoveAllFanins(absl::string_view node_name, bool node_exists,
1440 bool keep_controlling_nodes, bool success,
1441 const string& error_msg,
1442 absl::Span<const string> expected_fanins) {
1443 GraphDef graph_def = SimpleMutateFaninGraph();
1444
1445 MutableGraphView graph(&graph_def);
1446
1447 NodeDef* node = graph.GetNode(node_name);
1448 absl::flat_hash_set<string> fanin_strings;
1449 if (node_exists) {
1450 EXPECT_NE(node, nullptr);
1451 fanin_strings.insert(node->input().begin(), node->input().end());
1452 } else {
1453 EXPECT_EQ(node, nullptr);
1454 }
1455
1456 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1457 GetNodeInputsFromGraph(graph_def, node_name);
1458
1459 Status s = graph.RemoveAllFanins(node_name, keep_controlling_nodes);
1460 EXPECT_EQ(s.ok(), success);
1461 if (!success) {
1462 EXPECT_EQ(s.error_message(), error_msg);
1463 }
1464 if (node_exists) {
1465 CompareNodeFanins(graph, node, expected_fanins);
1466 if (success) {
1467 TensorId tensor_id;
1468 auto retained_inputs = absl::flat_hash_set<string>(node->input().begin(),
1469 node->input().end());
1470 for (const string& fanin : fanin_strings) {
1471 if (!retained_inputs.contains(fanin)) {
1472 tensor_id = ParseTensorName(fanin);
1473 CheckFanoutRemoved(graph, tensor_id, node_name);
1474 }
1475 }
1476 }
1477 }
1478
1479 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1480
1481 CheckGraph(graph);
1482 }
1483
TEST(MutableGraphViewTest,RemoveAllFanins)1484 TEST(MutableGraphViewTest, RemoveAllFanins) {
1485 string error_msg;
1486 // Remove all fanins from node with no control dependencies.
1487 TestRemoveAllFanins("foo_3", /*node_exists=*/true,
1488 /*keep_controlling_nodes=*/false,
1489 /*success=*/true, error_msg, {});
1490 // Remove all fanins from node with control dependencies.
1491 TestRemoveAllFanins("foo_4", /*node_exists=*/true,
1492 /*keep_controlling_nodes=*/false,
1493 /*success=*/true, error_msg, {});
1494
1495 // Remove all fanins from node with no control dependencies and preserve
1496 // control dependencies.
1497 TestRemoveAllFanins("foo_3", /*node_exists=*/true,
1498 /*keep_controlling_nodes=*/true,
1499 /*success=*/true, error_msg, {});
1500 // Remove all fanins from node with control dependencies and preserve control
1501 // dependencies.
1502 TestRemoveAllFanins("foo_4", /*node_exists=*/true,
1503 /*keep_controlling_nodes=*/true,
1504 /*success=*/true, error_msg, {"^c", "^d"});
1505
1506 // Remove all fanins from node with no fanins.
1507 TestRemoveAllFanins("foo_5", /*node_exists=*/true,
1508 /*keep_controlling_nodes=*/false,
1509 /*success=*/true, error_msg, {});
1510 TestRemoveAllFanins("foo_5", /*node_exists=*/true,
1511 /*keep_controlling_nodes=*/true,
1512 /*success=*/true, error_msg, {});
1513
1514 // Remove all fanins from node with only control dependencies.
1515 TestRemoveAllFanins("foo_6", /*node_exists=*/true,
1516 /*keep_controlling_nodes=*/false,
1517 /*success=*/true, error_msg, {});
1518 TestRemoveAllFanins("foo_6", /*node_exists=*/true,
1519 /*keep_controlling_nodes=*/true,
1520 /*success=*/true, error_msg, {"^a", "^b"});
1521
1522 // Remove all fanins from node where node is missing.
1523 error_msg =
1524 "MutableGraphView::RemoveAllFanins(node_name='foo_missing', "
1525 "keep_controlling_fanins=false) error: node 'foo_missing' was not found.";
1526 TestRemoveAllFanins("foo_missing", /*node_exists=*/false,
1527 /*keep_controlling_nodes=*/false,
1528 /*success=*/false, error_msg, {});
1529 error_msg =
1530 "MutableGraphView::RemoveAllFanins(node_name='foo_missing', "
1531 "keep_controlling_fanins=true) error: node 'foo_missing' was not found.";
1532 TestRemoveAllFanins("foo_missing", /*node_exists=*/false,
1533 /*keep_controlling_nodes=*/true,
1534 /*success=*/false, error_msg, {});
1535 }
1536
TestUpdateFanin(absl::string_view node_name,bool node_exists,const TensorId & from_fanin,const TensorId & to_fanin,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1537 void TestUpdateFanin(absl::string_view node_name, bool node_exists,
1538 const TensorId& from_fanin, const TensorId& to_fanin,
1539 bool success, const string& error_msg,
1540 absl::Span<const string> expected_fanins) {
1541 GraphDef graph_def = SimpleMutateFaninGraph();
1542
1543 MutableGraphView graph(&graph_def);
1544
1545 NodeDef* node = graph.GetNode(node_name);
1546 if (node_exists) {
1547 EXPECT_NE(node, nullptr);
1548 } else {
1549 EXPECT_EQ(node, nullptr);
1550 }
1551
1552 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1553 GetNodeInputsFromGraph(graph_def, node_name);
1554
1555 Status s = graph.UpdateFanin(node_name, from_fanin, to_fanin);
1556 EXPECT_EQ(s.ok(), success);
1557 if (!success) {
1558 EXPECT_EQ(s.error_message(), error_msg);
1559 }
1560 if (node_exists) {
1561 CompareNodeFanins(graph, node, expected_fanins);
1562 if (success) {
1563 CheckFanoutRemoved(graph, from_fanin, node_name);
1564 }
1565 }
1566
1567 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1568
1569 CheckGraph(graph);
1570 }
1571
TEST(MutableGraphViewTest,UpdateFanin)1572 TEST(MutableGraphViewTest, UpdateFanin) {
1573 string error_msg;
1574 // Update fanin from non control to non control.
1575 TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2}, {"b", 3},
1576 /*success=*/true, error_msg, {"a", "b:3", "b:3", "^c", "^d"});
1577 // Update fanin from non control to control.
1578 TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2},
1579 {"b", Graph::kControlSlot},
1580 /*success=*/true, error_msg, {"a", "^c", "^d", "^b"});
1581 // Update fanin from control to non control.
1582 TestUpdateFanin(
1583 "foo_4", /*node_exists=*/true, {"d", Graph::kControlSlot}, {"d", 1},
1584 /*success=*/true, error_msg, {"a", "b:2", "b:2", "d:1", "^c"});
1585 // Update fanin from control to control.
1586 TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1587 {"b", Graph::kControlSlot}, /*success=*/true, error_msg,
1588 {"a", "b:2", "b:2", "^d"});
1589 // Update fanin from control to existing control.
1590 TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1591 {"d", Graph::kControlSlot}, /*success=*/true, error_msg,
1592 {"a", "b:2", "b:2", "^d"});
1593
1594 // Update fanin of node where from and to fanins are the same.
1595 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", -1}, {"a", -1},
1596 /*success=*/true, error_msg, {"a"});
1597 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0}, {"a", 0},
1598 /*success=*/true, error_msg, {"a"});
1599 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 1}, {"a", 1},
1600 /*success=*/true, error_msg, {"a"});
1601
1602 // Update fanin of node where node is missing.
1603 error_msg =
1604 "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1605 "from_fanin='a:0', to_fanin='a:1') error: node 'foo_missing' was not "
1606 "found.";
1607 TestUpdateFanin("foo_missing", /*node_exists=*/false, {"a", 0}, {"a", 1},
1608 /*success=*/false, error_msg, {});
1609 // Update fanin of node where from fanin is missing.
1610 error_msg =
1611 "MutableGraphView::UpdateFanin(node_name='foo_1', "
1612 "from_fanin='from_bar_missing:0', to_fanin='a:1') error: node "
1613 "'from_bar_missing' was not found.";
1614 TestUpdateFanin("foo_1", /*node_exists=*/true, {"from_bar_missing", 0},
1615 {"a", 1},
1616 /*success=*/false, error_msg, {"a"});
1617 // Update fanin of node where to fanin is missing.
1618 error_msg =
1619 "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:0', "
1620 "to_fanin='to_bar_missing:1') error: node 'to_bar_missing' was not "
1621 "found.";
1622 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0},
1623 {"to_bar_missing", 1}, /*success=*/false, error_msg, {"a"});
1624 // Update fanin of node where from/to fanins and node are missing.
1625 error_msg =
1626 "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1627 "from_fanin='from_bar_missing:0', to_fanin='to_bar_missing:1') error: "
1628 "node 'foo_missing' was not found.";
1629 TestUpdateFanin("foo_missing", /*node_exists=*/false, {"from_bar_missing", 0},
1630 {"to_bar_missing", 1},
1631 /*success=*/false, error_msg, {});
1632 // Update fanin of node where from fanin is invalid.
1633 error_msg =
1634 "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:-2', "
1635 "to_fanin='a:0') error: fanin 'a:-2' must be a valid tensor id.";
1636 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", -2}, {"a", 0},
1637 /*success=*/false, error_msg, {"a"});
1638 // Update fanin of node where to fanin is invalid.
1639 error_msg =
1640 "MutableGraphView::UpdateFanin(node_name='foo_1', from_fanin='a:0', "
1641 "to_fanin='a:-2') error: fanin 'a:-2' must be a valid tensor id.";
1642 TestUpdateFanin("foo_1", /*node_exists=*/true, {"a", 0}, {"a", -2},
1643 /*success=*/false, error_msg, {"a"});
1644 // Update fanin of node where from/to fanins are invalid and missing and node
1645 // is missing.
1646 error_msg =
1647 "MutableGraphView::UpdateFanin(node_name='foo_missing', "
1648 "from_fanin='from_bar_missing:-2', to_fanin='to_bar_missing:-3') error: "
1649 "fanin 'from_bar_missing:-2' must be a valid tensor id.";
1650 TestUpdateFanin("foo_missing", /*node_exists=*/false,
1651 {"from_bar_missing", -2}, {"to_bar_missing", -3},
1652 /*success=*/false, error_msg, {});
1653
1654 // Update to self to create cycle.
1655 error_msg =
1656 "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='b:2', "
1657 "to_fanin='foo_4:3') error: can't update fanin to or from self.";
1658 TestUpdateFanin("foo_4", /*node_exists=*/true, {"b", 2}, {"foo_4", 3},
1659 /*success=*/false, error_msg,
1660 {"a", "b:2", "b:2", "^c", "^d"});
1661 error_msg =
1662 "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='b:2', "
1663 "to_fanin='^foo_4') error: can't update fanin to or from self.";
1664 TestUpdateFanin(
1665 "foo_4", /*node_exists=*/true, {"b", 2}, {"foo_4", Graph::kControlSlot},
1666 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1667 error_msg =
1668 "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='^c', "
1669 "to_fanin='foo_4:4') error: can't update fanin to or from self.";
1670 TestUpdateFanin(
1671 "foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot}, {"foo_4", 4},
1672 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1673 error_msg =
1674 "MutableGraphView::UpdateFanin(node_name='foo_4', from_fanin='^c', "
1675 "to_fanin='^foo_4') error: can't update fanin to or from self.";
1676 TestUpdateFanin("foo_4", /*node_exists=*/true, {"c", Graph::kControlSlot},
1677 {"foo_4", Graph::kControlSlot}, /*success=*/false, error_msg,
1678 {"a", "b:2", "b:2", "^c", "^d"});
1679 }
1680
TestUpdateFaninFromFaninToNodeAsSwitchControl(const TensorId & fanin)1681 void TestUpdateFaninFromFaninToNodeAsSwitchControl(const TensorId& fanin) {
1682 string tensor_id_str = TensorIdToString(fanin);
1683 GraphDef graph_def = test::function::GDef(
1684 {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {}),
1685 NDef("c", "NotImportant", {tensor_id_str})},
1686 /*funcs=*/{});
1687
1688 MutableGraphView graph(&graph_def);
1689
1690 Status s = graph.UpdateFanin("c", fanin, {"b", Graph::kControlSlot});
1691 EXPECT_FALSE(s.ok());
1692 string expected_msg = absl::Substitute(
1693 "MutableGraphView::UpdateFanin(node_name='c', from_fanin='$0', "
1694 "to_fanin='^b') error: can't update to fanin '^b' as it will become a "
1695 "Switch control dependency.",
1696 fanin.ToString());
1697 EXPECT_EQ(s.error_message(), expected_msg);
1698
1699 EXPECT_EQ(graph.graph()->node_size(), 3);
1700
1701 string fanout = IsControlInput(fanin) ? AsControlDependency("c") : "c";
1702 CheckNode(graph, "a", "NotImportant", "", {}, {}, {fanout});
1703 CheckNode(graph, "b", "Switch", "", {}, {}, {});
1704 CheckNode(graph, "c", "NotImportant", "", {}, {tensor_id_str}, {});
1705
1706 CheckGraph(graph);
1707 }
1708
TEST(MutableGraphViewTest,UpdateFaninToNodeAsSwitchControl)1709 TEST(MutableGraphViewTest, UpdateFaninToNodeAsSwitchControl) {
1710 TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", 0});
1711 TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", 1});
1712 TestUpdateFaninFromFaninToNodeAsSwitchControl({"a", Graph::kControlSlot});
1713 }
1714
TestUpdateRegularFaninByPort(absl::string_view node_name,bool node_exists,int port,const TensorId & fanin,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1715 void TestUpdateRegularFaninByPort(absl::string_view node_name, bool node_exists,
1716 int port, const TensorId& fanin, bool success,
1717 const string& error_msg,
1718 absl::Span<const string> expected_fanins) {
1719 GraphDef graph_def = SimpleMutateFaninGraph();
1720
1721 MutableGraphView graph(&graph_def);
1722
1723 NodeDef* node = graph.GetNode(node_name);
1724 if (node_exists) {
1725 EXPECT_NE(node, nullptr);
1726 } else {
1727 EXPECT_EQ(node, nullptr);
1728 }
1729
1730 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1731 GetNodeInputsFromGraph(graph_def, node_name);
1732
1733 Status s = graph.UpdateRegularFaninByPort(node_name, port, fanin);
1734 EXPECT_EQ(s.ok(), success);
1735 if (!success) {
1736 EXPECT_EQ(s.error_message(), error_msg);
1737 }
1738 if (node_exists) {
1739 CompareNodeFanins(graph, node, expected_fanins);
1740 }
1741
1742 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1743
1744 CheckGraph(graph);
1745 }
1746
TEST(MutableGraphViewTest,UpdateRegularFaninByPort)1747 TEST(MutableGraphViewTest, UpdateRegularFaninByPort) {
1748 string error_msg;
1749 // Update input at start to node with some inputs and no controls.
1750 TestUpdateRegularFaninByPort(
1751 "foo_3", /*node_exists=*/true, /*port=*/0, {"d", 2},
1752 /*success=*/true, error_msg, {"d:2", "a:1", "a:1"});
1753 // Update input at end to node with some inputs and no controls.
1754 TestUpdateRegularFaninByPort(
1755 "foo_3", /*node_exists=*/true, /*port=*/2, {"d", 2},
1756 /*success=*/true, error_msg, {"b", "a:1", "d:2"});
1757 // Update input in middle to node with some inputs and no controls.
1758 TestUpdateRegularFaninByPort(
1759 "foo_3", /*node_exists=*/true, /*port=*/1, {"d", 2},
1760 /*success=*/true, error_msg, {"b", "d:2", "a:1"});
1761 // Update input at start to node with some inputs and some controls, and dedup
1762 // controls.
1763 TestUpdateRegularFaninByPort(
1764 "foo_4", /*node_exists=*/true, /*port=*/0, {"d", 2},
1765 /*success=*/true, error_msg, {"d:2", "b:2", "b:2", "^c"});
1766 // Update input at end to node with some inputs and some controls, and dedup
1767 // controls.
1768 TestUpdateRegularFaninByPort(
1769 "foo_4", /*node_exists=*/true, /*port=*/2, {"d", 2},
1770 /*success=*/true, error_msg, {"a", "b:2", "d:2", "^c"});
1771 // Update input in middle to node with some inputs and some controls and
1772 // dedup controls.
1773 TestUpdateRegularFaninByPort(
1774 "foo_4", /*node_exists=*/true, /*port=*/1, {"d", 2},
1775 /*success=*/true, error_msg, {"a", "d:2", "b:2", "^c"});
1776
1777 // Update input to controlling fanin.
1778 error_msg =
1779 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=1, "
1780 "fanin='^d') error: fanin '^d' must be a regular tensor id.";
1781 TestUpdateRegularFaninByPort(
1782 "foo_4", /*node_exists=*/true, /*port=*/1, {"d", Graph::kControlSlot},
1783 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1784
1785 // Update fanin at out of bounds port.
1786 error_msg =
1787 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=-1, "
1788 "fanin='d:2') error: no available ports as node has no regular fanins.";
1789 TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/-1,
1790 {"d", 2},
1791 /*success=*/false, error_msg, {});
1792 error_msg =
1793 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=0, "
1794 "fanin='d:2') error: no available ports as node has no regular fanins.";
1795 TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/0,
1796 {"d", 2},
1797 /*success=*/false, error_msg, {});
1798 error_msg =
1799 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_5', port=1, "
1800 "fanin='d:2') error: no available ports as node has no regular fanins.";
1801 TestUpdateRegularFaninByPort("foo_5", /*node_exists=*/true, /*port=*/1,
1802 {"d", 2},
1803 /*success=*/false, error_msg, {});
1804 error_msg =
1805 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=-1, "
1806 "fanin='d:2') error: no available ports as node has no regular fanins.";
1807 TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/-1,
1808 {"d", 2},
1809 /*success=*/false, error_msg, {"^a", "^b"});
1810 error_msg =
1811 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=0, "
1812 "fanin='d:2') error: no available ports as node has no regular fanins.";
1813 TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1814 {"d", 2},
1815 /*success=*/false, error_msg, {"^a", "^b"});
1816 error_msg =
1817 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=1, "
1818 "fanin='d:2') error: no available ports as node has no regular fanins.";
1819 TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/1,
1820 {"d", 2},
1821 /*success=*/false, error_msg, {"^a", "^b"});
1822 error_msg =
1823 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_3', port=-1, "
1824 "fanin='d:2') error: port must be in range [0, 2].";
1825 TestUpdateRegularFaninByPort(
1826 "foo_3", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1827 /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1828 error_msg =
1829 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_3', port=3, "
1830 "fanin='d:2') error: port must be in range [0, 2].";
1831 TestUpdateRegularFaninByPort(
1832 "foo_3", /*node_exists=*/true, /*port=*/3, {"d", 2},
1833 /*success=*/false, error_msg, {"b", "a:1", "a:1"});
1834 error_msg =
1835 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=-1, "
1836 "fanin='d:2') error: port must be in range [0, 2].";
1837 TestUpdateRegularFaninByPort(
1838 "foo_4", /*node_exists=*/true, /*port=*/-1, {"d", 2},
1839 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1840 error_msg =
1841 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_4', port=3, "
1842 "fanin='d:2') error: port must be in range [0, 2].";
1843 TestUpdateRegularFaninByPort(
1844 "foo_4", /*node_exists=*/true, /*port=*/3, {"d", 2},
1845 /*success=*/false, error_msg, {"a", "b:2", "b:2", "^c", "^d"});
1846
1847 // Update fanin to node where node is missing.
1848 error_msg =
1849 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_missing', "
1850 "port=0, fanin='a:0') error: node 'foo_missing' was not found.";
1851 TestUpdateRegularFaninByPort("foo_missing", /*node_exists=*/false,
1852 /*port=*/0, {"a", 0},
1853 /*success=*/false, error_msg, {});
1854 // Update fanin to node where fanin is missing.
1855 error_msg =
1856 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_1', port=0, "
1857 "fanin='bar_missing:0') error: node 'bar_missing' was not "
1858 "found.";
1859 TestUpdateRegularFaninByPort("foo_1", /*node_exists=*/true, /*port=*/0,
1860 {"bar_missing", 0},
1861 /*success=*/false, error_msg, {"a"});
1862 // Update fanin to node where node and fanin are missing.
1863 error_msg =
1864 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_missing', "
1865 "port=0, fanin='bar_missing:0') error: node 'foo_missing' was not found.";
1866 TestUpdateRegularFaninByPort("foo_missing", /*node_exists=*/false,
1867 /*port=*/0, {"bar_missing", 0},
1868 /*success=*/false, error_msg, {});
1869
1870 // Update self to create cycle.
1871 error_msg =
1872 "MutableGraphView::UpdateRegularFaninByPort(node_name='foo_6', port=0, "
1873 "fanin='foo_6:2') error: can't add fanin 'foo_6:2' to self.";
1874 TestUpdateRegularFaninByPort("foo_6", /*node_exists=*/true, /*port=*/0,
1875 {"foo_6", 2},
1876 /*success=*/false, error_msg, {"^a", "^b"});
1877 }
1878
TestSwapRegularFaninsByPorts(absl::string_view node_name,bool node_exists,int from_port,int to_port,bool success,const string & error_msg,absl::Span<const string> expected_fanins)1879 void TestSwapRegularFaninsByPorts(absl::string_view node_name, bool node_exists,
1880 int from_port, int to_port, bool success,
1881 const string& error_msg,
1882 absl::Span<const string> expected_fanins) {
1883 GraphDef graph_def = SimpleMutateFaninGraph();
1884
1885 MutableGraphView graph(&graph_def);
1886
1887 NodeDef* node = graph.GetNode(node_name);
1888 if (node_exists) {
1889 EXPECT_NE(node, nullptr);
1890 } else {
1891 EXPECT_EQ(node, nullptr);
1892 }
1893
1894 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
1895 GetNodeInputsFromGraph(graph_def, node_name);
1896
1897 Status s = graph.SwapRegularFaninsByPorts(node_name, from_port, to_port);
1898 EXPECT_EQ(s.ok(), success);
1899 if (!success) {
1900 EXPECT_EQ(s.error_message(), error_msg);
1901 }
1902 if (node_exists) {
1903 CompareNodeFanins(graph, node, expected_fanins);
1904 }
1905
1906 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
1907
1908 CheckGraph(graph);
1909 }
1910
TEST(MutableGraphViewTest,SwapRegularFaninsByPorts)1911 TEST(MutableGraphViewTest, SwapRegularFaninsByPorts) {
1912 string error_msg;
1913 // Swapping first and last regular fanins
1914 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
1915 /*to_port=*/2, /*success=*/true, error_msg,
1916 {"a:1", "a:1", "b"});
1917 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/2,
1918 /*to_port=*/0, /*success=*/true, error_msg,
1919 {"a:1", "a:1", "b"});
1920 // Swapping first and last regular fanins, in node with controls.
1921 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
1922 /*to_port=*/2, /*success=*/true, error_msg,
1923 {"b:2", "b:2", "a", "^c", "^d"});
1924 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/2,
1925 /*to_port=*/0, /*success=*/true, error_msg,
1926 {"b:2", "b:2", "a", "^c", "^d"});
1927 // Swapping middle regular fanin.
1928 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
1929 /*to_port=*/1, /*success=*/true, error_msg,
1930 {"a:1", "b", "a:1"});
1931 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/1,
1932 /*to_port=*/0, /*success=*/true, error_msg,
1933 {"a:1", "b", "a:1"});
1934 // Swapping middle regular fanin, in node with controls.
1935 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
1936 /*to_port=*/1, /*success=*/true, error_msg,
1937 {"b:2", "a", "b:2", "^c", "^d"});
1938 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1939 /*to_port=*/0, /*success=*/true, error_msg,
1940 {"b:2", "a", "b:2", "^c", "^d"});
1941 // Swapping same port.
1942 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1943 /*to_port=*/1, /*success=*/true, error_msg,
1944 {"a", "b:2", "b:2", "^c", "^d"});
1945 // Swapping same fanin but different port.
1946 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/1,
1947 /*to_port=*/2, /*success=*/true, error_msg,
1948 {"a", "b:2", "b:2", "^c", "^d"});
1949
1950 // Swapping fanins at out of bounds ports.
1951 // Node with no regular fanins and no controls.
1952 error_msg =
1953 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1954 "from_port=-1, to_port=0) error: no available ports as node has no "
1955 "regular fanins.";
1956 TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/-1,
1957 /*to_port=*/0, /*success=*/false, error_msg, {});
1958 error_msg =
1959 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1960 "from_port=0, to_port=-1) error: no available ports as node has no "
1961 "regular fanins.";
1962 TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1963 /*to_port=*/-1, /*success=*/false, error_msg,
1964 {});
1965 error_msg =
1966 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1967 "from_port=0, to_port=0) error: no available ports as node has no "
1968 "regular fanins.";
1969 TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1970 /*to_port=*/0, /*success=*/false, error_msg, {});
1971 error_msg =
1972 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1973 "from_port=0, to_port=1) error: no available ports as node has no "
1974 "regular fanins.";
1975 TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/0,
1976 /*to_port=*/1, /*success=*/false, error_msg, {});
1977 error_msg =
1978 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_5', "
1979 "from_port=1, to_port=0) error: no available ports as node has no "
1980 "regular fanins.";
1981 TestSwapRegularFaninsByPorts("foo_5", /*node_exists=*/true, /*from_port=*/1,
1982 /*to_port=*/0, /*success=*/false, error_msg, {});
1983 // Node with no regular fanins and some controls.
1984 error_msg =
1985 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
1986 "from_port=-1, to_port=0) error: no available ports as node has no "
1987 "regular fanins.";
1988 TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/-1,
1989 /*to_port=*/0, /*success=*/false, error_msg,
1990 {"^a", "^b"});
1991 error_msg =
1992 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
1993 "from_port=0, to_port=-1) error: no available ports as node has no "
1994 "regular fanins.";
1995 TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
1996 /*to_port=*/-1, /*success=*/false, error_msg,
1997 {"^a", "^b"});
1998 error_msg =
1999 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2000 "from_port=0, to_port=0) error: no available ports as node has no "
2001 "regular fanins.";
2002 TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
2003 /*to_port=*/0, /*success=*/false, error_msg,
2004 {"^a", "^b"});
2005 error_msg =
2006 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2007 "from_port=0, to_port=1) error: no available ports as node has no "
2008 "regular fanins.";
2009 TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/0,
2010 /*to_port=*/1, /*success=*/false, error_msg,
2011 {"^a", "^b"});
2012 error_msg =
2013 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_6', "
2014 "from_port=1, to_port=0) error: no available ports as node has no "
2015 "regular fanins.";
2016 TestSwapRegularFaninsByPorts("foo_6", /*node_exists=*/true, /*from_port=*/1,
2017 /*to_port=*/0, /*success=*/false, error_msg,
2018 {"^a", "^b"});
2019 // Node with regular fanins and no controls.
2020 error_msg =
2021 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2022 "from_port=-1, to_port=0) error: port must be in range [0, 2].";
2023 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/-1,
2024 /*to_port=*/0, /*success=*/false, error_msg,
2025 {"b", "a:1", "a:1"});
2026 error_msg =
2027 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2028 "from_port=0, to_port=-1) error: port must be in range [0, 2].";
2029 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
2030 /*to_port=*/-1, /*success=*/false, error_msg,
2031 {"b", "a:1", "a:1"});
2032 error_msg =
2033 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2034 "from_port=0, to_port=3) error: port must be in range [0, 2].";
2035 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/0,
2036 /*to_port=*/3, /*success=*/false, error_msg,
2037 {"b", "a:1", "a:1"});
2038 error_msg =
2039 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2040 "from_port=3, to_port=0) error: port must be in range [0, 2].";
2041 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/3,
2042 /*to_port=*/0, /*success=*/false, error_msg,
2043 {"b", "a:1", "a:1"});
2044 error_msg =
2045 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2046 "from_port=-1, to_port=3) error: port must be in range [0, 2].";
2047 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/-1,
2048 /*to_port=*/3, /*success=*/false, error_msg,
2049 {"b", "a:1", "a:1"});
2050 error_msg =
2051 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_3', "
2052 "from_port=3, to_port=-1) error: port must be in range [0, 2].";
2053 TestSwapRegularFaninsByPorts("foo_3", /*node_exists=*/true, /*from_port=*/3,
2054 /*to_port=*/-1, /*success=*/false, error_msg,
2055 {"b", "a:1", "a:1"});
2056 // Node with regular fanins and controls.
2057 error_msg =
2058 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2059 "from_port=-1, to_port=0) error: port must be in range [0, 2].";
2060 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/-1,
2061 /*to_port=*/0, /*success=*/false, error_msg,
2062 {"a", "b:2", "b:2", "^c", "^d"});
2063 error_msg =
2064 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2065 "from_port=0, to_port=-1) error: port must be in range [0, 2].";
2066 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
2067 /*to_port=*/-1, /*success=*/false, error_msg,
2068 {"a", "b:2", "b:2", "^c", "^d"});
2069 error_msg =
2070 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2071 "from_port=0, to_port=3) error: port must be in range [0, 2].";
2072 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/0,
2073 /*to_port=*/3, /*success=*/false, error_msg,
2074 {"a", "b:2", "b:2", "^c", "^d"});
2075 error_msg =
2076 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2077 "from_port=3, to_port=0) error: port must be in range [0, 2].";
2078 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/3,
2079 /*to_port=*/0, /*success=*/false, error_msg,
2080 {"a", "b:2", "b:2", "^c", "^d"});
2081 error_msg =
2082 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2083 "from_port=-1, to_port=3) error: port must be in range [0, 2].";
2084 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/-1,
2085 /*to_port=*/3, /*success=*/false, error_msg,
2086 {"a", "b:2", "b:2", "^c", "^d"});
2087 error_msg =
2088 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_4', "
2089 "from_port=3, to_port=-1) error: port must be in range [0, 2].";
2090 TestSwapRegularFaninsByPorts("foo_4", /*node_exists=*/true, /*from_port=*/3,
2091 /*to_port=*/-1, /*success=*/false, error_msg,
2092 {"a", "b:2", "b:2", "^c", "^d"});
2093
2094 // Swapping fanin to node where node is missing.
2095 error_msg =
2096 "MutableGraphView::SwapRegularFaninsByPorts(node_name='foo_missing', "
2097 "from_port=0, to_port=1) error: node 'foo_missing' was not found.";
2098 TestSwapRegularFaninsByPorts("foo_missing", /*node_exists=*/false,
2099 /*from_port=*/0, /*to_port=*/1,
2100 /*success=*/false, error_msg, {});
2101 }
2102
TEST(MutableGraphViewTest,DedupControllingFaninsOnGraphInit)2103 TEST(MutableGraphViewTest, DedupControllingFaninsOnGraphInit) {
2104 GraphDef graph_def = test::function::GDef(
2105 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2106 NDef("c", "Switch", {}, {}), NDef("d", "Identity", {"c:1"}),
2107 NDef("foo_1", "IdentityN", {"a", "b:1", "^b"}),
2108 NDef("foo_2", "IdentityN", {"a", "^b", "^b"}),
2109 NDef("foo_3", "IdentityN", {"a", "b:1", "^b", "^b"}),
2110 NDef("foo_4", "IdentityN", {"a:2", "b:1", "^b", "^b", "^a", "^a"}),
2111 NDef("foo_5", "NotImportant", {"a:2", "b:1", "^b", "^b", "^a", "^a"}),
2112 NDef("foo_6", "Identity", {"d", "^d"}),
2113 NDef("foo_7", "NotImportant",
2114 {"a:3", "b:2", "d", "^d", "^d", "^a", "^b", "^a", "^b"})},
2115 /*funcs=*/{});
2116
2117 MutableGraphView graph(&graph_def);
2118
2119 EXPECT_EQ(graph.graph()->node_size(), 11);
2120
2121 CheckNode(graph, "a", "NotImportant", "", {}, {},
2122 {"foo_1", "foo_2", "foo_3", "foo_4", "foo_5", "foo_7"});
2123 CheckNode(graph, "b", "NotImportant", "", {}, {},
2124 {"foo_1:1", "^foo_2", "foo_3:1", "foo_4:1", "foo_5:1", "foo_7:1"});
2125 CheckNode(graph, "c", "Switch", "", {}, {}, {"d"});
2126 CheckNode(graph, "d", "Identity", "", {}, {"c:1"},
2127 {"foo_6", "^foo_6", "foo_7:2", "^foo_7"});
2128 CheckNode(graph, "foo_1", "IdentityN", "", {}, {"a", "b:1"}, {});
2129 CheckNode(graph, "foo_2", "IdentityN", "", {}, {"a", "^b"}, {});
2130 CheckNode(graph, "foo_3", "IdentityN", "", {}, {"a", "b:1"}, {});
2131 CheckNode(graph, "foo_4", "IdentityN", "", {}, {"a:2", "b:1"}, {});
2132 CheckNode(graph, "foo_5", "NotImportant", "", {}, {"a:2", "b:1"}, {});
2133 CheckNode(graph, "foo_6", "Identity", "", {}, {"d", "^d"}, {});
2134 CheckNode(graph, "foo_7", "NotImportant", "", {}, {"a:3", "b:2", "d", "^d"},
2135 {});
2136
2137 CheckGraph(graph);
2138 }
2139
TEST(MutableGraphViewTest,DedupControllingFaninsOnAddFanin)2140 TEST(MutableGraphViewTest, DedupControllingFaninsOnAddFanin) {
2141 // Actual node.op() is not important in this test.
2142 GraphDef graph_def = test::function::GDef(
2143 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"^a"}),
2144 NDef("c", "NotImportant", {"a:1"})},
2145 /*funcs=*/{});
2146
2147 MutableGraphView graph(&graph_def);
2148
2149 TF_EXPECT_OK(graph.AddRegularFanin("b", {"a", 2}));
2150 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {});
2151
2152 TF_EXPECT_OK(graph.AddControllingFanin("c", {"a", Graph::kControlSlot}));
2153 CheckNode(graph, "c", "NotImportant", "", {}, {"a:1"}, {});
2154
2155 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b:0", "c:0"});
2156
2157 CheckGraph(graph);
2158 }
2159
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnAddFanin)2160 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnAddFanin) {
2161 GraphDef graph_def = test::function::GDef(
2162 {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2163 NDef("c", "", {}, {}), NDef("d", "", {}, {})},
2164 /*funcs=*/{});
2165
2166 MutableGraphView graph(&graph_def);
2167
2168 TF_EXPECT_OK(graph.AddRegularFanin("c", {"b", 2}));
2169 CheckNode(graph, "c", "", "", {}, {"b:2"}, {});
2170 TF_EXPECT_OK(graph.AddControllingFanin("c", {"b", Graph::kControlSlot}));
2171 CheckNode(graph, "c", "", "", {}, {"b:2", "^b"}, {});
2172 TF_EXPECT_OK(graph.AddControllingFanin("c", {"b", Graph::kControlSlot}));
2173 CheckNode(graph, "c", "", "", {}, {"b:2", "^b"}, {});
2174 TF_EXPECT_OK(graph.AddRegularFanin("c", {"b", 2}));
2175 CheckNode(graph, "c", "", "", {}, {"b:2", "b:2", "^b"}, {});
2176
2177 TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2178 CheckNode(graph, "d", "", "", {}, {"^b"}, {});
2179 TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2180 CheckNode(graph, "d", "", "", {}, {"^b"}, {});
2181
2182 CheckNode(graph, "a", "Switch", "", {}, {}, {"b"});
2183 CheckNode(graph, "b", "Identity", "", {}, {"a:1"},
2184 {"c:0", "c:1", "^c", "^d"});
2185
2186 CheckGraph(graph);
2187 }
2188
TEST(MutableGraphViewTest,DedupControllingFaninsOnAddFaninByPort)2189 TEST(MutableGraphViewTest, DedupControllingFaninsOnAddFaninByPort) {
2190 // Actual node.op() is not important in this test.
2191 GraphDef graph_def =
2192 test::function::GDef({NDef("a", "NotImportant", {}, {}),
2193 NDef("b", "NotImportant", {"c", "^a"}),
2194 NDef("c", "NotImportant", {"a:1"})},
2195 /*funcs=*/{});
2196
2197 MutableGraphView graph(&graph_def);
2198
2199 TF_EXPECT_OK(graph.AddRegularFaninByPort("b", 0, {"a", 2}));
2200 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2", "c"}, {});
2201
2202 TF_EXPECT_OK(graph.AddControllingFanin("c", {"a", Graph::kControlSlot}));
2203 CheckNode(graph, "c", "NotImportant", "", {}, {"a:1"}, {"b:1"});
2204
2205 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b:0", "c:0"});
2206
2207 CheckGraph(graph);
2208 }
2209
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnAddFaninByPort)2210 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnAddFaninByPort) {
2211 GraphDef graph_def = test::function::GDef(
2212 {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2213 NDef("c", "", {}, {}), NDef("d", "", {"c:2"}, {})},
2214 /*funcs=*/{});
2215
2216 MutableGraphView graph(&graph_def);
2217
2218 TF_EXPECT_OK(graph.AddRegularFaninByPort("d", 1, {"b", 2}));
2219 CheckNode(graph, "d", "", "", {}, {"c:2", "b:2"}, {});
2220 TF_EXPECT_OK(graph.AddControllingFanin("d", {"b", Graph::kControlSlot}));
2221 CheckNode(graph, "d", "", "", {}, {"c:2", "b:2", "^b"}, {});
2222 TF_EXPECT_OK(graph.AddRegularFaninByPort("d", 0, {"b", 2}));
2223 CheckNode(graph, "d", "", "", {}, {"b:2", "c:2", "b:2", "^b"}, {});
2224
2225 CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0"});
2226 CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {"d:0", "d:2", "^d"});
2227 CheckNode(graph, "c", "", "", {}, {}, {"d:1"});
2228
2229 CheckGraph(graph);
2230 }
2231
TEST(MutableGraphViewTest,DedupControllingFaninsOnUpdateFanin)2232 TEST(MutableGraphViewTest, DedupControllingFaninsOnUpdateFanin) {
2233 // Actual node.op() is not important in this test.
2234 GraphDef graph_def = test::function::GDef(
2235 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2236 NDef("c", "NotImportant", {"a:1", "^b"})},
2237 /*funcs=*/{});
2238
2239 MutableGraphView graph(&graph_def);
2240
2241 TF_EXPECT_OK(graph.UpdateFanin("c", {"a", 1}, {"b", 2}));
2242
2243 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2244 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"c"});
2245 CheckNode(graph, "c", "NotImportant", "", {}, {"b:2"}, {});
2246
2247 CheckGraph(graph);
2248 }
2249
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnUpdateFanin)2250 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnUpdateFanin) {
2251 GraphDef graph_def = test::function::GDef(
2252 {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2253 NDef("c", "Identity", {"a:2"}), NDef("d", "NotImportant", {"c", "^b"}),
2254 NDef("e", "NotImportant", {"b", "^c"})},
2255 /*funcs=*/{});
2256
2257 MutableGraphView graph(&graph_def);
2258
2259 TF_EXPECT_OK(graph.UpdateFanin("d", {"b", Graph::kControlSlot},
2260 {"c", Graph::kControlSlot}));
2261 CheckNode(graph, "d", "NotImportant", "", {}, {"c", "^c"}, {});
2262
2263 TF_EXPECT_OK(graph.UpdateFanin("e", {"b", 0}, {"c", 3}));
2264 CheckNode(graph, "e", "NotImportant", "", {}, {"c:3", "^c"}, {});
2265
2266 TF_EXPECT_OK(graph.UpdateFanin("e", {"c", 3}, {"c", Graph::kControlSlot}));
2267 CheckNode(graph, "e", "NotImportant", "", {}, {"^c"}, {});
2268
2269 CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0", "c:0"});
2270 CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {});
2271 CheckNode(graph, "c", "Identity", "", {}, {"a:2"}, {"d:0", "^d", "^e"});
2272
2273 CheckGraph(graph);
2274 }
2275
TEST(MutableGraphViewTest,DedupControllingFaninsOnUpdateFaninByPort)2276 TEST(MutableGraphViewTest, DedupControllingFaninsOnUpdateFaninByPort) {
2277 // Actual node.op() is not important in this test.
2278 GraphDef graph_def = test::function::GDef(
2279 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2280 NDef("c", "NotImportant", {"a:1", "^b"})},
2281 /*funcs=*/{});
2282
2283 MutableGraphView graph(&graph_def);
2284
2285 TF_EXPECT_OK(graph.UpdateRegularFaninByPort("c", 0, {"b", 2}));
2286
2287 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2288 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"c"});
2289 CheckNode(graph, "c", "NotImportant", "", {}, {"b:2"}, {});
2290
2291 CheckGraph(graph);
2292 }
2293
TEST(MutableGraphViewTest,NoDedupControllingFaninsOnUpdateFaninByPort)2294 TEST(MutableGraphViewTest, NoDedupControllingFaninsOnUpdateFaninByPort) {
2295 GraphDef graph_def = test::function::GDef(
2296 {NDef("a", "Switch", {}, {}), NDef("b", "Identity", {"a:1"}),
2297 NDef("c", "Identity", {"a:2"}), NDef("d", "NotImportant", {"c", "^b"}),
2298 NDef("e", "NotImportant", {"b", "^c"})},
2299 /*funcs=*/{});
2300
2301 MutableGraphView graph(&graph_def);
2302
2303 TF_EXPECT_OK(graph.UpdateRegularFaninByPort("d", 0, {"b", 1}));
2304 CheckNode(graph, "d", "NotImportant", "", {}, {"b:1", "^b"}, {});
2305
2306 TF_EXPECT_OK(graph.UpdateRegularFaninByPort("e", 0, {"c", 2}));
2307 CheckNode(graph, "e", "NotImportant", "", {}, {"c:2", "^c"}, {});
2308
2309 CheckNode(graph, "a", "Switch", "", {}, {}, {"b:0", "c:0"});
2310 CheckNode(graph, "b", "Identity", "", {}, {"a:1"}, {"d:0", "^d"});
2311 CheckNode(graph, "c", "Identity", "", {}, {"a:2"}, {"e:0", "^e"});
2312
2313 CheckGraph(graph);
2314 }
2315
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnAddFanin)2316 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnAddFanin) {
2317 // Actual node.op() is not important in this test.
2318 GraphDef graph_def = test::function::GDef(
2319 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2320 NDef("c", "NotImportant", {"^b"})},
2321 /*funcs=*/{});
2322
2323 MutableGraphView graph(&graph_def);
2324
2325 TF_EXPECT_OK(graph.AddRegularFanin("c", {"a", 3}));
2326
2327 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2328 CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {"^c"});
2329 CheckNode(graph, "c", "NotImportant", "", {}, {"a:3", "^b"}, {});
2330
2331 CheckGraph(graph);
2332 }
2333
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnRemoveFanin)2334 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnRemoveFanin) {
2335 // Actual node.op() is not important in this test.
2336 GraphDef graph_def = test::function::GDef(
2337 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2338 NDef("c", "NotImportant", {"a:2"})},
2339 /*funcs=*/{});
2340
2341 MutableGraphView graph(&graph_def);
2342
2343 TF_EXPECT_OK(graph.RemoveRegularFanin("c", {"a", 2}));
2344 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b"});
2345 CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {});
2346 CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2347
2348 CheckGraph(graph);
2349 }
2350
TEST(MutableGraphViewTest,KeepMaxRegularOutputPortOnRemoveFanin)2351 TEST(MutableGraphViewTest, KeepMaxRegularOutputPortOnRemoveFanin) {
2352 // Actual node.op() is not important in this test.
2353 GraphDef graph_def = test::function::GDef(
2354 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2355 NDef("c", "NotImportant", {"a:2"})},
2356 /*funcs=*/{});
2357
2358 MutableGraphView graph(&graph_def);
2359
2360 TF_EXPECT_OK(graph.RemoveRegularFanin("b", {"a", 1}));
2361
2362 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"c"});
2363 CheckNode(graph, "b", "NotImportant", "", {}, {}, {});
2364 CheckNode(graph, "c", "NotImportant", "", {}, {"a:2"}, {});
2365
2366 CheckGraph(graph);
2367 }
2368
TEST(MutableGraphViewTest,UpdateMaxRegularOutputPortOnUpdateFanin)2369 TEST(MutableGraphViewTest, UpdateMaxRegularOutputPortOnUpdateFanin) {
2370 // Actual node.op() is not important in this test.
2371 GraphDef graph_def = test::function::GDef(
2372 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:1"}),
2373 NDef("c", "NotImportant", {"a:2"})},
2374 /*funcs=*/{});
2375
2376 MutableGraphView graph(&graph_def);
2377
2378 TF_EXPECT_OK(graph.UpdateFanin("c", {"a", 2}, {"b", 3}));
2379
2380 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b"});
2381 CheckNode(graph, "b", "NotImportant", "", {}, {"a:1"}, {"c"});
2382 CheckNode(graph, "c", "NotImportant", "", {}, {"b:3"}, {});
2383
2384 CheckGraph(graph);
2385 }
2386
TEST(MutableGraphViewTest,AddControllingFaninMissing)2387 TEST(MutableGraphViewTest, AddControllingFaninMissing) {
2388 // Actual node.op() is not important in this test.
2389 GraphDef graph_def = test::function::GDef(
2390 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2391 /*funcs=*/{});
2392
2393 MutableGraphView graph(&graph_def);
2394 // Missing fanin.
2395 Status s = graph.AddControllingFanin("a", {"c", Graph::kControlSlot});
2396 EXPECT_FALSE(s.ok());
2397 string expected_msg =
2398 "MutableGraphView::AddControllingFanin(node_name='a', fanin='^c') error: "
2399 "node 'c' was not found.";
2400 EXPECT_EQ(s.error_message(), expected_msg);
2401 // Missing node.
2402 s = graph.AddControllingFanin("d", {"a", Graph::kControlSlot});
2403 EXPECT_FALSE(s.ok());
2404 expected_msg =
2405 "MutableGraphView::AddControllingFanin(node_name='d', fanin='^a') error: "
2406 "node 'd' was not found.";
2407 EXPECT_EQ(s.error_message(), expected_msg);
2408 // Missing node and fanin.
2409 s = graph.AddControllingFanin("c", {"d", Graph::kControlSlot});
2410 EXPECT_FALSE(s.ok());
2411 expected_msg =
2412 "MutableGraphView::AddControllingFanin(node_name='c', fanin='^d') error: "
2413 "node 'c' was not found.";
2414 EXPECT_EQ(s.error_message(), expected_msg);
2415
2416 ASSERT_EQ(graph.graph()->node_size(), 2);
2417
2418 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2419 CheckNode(graph, "b", "NotImportant", "", {}, {}, {});
2420
2421 CheckGraph(graph);
2422 }
2423
TEST(MutableGraphViewTest,AddControllingFaninExistingControl)2424 TEST(MutableGraphViewTest, AddControllingFaninExistingControl) {
2425 // Actual node.op() is not important in this test.
2426 GraphDef graph_def = test::function::GDef(
2427 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2428 /*funcs=*/{});
2429
2430 MutableGraphView graph(&graph_def);
2431 TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
2432 TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
2433
2434 ASSERT_EQ(graph.graph()->node_size(), 2);
2435
2436 CheckNode(graph, "a", "NotImportant", "", {}, {"^b"}, {});
2437 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^a"});
2438
2439 CheckGraph(graph);
2440 }
2441
TEST(MutableGraphViewTest,AddControllingFaninNotSwitch)2442 TEST(MutableGraphViewTest, AddControllingFaninNotSwitch) {
2443 // Actual node.op() is not important in this test.
2444 GraphDef graph_def = test::function::GDef(
2445 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
2446 /*funcs=*/{});
2447
2448 MutableGraphView graph(&graph_def);
2449 TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", 2}));
2450 TF_EXPECT_OK(graph.AddControllingFanin("a", {"b", 2}));
2451
2452 ASSERT_EQ(graph.graph()->node_size(), 2);
2453
2454 CheckNode(graph, "a", "NotImportant", "", {}, {"^b"}, {});
2455 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^a"});
2456
2457 CheckGraph(graph);
2458 }
2459
TEST(MutableGraphViewTest,AddControllingFaninSwitch)2460 TEST(MutableGraphViewTest, AddControllingFaninSwitch) {
2461 GraphDef graph_def = test::function::GDef(
2462 {NDef("a", "NotImportant", {}, {}), NDef("b", "Switch", {}, {})},
2463 /*funcs=*/{});
2464
2465 MutableGraphView graph(&graph_def);
2466
2467 Status s = graph.AddControllingFanin("a", {"b", Graph::kControlSlot});
2468 EXPECT_FALSE(s.ok());
2469 string expected_msg =
2470 "MutableGraphView::AddControllingFanin(node_name='a', fanin='^b') error: "
2471 "can't add fanin '^b' as it will become a Switch control dependency.";
2472 EXPECT_EQ(s.error_message(), expected_msg);
2473
2474 ASSERT_EQ(graph.graph()->node_size(), 2);
2475
2476 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2477 CheckNode(graph, "b", "Switch", "", {}, {}, {});
2478
2479 CheckGraph(graph);
2480 }
2481
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithIdentity)2482 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithIdentity) {
2483 GraphDef graph_def = test::function::GDef(
2484 {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
2485 NDef("identity", "Identity", {"switch"})},
2486 /*funcs=*/{});
2487
2488 MutableGraphView graph(&graph_def);
2489
2490 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2491 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2492
2493 ASSERT_EQ(graph.graph()->node_size(), 3);
2494
2495 CheckNode(graph, "a", "NotImportant", "", {}, {"^identity"}, {});
2496 CheckNode(graph, "switch", "Switch", "", {}, {}, {"identity"});
2497 CheckNode(graph, "identity", "Identity", "", {}, {"switch"}, {"^a"});
2498
2499 CheckGraph(graph);
2500 }
2501
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithNoExistingIdentity)2502 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithNoExistingIdentity) {
2503 constexpr char kDevice[] = "/device:foo:0";
2504 GraphDef graph_def = test::function::GDef(
2505 {NDef("a", "NotImportant", {}, {}),
2506 NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice)},
2507 /*funcs=*/{});
2508
2509 MutableGraphView graph(&graph_def);
2510
2511 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2512 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2513
2514 ASSERT_EQ(graph.graph()->node_size(), 3);
2515
2516 CheckNode(graph, "a", "NotImportant", "", {},
2517 {"^ConstantFoldingCtrl/switch_0"}, {});
2518 CheckNode(graph, "switch", "Switch", kDevice, {{"T", DT_FLOAT}}, {},
2519 {"ConstantFoldingCtrl/switch_0"});
2520 CheckNode(graph, "ConstantFoldingCtrl/switch_0", "Identity", kDevice,
2521 {{"T", DT_FLOAT}}, {"switch"}, {"^a"});
2522
2523 CheckGraph(graph);
2524 }
2525
TEST(MutableGraphViewTest,AddControllingFaninSwitchWithExistingAddedIdentity)2526 TEST(MutableGraphViewTest, AddControllingFaninSwitchWithExistingAddedIdentity) {
2527 GraphDef graph_def = test::function::GDef(
2528 {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
2529 NDef("ConstantFoldingCtrl/switch_0", "Identity", {"switch"})},
2530 /*funcs=*/{});
2531
2532 MutableGraphView graph(&graph_def);
2533
2534 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2535 TF_EXPECT_OK(graph.AddControllingFanin("a", {"switch", 0}));
2536
2537 ASSERT_EQ(graph.graph()->node_size(), 3);
2538
2539 CheckNode(graph, "a", "NotImportant", "", {},
2540 {"^ConstantFoldingCtrl/switch_0"}, {});
2541 CheckNode(graph, "switch", "Switch", "", {}, {},
2542 {"ConstantFoldingCtrl/switch_0"});
2543 CheckNode(graph, "ConstantFoldingCtrl/switch_0", "Identity", "", {},
2544 {"switch"}, {"^a"});
2545
2546 CheckGraph(graph);
2547 }
2548
TestAddControllingFaninSelfLoops(absl::string_view node_name,const TensorId & fanin,const string & error_msg)2549 void TestAddControllingFaninSelfLoops(absl::string_view node_name,
2550 const TensorId& fanin,
2551 const string& error_msg) {
2552 GraphDef graph_def = test::function::GDef(
2553 {NDef("a", "NotImportant", {}, {}),
2554 NDef("b", "Switch", {}, {{"T", DT_FLOAT}}),
2555 NDef("c", "Identity", {"b:0"}), NDef("d", "Identity", {"b:1"}),
2556 NDef("e", "NotImportant", {"^a"})},
2557 /*funcs=*/{});
2558
2559 MutableGraphView graph(&graph_def);
2560
2561 Status s = graph.AddControllingFanin(node_name, fanin);
2562 EXPECT_FALSE(s.ok());
2563 EXPECT_EQ(s.error_message(), error_msg);
2564
2565 EXPECT_EQ(graph.graph()->node_size(), 5);
2566
2567 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^e"});
2568 CheckNode(graph, "b", "Switch", "", {{"T", DT_FLOAT}}, {}, {"c", "d"});
2569 CheckNode(graph, "c", "Identity", "", {}, {"b"}, {});
2570 CheckNode(graph, "d", "Identity", "", {}, {"b:1"}, {});
2571 CheckNode(graph, "e", "NotImportant", "", {}, {"^a"}, {});
2572
2573 CheckGraph(graph);
2574 }
2575
TEST(MutableGraphViewTest,AddControllingFaninSelfLoops)2576 TEST(MutableGraphViewTest, AddControllingFaninSelfLoops) {
2577 string error_msg =
2578 "MutableGraphView::AddControllingFanin(node_name='a', fanin='^a') error: "
2579 "can't add fanin '^a' to self.";
2580 TestAddControllingFaninSelfLoops("a", {"a", Graph::kControlSlot}, error_msg);
2581
2582 // Adding Switch control dependency to Identity consumer. Node `c` is
2583 // consuming `b:0`, so adding `b:0` as a control dependency, because it is a
2584 // Switch, should trigger a lookup of outputs. As `c` is a consumer and an
2585 // Identity, this will introduce a self loop, so no control dependency should
2586 // be added.
2587 error_msg =
2588 "MutableGraphView::AddControllingFanin(node_name='c', fanin='b:0') "
2589 "error: can't add found fanin '^c' to self.";
2590 TestAddControllingFaninSelfLoops("c", {"b", 0}, error_msg);
2591
2592 // Adding Switch control dependency to Identity consumer. Node `d` is
2593 // consuming `b:1`, so adding `b:1` as a control dependency, because it is a
2594 // Switch, should trigger a lookup of outputs. As `d` is a consumer and an
2595 // Identity, this will introduce a self loop, so no control dependency should
2596 // be added.
2597 error_msg =
2598 "MutableGraphView::AddControllingFanin(node_name='d', fanin='b:1') "
2599 "error: can't add found fanin '^d' to self.";
2600 TestAddControllingFaninSelfLoops("d", {"b", 1}, error_msg);
2601 }
2602
TEST(MutableGraphViewTest,AddControllingFaninSelfLoopsGeneratedIdentity)2603 TEST(MutableGraphViewTest, AddControllingFaninSelfLoopsGeneratedIdentity) {
2604 GraphDef graph_def =
2605 test::function::GDef({NDef("a", "NotImportant", {}, {}),
2606 NDef("b", "Switch", {}, {{"T", DT_FLOAT}}),
2607 NDef("c", "NotImportant", {}),
2608 NDef("ConstantFoldingCtrl/b_1", "Identity", {})},
2609 /*funcs=*/{});
2610
2611 MutableGraphView graph(&graph_def);
2612
2613 // Adding Switch control dependency to Identity node of the same name as a
2614 // generated Identity node for pinning the control dependency. Because there
2615 // are no consumers of `b:1`, there will be an attempt to generate an Identity
2616 // node, with name `ConstantFoldingCtrl/b_1`. As the input node is of the same
2617 // name, we will introduce a self loop, so no control dependency should be
2618 // added.
2619 Status s = graph.AddControllingFanin("ConstantFoldingCtrl/b_1", {"b", 1});
2620 EXPECT_FALSE(s.ok());
2621 string expected_msg =
2622 "MutableGraphView::AddControllingFanin(node_name='ConstantFoldingCtrl/"
2623 "b_1', fanin='b:1') error: can't add generated fanin "
2624 "'^ConstantFoldingCtrl/b_1' to self.";
2625 EXPECT_EQ(s.error_message(), expected_msg);
2626
2627 EXPECT_EQ(graph.graph()->node_size(), 4);
2628
2629 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2630 CheckNode(graph, "b", "Switch", "", {{"T", DT_FLOAT}}, {}, {});
2631 CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2632 CheckNode(graph, "ConstantFoldingCtrl/b_1", "Identity", "", {}, {}, {});
2633
2634 CheckGraph(graph);
2635 }
2636
TEST(MutableGraphViewTest,RemoveControllingFaninMissing)2637 TEST(MutableGraphViewTest, RemoveControllingFaninMissing) {
2638 // Actual node.op() is not important in this test.
2639 GraphDef graph_def = test::function::GDef(
2640 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2641 NDef("c", "NotImportant", {}, {}),
2642 NDef("d", "NotImportant", {"^a", "^b"})},
2643 /*funcs=*/{});
2644
2645 MutableGraphView graph(&graph_def);
2646
2647 TF_EXPECT_OK(graph.RemoveControllingFanin("d", "c"));
2648
2649 ASSERT_EQ(graph.graph()->node_size(), 4);
2650
2651 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"^d"});
2652 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^d"});
2653 CheckNode(graph, "c", "NotImportant", "", {}, {}, {});
2654 CheckNode(graph, "d", "NotImportant", "", {}, {"^a", "^b"}, {});
2655
2656 CheckGraph(graph);
2657 }
2658
TEST(MutableGraphViewTest,RemoveControllingFaninExisting)2659 TEST(MutableGraphViewTest, RemoveControllingFaninExisting) {
2660 // Actual node.op() is not important in this test.
2661 GraphDef graph_def = test::function::GDef(
2662 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
2663 NDef("c", "NotImportant", {}, {}),
2664 NDef("d", "NotImportant", {"^a", "^b", "^c"})},
2665 /*funcs=*/{});
2666
2667 MutableGraphView graph(&graph_def);
2668
2669 TF_EXPECT_OK(graph.RemoveControllingFanin("d", "a"));
2670 TF_EXPECT_OK(graph.RemoveControllingFanin("d", "a"));
2671
2672 ASSERT_EQ(graph.graph()->node_size(), 4);
2673
2674 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2675 CheckNode(graph, "b", "NotImportant", "", {}, {}, {"^d"});
2676 CheckNode(graph, "c", "NotImportant", "", {}, {}, {"^d"});
2677 CheckNode(graph, "d", "NotImportant", "", {}, {"^c", "^b"}, {});
2678
2679 CheckGraph(graph);
2680 }
2681
TEST(MutableGraphViewTest,RemoveControllingFaninOnRegularFanin)2682 TEST(MutableGraphViewTest, RemoveControllingFaninOnRegularFanin) {
2683 // Actual node.op() is not important in this test.
2684 GraphDef graph_def = test::function::GDef(
2685 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
2686 NDef("c", "NotImportant", {"a", "b"})},
2687 /*funcs=*/{});
2688
2689 MutableGraphView graph(&graph_def);
2690
2691 TF_EXPECT_OK(graph.RemoveControllingFanin("c", "a"));
2692 TF_EXPECT_OK(graph.RemoveControllingFanin("c", "b"));
2693
2694 ASSERT_EQ(graph.graph()->node_size(), 3);
2695
2696 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2697 CheckNode(graph, "b", "NotImportant", "", {}, {"a"}, {"c:1"});
2698 CheckNode(graph, "c", "NotImportant", "", {}, {"a", "b"}, {});
2699
2700 CheckGraph(graph);
2701 }
2702
TEST(MutableGraphViewTest,RemoveControllingFaninSelfLoop)2703 TEST(MutableGraphViewTest, RemoveControllingFaninSelfLoop) {
2704 // Actual node.op() is not important in this test.
2705 GraphDef graph_def = test::function::GDef(
2706 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a"}),
2707 NDef("c", "NotImportant", {"a", "b"})},
2708 /*funcs=*/{});
2709
2710 MutableGraphView graph(&graph_def);
2711
2712 Status s = graph.RemoveControllingFanin("c", "c");
2713 EXPECT_FALSE(s.ok());
2714 string expected_msg =
2715 "MutableGraphView::RemoveControllingFanin(node_name='c', "
2716 "fanin_node_name='c') error: can't remove fanin '^c' from "
2717 "self.";
2718 EXPECT_EQ(s.error_message(), expected_msg);
2719
2720 ASSERT_EQ(graph.graph()->node_size(), 3);
2721
2722 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2723 CheckNode(graph, "b", "NotImportant", "", {}, {"a"}, {"c:1"});
2724 CheckNode(graph, "c", "NotImportant", "", {}, {"a", "b"}, {});
2725
2726 CheckGraph(graph);
2727 }
2728
TestUpdateAllRegularFaninsToControlling(absl::string_view node_name,bool node_exists,bool success,const string & error_msg,absl::Span<const string> expected_fanins)2729 void TestUpdateAllRegularFaninsToControlling(
2730 absl::string_view node_name, bool node_exists, bool success,
2731 const string& error_msg, absl::Span<const string> expected_fanins) {
2732 constexpr char kDevice[] = "/device:foo:0";
2733 GraphDef graph_def = test::function::GDef(
2734 {NDef("a", "NotImportant", {}, {}),
2735 NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
2736 NDef("b", "NotImportant", {"switch:1"}, {}),
2737 NDef("ConstantFoldingCtrl/switch_1", "Identity", {"switch:1"},
2738 {{"T", DT_FLOAT}}, kDevice),
2739 NDef("c", "NotImportant", {"a", "^b"}, {}),
2740 NDef("d", "NotImportant", {"b", "c"}, {}),
2741 NDef("e", "NotImportant", {"^d"}, {})},
2742 /*funcs=*/{});
2743
2744 MutableGraphView graph(&graph_def);
2745
2746 NodeDef* node = graph.GetNode(node_name);
2747 if (node_exists) {
2748 EXPECT_NE(node, nullptr);
2749 } else {
2750 EXPECT_EQ(node, nullptr);
2751 }
2752
2753 absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
2754 GetNodeInputsFromGraph(graph_def, node_name);
2755
2756 Status s = graph.UpdateAllRegularFaninsToControlling(node_name);
2757 EXPECT_EQ(s.ok(), success);
2758 if (!success) {
2759 EXPECT_EQ(s.error_message(), error_msg);
2760 }
2761 if (node_exists) {
2762 CompareNodeFanins(graph, node, expected_fanins);
2763 }
2764
2765 CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
2766
2767 CheckGraph(graph);
2768 }
2769
TEST(MutableGraphViewTest,UpdateAllRegularFaninsToControlling)2770 TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControlling) {
2771 string error_msg;
2772 // Nodes with some regular fanins and some controls.
2773 TestUpdateAllRegularFaninsToControlling("a", /*node_exists=*/true,
2774 /*success=*/true, error_msg, {});
2775 TestUpdateAllRegularFaninsToControlling("c", /*node_exists=*/true,
2776 /*success=*/true, error_msg,
2777 {"^a", "^b"});
2778 TestUpdateAllRegularFaninsToControlling("d", /*node_exists=*/true,
2779 /*success=*/true, error_msg,
2780 {"^b", "^c"});
2781 TestUpdateAllRegularFaninsToControlling("e", /*node_exists=*/true,
2782 /*success=*/true, error_msg, {"^d"});
2783
2784 // Use existing Identity to pin control dependency of Switch.
2785 TestUpdateAllRegularFaninsToControlling("b", /*node_exists=*/true,
2786 /*success=*/true, error_msg,
2787 {"^ConstantFoldingCtrl/switch_1"});
2788
2789 // Missing node.
2790 error_msg =
2791 "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='f') "
2792 "error: node 'f' was not found.";
2793 TestUpdateAllRegularFaninsToControlling("f", /*node_exists=*/false,
2794 /*success=*/false, error_msg, {});
2795
2796 // Error in getting controlling fanin.
2797 error_msg =
2798 "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='"
2799 "ConstantFoldingCtrl/switch_1') error: can't add found fanin "
2800 "'^ConstantFoldingCtrl/switch_1' to self.";
2801 TestUpdateAllRegularFaninsToControlling("ConstantFoldingCtrl/switch_1",
2802 /*node_exists=*/true,
2803 /*success=*/false, error_msg,
2804 {"switch:1"});
2805 }
2806
TEST(MutableGraphViewTest,UpdateAllRegularFaninsToControllingConsumingSwitch)2807 TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControllingConsumingSwitch) {
2808 constexpr char kDevice[] = "/device:foo:0";
2809 GraphDef graph_def = test::function::GDef(
2810 {NDef("a", "NotImportant", {}, {}),
2811 NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
2812 NDef("b", "NotImportant", {"switch:1"}, {})},
2813 /*funcs=*/{});
2814
2815 MutableGraphView graph(&graph_def);
2816
2817 TF_EXPECT_OK(graph.UpdateAllRegularFaninsToControlling("b"));
2818
2819 EXPECT_EQ(graph.graph()->node_size(), 4);
2820
2821 CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
2822 CheckNode(graph, "switch", "Switch", kDevice, {{"T", DT_FLOAT}}, {},
2823 {"ConstantFoldingCtrl/switch_1"});
2824 CheckNode(graph, "b", "NotImportant", "", {},
2825 {"^ConstantFoldingCtrl/switch_1"}, {});
2826 CheckNode(graph, "ConstantFoldingCtrl/switch_1", "Identity", kDevice,
2827 {{"T", DT_FLOAT}}, {"switch:1"}, {"^b"});
2828
2829 CheckGraph(graph);
2830 }
2831
TEST(MutableGraphViewTest,DeleteNodes)2832 TEST(MutableGraphViewTest, DeleteNodes) {
2833 // Actual node.op() is not important in this test.
2834 GraphDef graph_def = test::function::GDef(
2835 {NDef("bar", "NotImportant", {}, {}),
2836 NDef("other", "NotImportant", {}, {}),
2837 NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
2838 NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
2839 /*funcs=*/{});
2840
2841 MutableGraphView graph(&graph_def);
2842
2843 EXPECT_NE(graph.GetNode("foo_1"), nullptr);
2844 TF_EXPECT_OK(graph.DeleteNodes({"foo_1"}));
2845
2846 EXPECT_EQ(graph.graph()->node_size(), 3);
2847 EXPECT_EQ(graph.GetNode("foo_1"), nullptr);
2848
2849 CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"foo_2:1"});
2850 CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_2"});
2851 CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "bar:2"}, {});
2852
2853 CheckGraph(graph);
2854 }
2855
SimpleDeleteNodeGraph()2856 GraphDef SimpleDeleteNodeGraph() {
2857 // Actual node.op() is not important in this test.
2858 GraphDef graph_def = test::function::GDef(
2859 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
2860 NDef("c", "NotImportant", {"a:5", "^b"}), NDef("d", "NotImportant", {}),
2861 NDef("e", "NotImportant", {"d:2"}),
2862 NDef("f", "NotImportant", {"d:3", "^e"})},
2863 /*funcs=*/{});
2864 return graph_def;
2865 }
2866
TEST(MutableGraphViewTest,DeleteNodesWithFanoutsBeingDeleted)2867 TEST(MutableGraphViewTest, DeleteNodesWithFanoutsBeingDeleted) {
2868 GraphDef graph_def = SimpleDeleteNodeGraph();
2869
2870 MutableGraphView graph(&graph_def);
2871 EXPECT_NE(graph.GetNode("a"), nullptr);
2872 EXPECT_NE(graph.GetNode("b"), nullptr);
2873 EXPECT_NE(graph.GetNode("c"), nullptr);
2874 TF_EXPECT_OK(graph.DeleteNodes({"c", "a", "b"}));
2875
2876 EXPECT_EQ(graph.graph()->node_size(), 3);
2877 EXPECT_EQ(graph.GetNode("a"), nullptr);
2878 EXPECT_EQ(graph.GetNode("b"), nullptr);
2879 EXPECT_EQ(graph.GetNode("c"), nullptr);
2880
2881 CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2882 CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2883 CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2884
2885 CheckGraph(graph);
2886 }
2887
TEST(MutableGraphViewTest,DeleteMissingNodes)2888 TEST(MutableGraphViewTest, DeleteMissingNodes) {
2889 GraphDef graph_def = SimpleDeleteNodeGraph();
2890
2891 MutableGraphView graph(&graph_def);
2892
2893 EXPECT_EQ(graph.GetNode("g"), nullptr);
2894 EXPECT_EQ(graph.GetNode("h"), nullptr);
2895 TF_EXPECT_OK(graph.DeleteNodes({"g", "h"}));
2896
2897 EXPECT_EQ(graph.graph()->node_size(), 6);
2898 EXPECT_EQ(graph.GetNode("g"), nullptr);
2899 EXPECT_EQ(graph.GetNode("h"), nullptr);
2900
2901 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2902 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2903 CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2904 CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2905 CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2906 CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2907
2908 CheckGraph(graph);
2909 }
2910
TEST(MutableGraphViewTest,DeleteMissingNodesAndNodesWithFanoutsBeingDeleted)2911 TEST(MutableGraphViewTest, DeleteMissingNodesAndNodesWithFanoutsBeingDeleted) {
2912 GraphDef graph_def = SimpleDeleteNodeGraph();
2913
2914 MutableGraphView graph(&graph_def);
2915
2916 EXPECT_NE(graph.GetNode("d"), nullptr);
2917 EXPECT_NE(graph.GetNode("e"), nullptr);
2918 EXPECT_NE(graph.GetNode("f"), nullptr);
2919 TF_EXPECT_OK(graph.DeleteNodes({"d", "e", "f", "g", "h"}));
2920
2921 EXPECT_EQ(graph.graph()->node_size(), 3);
2922 EXPECT_EQ(graph.GetNode("d"), nullptr);
2923 EXPECT_EQ(graph.GetNode("e"), nullptr);
2924 EXPECT_EQ(graph.GetNode("f"), nullptr);
2925
2926 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2927 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2928 CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2929
2930 CheckGraph(graph);
2931 }
2932
TEST(MutableGraphViewTest,DeleteNodesWithError)2933 TEST(MutableGraphViewTest, DeleteNodesWithError) {
2934 GraphDef graph_def = SimpleDeleteNodeGraph();
2935
2936 MutableGraphView graph(&graph_def);
2937
2938 Status s = graph.DeleteNodes({"b", "a"});
2939 EXPECT_FALSE(s.ok());
2940 string error_msg =
2941 "MutableGraphView::DeleteNodes(nodes_to_delete={a, b}) error: can't "
2942 "delete node(s) with retained fanouts(s) [a, b].";
2943 EXPECT_EQ(s.error_message(), error_msg);
2944
2945 EXPECT_EQ(graph.graph()->node_size(), 6);
2946
2947 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "c"});
2948 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c"});
2949 CheckNode(graph, "c", "NotImportant", "", {}, {"a:5", "^b"}, {});
2950 CheckNode(graph, "d", "NotImportant", "", {}, {}, {"e", "f"});
2951 CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f"});
2952 CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {});
2953
2954 CheckGraph(graph);
2955 }
2956
TEST(MutableGraphViewTest,DeleteNodesWithLargeError)2957 TEST(MutableGraphViewTest, DeleteNodesWithLargeError) {
2958 // Actual node.op() is not important in this test.
2959 GraphDef graph_def = test::function::GDef(
2960 {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {"a:2"}),
2961 NDef("c", "NotImportant", {"^b"}), NDef("d", "NotImportant", {"c:6"}),
2962 NDef("e", "NotImportant", {"d:2"}),
2963 NDef("f", "NotImportant", {"d:3", "^e"}),
2964 NDef("g", "NotImportant", {"f"}), NDef("h", "NotImportant", {"a"}),
2965 NDef("i", "NotImportant", {"b"}), NDef("j", "NotImportant", {"c"}),
2966 NDef("k", "NotImportant", {"d"}), NDef("l", "NotImportant", {"e"}),
2967 NDef("m", "NotImportant", {"f"})},
2968 /*funcs=*/{});
2969
2970 MutableGraphView graph(&graph_def);
2971
2972 Status s = graph.DeleteNodes({"a", "b", "c", "d", "e", "f"});
2973 EXPECT_FALSE(s.ok());
2974 string error_msg =
2975 "MutableGraphView::DeleteNodes(nodes_to_delete={a, b, c, d, e, ...}) "
2976 "error: can't delete node(s) with retained fanouts(s) [a, b, c, d, e, "
2977 "...].";
2978 EXPECT_EQ(s.error_message(), error_msg);
2979
2980 EXPECT_EQ(graph.graph()->node_size(), 13);
2981
2982 CheckNode(graph, "a", "NotImportant", "", {}, {}, {"b", "h"});
2983 CheckNode(graph, "b", "NotImportant", "", {}, {"a:2"}, {"^c", "i"});
2984 CheckNode(graph, "c", "NotImportant", "", {}, {"^b"}, {"d", "j"});
2985 CheckNode(graph, "d", "NotImportant", "", {}, {"c:6"}, {"e", "f", "k"});
2986 CheckNode(graph, "e", "NotImportant", "", {}, {"d:2"}, {"^f", "l"});
2987 CheckNode(graph, "f", "NotImportant", "", {}, {"d:3", "^e"}, {"g", "m"});
2988 CheckNode(graph, "g", "NotImportant", "", {}, {"f"}, {});
2989 CheckNode(graph, "h", "NotImportant", "", {}, {"a"}, {});
2990 CheckNode(graph, "i", "NotImportant", "", {}, {"b"}, {});
2991 CheckNode(graph, "j", "NotImportant", "", {}, {"c"}, {});
2992 CheckNode(graph, "k", "NotImportant", "", {}, {"d"}, {});
2993 CheckNode(graph, "l", "NotImportant", "", {}, {"e"}, {});
2994 CheckNode(graph, "m", "NotImportant", "", {}, {"f"}, {});
2995
2996 CheckGraph(graph);
2997 }
2998
2999 } // namespace
3000 } // namespace grappler
3001 } // namespace tensorflow
3002