1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/grappler_item.h"
21 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
22 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
23 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/grappler/utils/topological_sort.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace grappler {
32 namespace {
33
34 class DependencyOptimizerTest : public GrapplerTest {};
35
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)36 void VerifyGraphsEqual(const GraphDef& original_graph,
37 const GraphDef& optimized_graph, const string& func) {
38 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
39 for (int i = 0; i < original_graph.node_size(); ++i) {
40 const NodeDef& original = original_graph.node(i);
41 const NodeDef& optimized = optimized_graph.node(i);
42 EXPECT_EQ(original.name(), optimized.name()) << func;
43 EXPECT_EQ(original.op(), optimized.op()) << func;
44 EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
45 for (int j = 0; j < original.input_size(); ++j) {
46 EXPECT_EQ(original.input(j), optimized.input(j)) << func;
47 }
48 }
49 }
50
TEST_F(DependencyOptimizerTest,NoOp)51 TEST_F(DependencyOptimizerTest, NoOp) {
52 // This trivial graph is so basic there's nothing to optimize.
53 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
54 GrapplerItem item;
55 CHECK(fake_input.NextItem(&item));
56
57 DependencyOptimizer optimizer;
58 GraphDef output;
59 Status status = optimizer.Optimize(nullptr, item, &output);
60 TF_EXPECT_OK(status);
61
62 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
63 }
64
TEST_F(DependencyOptimizerTest,DependenciesDrivenByConstants)65 TEST_F(DependencyOptimizerTest, DependenciesDrivenByConstants) {
66 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
67 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
68 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
69 Output z = ops::Const(s.WithOpName("z"), {1.0f, 2.0f}, {1, 2});
70 Output add = ops::Add(s.WithOpName("add"), x, y);
71 Output id1 =
72 ops::Identity(s.WithOpName("id1").WithControlDependencies(x), add);
73 Output id2 = ops::Identity(
74 s.WithOpName("id2").WithControlDependencies(y).WithControlDependencies(z),
75 add);
76
77 GrapplerItem item;
78 TF_CHECK_OK(s.ToGraphDef(&item.graph));
79 item.fetch.push_back("id1");
80 item.fetch.push_back("id2");
81
82 DependencyOptimizer optimizer;
83 GraphDef output;
84 Status status = optimizer.Optimize(nullptr, item, &output);
85 TF_EXPECT_OK(status);
86 // Run the optimizer twice to make sure the rewrite is idempotent.
87 item.graph.Swap(&output);
88 status = optimizer.Optimize(nullptr, item, &output);
89 TF_EXPECT_OK(status);
90
91 // The 'z' node should have been optimized away leaving only 5 nodes.
92 EXPECT_EQ(5, output.node_size());
93
94 for (const NodeDef& node : item.graph.node()) {
95 if (node.name() == "id1" || node.name() == "id2") {
96 EXPECT_EQ(1, node.input_size());
97 EXPECT_EQ("add", node.input(0));
98 }
99 }
100 }
101
TEST_F(DependencyOptimizerTest,ChangeToNoop)102 TEST_F(DependencyOptimizerTest, ChangeToNoop) {
103 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
104 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
105 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
106 Output add = ops::Add(s.WithOpName("add"), x, y);
107 Output id1 =
108 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
109 Output id2 =
110 ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
111
112 GrapplerItem item;
113 TF_CHECK_OK(s.ToGraphDef(&item.graph));
114 item.fetch.push_back("id1");
115 item.fetch.push_back("id2");
116
117 DependencyOptimizer optimizer;
118 GraphDef output;
119 Status status = optimizer.Optimize(nullptr, item, &output);
120 TF_EXPECT_OK(status);
121 // Run the optimizer twice to make sure the rewrite is idempotent.
122 item.graph.Swap(&output);
123 status = optimizer.Optimize(nullptr, item, &output);
124 TF_EXPECT_OK(status);
125
126 EXPECT_EQ(item.graph.node_size(), output.node_size());
127 int found = 0;
128 for (int i = 0; i < item.graph.node_size(); ++i) {
129 const NodeDef& node = item.graph.node(i);
130 // "add" should get turned into a NoOp and removed.
131 EXPECT_NE("add", node.name());
132 if (node.name() == "id1") {
133 EXPECT_EQ("Identity", node.op());
134 EXPECT_EQ(2, node.input_size());
135 EXPECT_EQ("x", node.input(0));
136 EXPECT_EQ("^y", node.input(1));
137 ++found;
138 } else if (node.name() == "id2") {
139 EXPECT_EQ("Identity", node.op());
140 EXPECT_EQ(2, node.input_size());
141 EXPECT_EQ("y", node.input(0));
142 EXPECT_EQ("^x", node.input(1));
143 ++found;
144 }
145 }
146 EXPECT_EQ(2, found);
147 }
148
TEST_F(DependencyOptimizerTest,ChangeToNoop_RepeatedInput)149 TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
150 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
151 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
152 Output add = ops::Add(s.WithOpName("add"), x, x);
153 Output id1 =
154 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
155 GrapplerItem item;
156 TF_CHECK_OK(s.ToGraphDef(&item.graph));
157 item.fetch = {"id1"};
158
159 DependencyOptimizer optimizer;
160 GraphDef output;
161 Status status = optimizer.Optimize(nullptr, item, &output);
162 TF_EXPECT_OK(status);
163 // Run the optimizer twice to make sure the rewrite is idempotent.
164 item.graph.Swap(&output);
165 status = optimizer.Optimize(nullptr, item, &output);
166 TF_EXPECT_OK(status);
167 LOG(INFO) << output.DebugString();
168
169 EXPECT_EQ(item.graph.node_size(), output.node_size());
170 int found = 0;
171 for (int i = 0; i < item.graph.node_size(); ++i) {
172 const NodeDef& node = item.graph.node(i);
173 // "add" should get turned into a NoOp and removed.
174 EXPECT_NE("add", node.name());
175 if (node.name() == "id1") {
176 EXPECT_EQ("Identity", node.op());
177 EXPECT_EQ(1, node.input_size());
178 EXPECT_EQ("x", node.input(0));
179 ++found;
180 }
181 }
182 EXPECT_EQ(1, found);
183 }
184
TEST_F(DependencyOptimizerTest,ChangeToNoop_SwitchIdentity)185 TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
186 // This tests that we don't try to repeatedly add Identity nodes
187 // with names like "ConstantFoldingCtrl/foo/bar/switch_$port" when
188 // multiple nodes reading the same output of a Switch node get
189 // optimized (e.g. constant folded or turned into NoOps).
190 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
191 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
192 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
193 ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
194 // "neg" should be turned into a NoOp with a control dependency from
195 // the existing Identity node "ConstantFoldingCtrl/switch_1" and
196 // subsequently eliminated completely from the graph.
197 Output neg = ops::Neg(scope.WithOpName("neg"), s.output_true);
198 // c1 could be a result of constant folding some node fed by neg.
199 Output c1 = ops::Const(scope.WithOpName("c1").WithControlDependencies(neg),
200 {1.0f, 2.0f}, {1, 2});
201 Output ctrl_dep_id = ops::Identity(
202 scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
203 // c2 could be a result of constant folding a node fed by s, which also
204 // added the ctrl_dep_id node.
205 Output c2 =
206 ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id),
207 {1.0f, 2.0f}, {1, 2});
208 Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false);
209 Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id);
210
211 GrapplerItem item;
212 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
213 item.fetch.push_back("c1");
214 item.fetch.push_back("c2");
215 item.fetch.push_back("neg1");
216 item.fetch.push_back("neg2");
217
218 DependencyOptimizer optimizer;
219 GraphDef output;
220 Status status = optimizer.Optimize(nullptr, item, &output);
221 TF_EXPECT_OK(status);
222
223 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
224 for (int i = 0; i < output.node_size(); ++i) {
225 const NodeDef& node = output.node(i);
226 // "neg" should be eliminated.
227 EXPECT_NE("neg", node.name());
228 // A control dep from "^ConstantFoldingCtrl/switch_1"
229 // should be attached to "c1".
230 if (node.name() == "c1") {
231 EXPECT_EQ("Const", node.op());
232 EXPECT_EQ(1, node.input_size());
233 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
234 }
235 }
236 }
237
238 // TODO(rmlarsen): Add test to make sure we skip Switch and Merge.
TEST_F(DependencyOptimizerTest,ChangeToNoop_NoFetch)239 TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
240 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
241 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
242 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
243 Output add = ops::Add(s.WithOpName("add"), x, y);
244 Output id1 =
245 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
246 Output id2 =
247 ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
248
249 GrapplerItem item;
250 TF_CHECK_OK(s.ToGraphDef(&item.graph));
251
252 DependencyOptimizer optimizer;
253 GraphDef output;
254 Status status = optimizer.Optimize(nullptr, item, &output);
255 TF_EXPECT_OK(status);
256
257 TF_CHECK_OK(TopologicalSort(&item.graph));
258 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
259 }
260
TEST_F(DependencyOptimizerTest,RemoveNoOps_EmptyInputOrOutput)261 TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) {
262 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
263 Output x = ops::RandomUniform(s, {1, 2}, DT_FLOAT);
264 auto noop1 = ops::NoOp(s);
265 auto noop2 = ops::NoOp(s.WithControlDependencies(x));
266 Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x);
267
268 GrapplerItem item;
269 TF_CHECK_OK(s.ToGraphDef(&item.graph));
270 item.fetch.push_back("Identity");
271
272 DependencyOptimizer optimizer;
273 GraphDef output;
274 Status status = optimizer.Optimize(nullptr, item, &output);
275 TF_EXPECT_OK(status);
276 // Run the optimizer twice to make sure the rewrite is idempotent.
277 item.graph.Swap(&output);
278 status = optimizer.Optimize(nullptr, item, &output);
279 TF_EXPECT_OK(status);
280
281 EXPECT_EQ(item.graph.node_size(), output.node_size());
282 for (const NodeDef& node : output.node()) {
283 if (node.name() == "NoOp" || node.name() == "NoOp_1") {
284 EXPECT_EQ(0, node.input_size());
285 } else if (node.name() == "Identity") {
286 EXPECT_EQ(1, node.input_size());
287 EXPECT_EQ("RandomUniform", node.input(0));
288 }
289 }
290 }
291
TEST_F(DependencyOptimizerTest,RemoveNoOps_DeviceBoundaries)292 TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
293 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
294 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
295 DT_FLOAT);
296 Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
297 DT_FLOAT);
298 // NoOp with a single input- and two output dependencies.
299 auto noop = ops::NoOp(s.WithControlDependencies(x).WithDevice("/CPU:1"));
300 // NoOp with a two input- and a single output dependency.
301 auto noop_1 = ops::NoOp(
302 s.WithControlDependencies(x).WithControlDependencies(y).WithDevice(
303 "/CPU:0"));
304 Output id = ops::Identity(
305 s.WithControlDependencies({noop.operation}).WithDevice("/CPU:1"), x);
306 Output id_1 = ops::Identity(
307 s.WithControlDependencies({noop.operation, noop_1.operation})
308 .WithDevice("/CPU:1"),
309 y);
310
311 GrapplerItem item;
312 TF_CHECK_OK(s.ToGraphDef(&item.graph));
313 item.fetch.push_back("Identity");
314 item.fetch.push_back("Identity_1");
315
316 DependencyOptimizer optimizer;
317 GraphDef output;
318 Status status = optimizer.Optimize(nullptr, item, &output);
319 TF_EXPECT_OK(status);
320
321 // The optimization should be disabled to prevent increasing the number of
322 // nodes crossing device boundaries.
323 TF_CHECK_OK(TopologicalSort(&item.graph));
324 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
325 }
326
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_DeviceBoundaries)327 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_DeviceBoundaries) {
328 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
329 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
330 DT_FLOAT);
331 Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
332 DT_FLOAT);
333 // Identity with a single input- and two output dependencies.
334 auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
335 // Identity with a two input- and a single output dependency.
336 auto id_b = ops::Identity(
337 s.WithOpName("id_b").WithControlDependencies(y).WithDevice("/CPU:0"), x);
338
339 Output id =
340 ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:1"), id_b);
341 Output id_1 = ops::Identity(s.WithDevice("/CPU:1"), id_a);
342
343 GrapplerItem item;
344 TF_CHECK_OK(s.ToGraphDef(&item.graph));
345 item.fetch.push_back("Identity");
346 item.fetch.push_back("Identity_1");
347
348 DependencyOptimizer optimizer;
349 GraphDef output;
350 Status status = optimizer.Optimize(nullptr, item, &output);
351 TF_EXPECT_OK(status);
352
353 // The optimization should be disabled to prevent increasing the number of
354 // nodes crossing device boundaries.
355 TF_CHECK_OK(TopologicalSort(&item.graph));
356 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
357 }
358
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_IdenticalDevices)359 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_IdenticalDevices) {
360 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
362 DT_FLOAT);
363 auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
364 Output id =
365 ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:0"), id_a);
366
367 GrapplerItem item;
368 TF_CHECK_OK(s.ToGraphDef(&item.graph));
369 item.fetch.push_back("Identity");
370
371 DependencyOptimizer optimizer;
372 GraphDef output;
373 Status status = optimizer.Optimize(nullptr, item, &output);
374 TF_EXPECT_OK(status);
375
376 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
377 for (const NodeDef& node : output.node()) {
378 EXPECT_NE(node.name(), "id_a");
379 if (node.name() == "Identity") {
380 EXPECT_EQ(node.input(0), "x");
381 }
382 }
383 }
384
TEST_F(DependencyOptimizerTest,RemoveNoOps_SingleInputOrOutput)385 TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
386 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
387 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
388 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
389 // NoOp with a single input- and two output dependencies.
390 auto noop = ops::NoOp(s.WithControlDependencies(x));
391 // NoOp with a two input- and a single output dependency.
392 auto noop_1 =
393 ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y));
394 Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x);
395 Output id_1 = ops::Identity(
396 s.WithControlDependencies({noop.operation, noop_1.operation}), y);
397
398 GrapplerItem item;
399 TF_CHECK_OK(s.ToGraphDef(&item.graph));
400 item.fetch.push_back("Identity");
401 item.fetch.push_back("Identity_1");
402
403 DependencyOptimizer optimizer;
404 GraphDef output;
405 Status status = optimizer.Optimize(nullptr, item, &output);
406 TF_EXPECT_OK(status);
407 // Run the optimizer twice to make sure the rewrite is idempotent.
408 item.graph.Swap(&output);
409 status = optimizer.Optimize(nullptr, item, &output);
410 TF_EXPECT_OK(status);
411
412 EXPECT_EQ(item.graph.node_size(), output.node_size());
413 for (const NodeDef& node : output.node()) {
414 if (node.name() == "NoOp" || node.name() == "NoOp_1") {
415 EXPECT_EQ(0, node.input_size());
416 } else if (node.name() == "Identity") {
417 EXPECT_EQ("x", node.input(0));
418 } else if (node.name() == "Identity_1") {
419 EXPECT_EQ("y", node.input(0));
420 EXPECT_EQ("^x", node.input(1));
421 }
422 }
423 }
424
TEST_F(DependencyOptimizerTest,RemoveIdentity)425 TEST_F(DependencyOptimizerTest, RemoveIdentity) {
426 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
427 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
428 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
429 Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT);
430
431 // Identity nodes to be removed.
432 // Case a) with a single input- and multiple outputs.
433 auto id_a = ops::Identity(s.WithOpName("id_a"), x);
434 // Case b) with multiple inputs and a single output.
435 auto id_b = ops::Identity(
436 s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies(
437 z),
438 x);
439 // Case c) with two inputs and two outputs.
440 auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x);
441
442 // Output for Case a.
443 Output a_a = ops::Identity(s.WithOpName("a_a"), id_a);
444 Output a_b = ops::Identity(s.WithOpName("a_b"), id_a);
445 Output a_c =
446 ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z);
447 Output a_d =
448 ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z);
449 // Output for Case b.
450 Output b_a = ops::Identity(s.WithOpName("b_a"), id_b);
451 // Output for Case c.
452 Output c_a = ops::Identity(s.WithOpName("c_a"), id_c);
453 Output c_b =
454 ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z);
455
456 GrapplerItem item;
457 TF_CHECK_OK(s.ToGraphDef(&item.graph));
458 item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
459
460 DependencyOptimizer optimizer;
461 GraphDef output;
462 Status status = optimizer.Optimize(nullptr, item, &output);
463 TF_EXPECT_OK(status);
464
465 EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
466 int found = 0;
467 for (const NodeDef& node : output.node()) {
468 EXPECT_NE("id_a", node.name());
469 EXPECT_NE("id_b", node.name());
470 EXPECT_NE("id_c", node.name());
471 if (node.name() == "a_a" || node.name() == "a_b") {
472 EXPECT_EQ(1, node.input_size());
473 EXPECT_EQ("x", node.input(0));
474 ++found;
475 }
476 if (node.name() == "a_c" || node.name() == "a_d") {
477 EXPECT_EQ(2, node.input_size());
478 EXPECT_EQ("z", node.input(0));
479 EXPECT_EQ("^x", node.input(1));
480 ++found;
481 }
482 if (node.name() == "b_a") {
483 EXPECT_EQ(3, node.input_size());
484 EXPECT_EQ("x", node.input(0));
485 EXPECT_EQ("^y", node.input(1));
486 EXPECT_EQ("^z", node.input(2));
487 ++found;
488 }
489 if (node.name() == "c_a") {
490 EXPECT_EQ(2, node.input_size());
491 EXPECT_EQ("x", node.input(0));
492 EXPECT_EQ("^y", node.input(1));
493 ++found;
494 }
495 if (node.name() == "c_b") {
496 EXPECT_EQ(3, node.input_size());
497 EXPECT_EQ("z", node.input(0));
498 EXPECT_EQ("^x", node.input(1));
499 EXPECT_EQ("^y", node.input(2));
500 ++found;
501 }
502 }
503 EXPECT_EQ(found, 7);
504 }
505
TEST_F(DependencyOptimizerTest,RemoveIdentity_RepeatedInputs)506 TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
507 // Corner cases with repeated inputs.
508 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
509 ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL);
510 ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL);
511 ops::Switch sw(scope.WithOpName("switch"), x, x);
512 // id0 should be removed.
513 Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true);
514 // id1 should not be removed, since it would anchor a control dependency
515 // on the switch.
516 Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false);
517 Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0);
518 Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y);
519 Output or2 = ops::LogicalOr(
520 scope.WithOpName("or2").WithControlDependencies(id1), y, y);
521
522 GrapplerItem item;
523 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
524 item.fetch.push_back("or0");
525 item.fetch.push_back("or1");
526 item.fetch.push_back("or2");
527 DependencyOptimizer optimizer;
528 GraphDef output;
529 Status status = optimizer.Optimize(nullptr, item, &output);
530 TF_EXPECT_OK(status);
531
532 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
533 int found = 0;
534 for (const NodeDef& node : output.node()) {
535 EXPECT_NE("id0", node.name());
536 if (node.name() == "or0") {
537 EXPECT_EQ(2, node.input_size());
538 EXPECT_EQ("switch:1", node.input(0));
539 EXPECT_EQ("switch:1", node.input(1));
540 ++found;
541 }
542 if (node.name() == "or1") {
543 EXPECT_EQ(2, node.input_size());
544 EXPECT_EQ("switch:1", node.input(0));
545 EXPECT_EQ("y", node.input(1));
546 ++found;
547 }
548 if (node.name() == "or2") {
549 // or1 should be unchanged.
550 EXPECT_EQ(3, node.input_size());
551 EXPECT_EQ("y", node.input(0));
552 EXPECT_EQ("y", node.input(1));
553 EXPECT_EQ("^id1", node.input(2));
554 ++found;
555 }
556 }
557 EXPECT_EQ(found, 3);
558 }
559
TEST_F(DependencyOptimizerTest,Transitive_Reduction_Simple)560 TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
561 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
562 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
563 Output x = ops::Square(s.WithOpName("x"), c);
564 Output neg1 = ops::Neg(s.WithOpName("neg1"), x);
565 Output neg2 =
566 ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1);
567
568 GrapplerItem item;
569 TF_CHECK_OK(s.ToGraphDef(&item.graph));
570 item.fetch.push_back("neg2");
571 DependencyOptimizer optimizer;
572 GraphDef output;
573 Status status = optimizer.Optimize(nullptr, item, &output);
574 TF_EXPECT_OK(status);
575 EXPECT_EQ(4, output.node_size());
576 EXPECT_EQ("neg2", output.node(3).name());
577 EXPECT_EQ(1, output.node(3).input_size());
578 EXPECT_EQ("neg1", output.node(3).input(0));
579 }
580
TEST_F(DependencyOptimizerTest,ChangeToNoop_Identity)581 TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
582 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
583 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
584 Output id_after_var = ops::Identity(scope.WithOpName("id_after_var"), v_in);
585 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
586 ops::Switch s(
587 scope.WithOpName("switch").WithControlDependencies(id_after_var), v_in,
588 v_ctrl);
589 Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true);
590 Output grappler_added_id = ops::Identity(
591 scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
592 Output c1 = ops::Const(scope.WithOpName("c1")
593 .WithControlDependencies(id_after_var)
594 .WithControlDependencies(grappler_added_id),
595 {1.0f, 2.0f}, {1, 2});
596 Output id1 = ops::Identity(scope.WithOpName("id1"), c1);
597 Output id2 = ops::Identity(scope.WithOpName("id2"), id0);
598 Output fetch =
599 ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1);
600
601 GrapplerItem item;
602 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
603 item.fetch.push_back("c1");
604 item.fetch.push_back("id2");
605 item.fetch.push_back("fetch");
606
607 DependencyOptimizer optimizer;
608 GraphDef output;
609 Status status = optimizer.Optimize(nullptr, item, &output);
610 TF_EXPECT_OK(status);
611
612 EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
613 bool found = false;
614 for (int i = 0; i < output.node_size(); ++i) {
615 const NodeDef& node = output.node(i);
616 // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
617 // "id_after_var, nor "id2"" should be eliminated.
618 EXPECT_NE("id0", node.name());
619 EXPECT_NE("id1", node.name());
620 if (node.name() == "c1") {
621 EXPECT_EQ("Const", node.op());
622 EXPECT_EQ(1, node.input_size());
623 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
624 found = true;
625 }
626 }
627 EXPECT_TRUE(found);
628 }
629
TEST_F(DependencyOptimizerTest,IdentityInputs)630 TEST_F(DependencyOptimizerTest, IdentityInputs) {
631 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
632 Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
633 Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
634 auto s = ops::Switch(scope.WithOpName("s"), x, b);
635
636 // Identity nodes to be removed.
637 auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
638 auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
639
640 // Output
641 Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
642 Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
643
644 GrapplerItem item;
645 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
646 item.fetch = {"out1", "out2"};
647
648 DependencyOptimizer optimizer;
649 GraphDef output;
650 Status status = optimizer.Optimize(nullptr, item, &output);
651 TF_EXPECT_OK(status);
652
653 EXPECT_EQ(6, output.node_size());
654 EXPECT_EQ("out1", output.node(4).name());
655 EXPECT_EQ(1, output.node(4).input_size());
656 EXPECT_EQ("s", output.node(4).input(0));
657
658 EXPECT_EQ("out2", output.node(5).name());
659 EXPECT_EQ(1, output.node(5).input_size());
660 EXPECT_EQ("s:1", output.node(5).input(0));
661 }
662
TEST_F(DependencyOptimizerTest,RemoveIdentityN_SwitchInput)663 TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
664 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
665 Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
666 Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
667 auto s = ops::Switch(scope.WithOpName("s"), x, b);
668
669 // IdentityN nodes to be removed.
670 auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
671 auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
672 auto id_b =
673 ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
674
675 // Outputs
676 Output out1 = ops::Identity(scope.WithOpName("out1"), id_f[0]);
677 Output out2 = ops::Identity(scope.WithOpName("out2"), id_t[0]);
678 Output out3 = ops::Identity(scope.WithOpName("out3"), id_b[0]);
679 Output out4 = ops::Identity(scope.WithOpName("out4"), id_b[1]);
680
681 GrapplerItem item;
682 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
683 item.fetch = {"out1", "out2", "out3", "out4"};
684
685 DependencyOptimizer optimizer;
686 GraphDef output;
687 Status status = optimizer.Optimize(nullptr, item, &output);
688 TF_EXPECT_OK(status);
689
690 EXPECT_EQ(8, output.node_size());
691
692 auto out1_node = output.node(7);
693 EXPECT_EQ("out1", out1_node.name());
694 EXPECT_EQ(1, out1_node.input_size());
695 EXPECT_EQ("s", out1_node.input(0));
696
697 auto out2_node = output.node(4);
698 EXPECT_EQ("out2", out2_node.name());
699 EXPECT_EQ(1, out2_node.input_size());
700 EXPECT_EQ("s:1", out2_node.input(0));
701
702 auto out3_node = output.node(5);
703 EXPECT_EQ("out3", out3_node.name());
704 EXPECT_EQ(1, out3_node.input_size());
705 EXPECT_EQ("s", out3_node.input(0));
706
707 auto out4_node = output.node(6);
708 EXPECT_EQ("out4", out4_node.name());
709 EXPECT_EQ(1, out4_node.input_size());
710 EXPECT_EQ("s:1", out4_node.input(0));
711 }
712
TEST_F(DependencyOptimizerTest,DoNotRemoveIdentityNWithControlDependency)713 TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
714 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
715 Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
716 Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
717
718 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
719 Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
720 Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
721 auto out3 =
722 ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
723
724 GrapplerItem item;
725 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
726 item.fetch = {"out1", "out2", "out3"};
727
728 DependencyOptimizer optimizer;
729 GraphDef optimized_graph_def;
730 Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
731 TF_EXPECT_OK(status);
732
733 EXPECT_EQ(6, optimized_graph_def.node_size());
734 }
735
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnDifferentDevice)736 TEST_F(DependencyOptimizerTest,
737 Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
738 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
739 Output x_on_1 =
740 ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
741 Output one_on_3 =
742 ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
743 Output x_on_2 =
744 ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
745 Output result =
746 ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
747
748 GrapplerItem item;
749 TF_CHECK_OK(s.ToGraphDef(&item.graph));
750 item.fetch = {"result"};
751 DependencyOptimizer optimizer;
752 GraphDef output;
753 Status status = optimizer.Optimize(nullptr, item, &output);
754 TF_EXPECT_OK(status);
755
756 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
757 }
758
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnSameDevice)759 TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
760 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
761 Output x_on_1 =
762 ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
763 Output one_on_2 =
764 ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
765 Output x_on_2 =
766 ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
767 Output result =
768 ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
769
770 GrapplerItem item;
771 TF_CHECK_OK(s.ToGraphDef(&item.graph));
772 item.fetch = {"result"};
773 DependencyOptimizer optimizer;
774 GraphDef output;
775 Status status = optimizer.Optimize(nullptr, item, &output);
776 TF_EXPECT_OK(status);
777 LOG(INFO) << output.DebugString();
778 EXPECT_EQ(3, output.node_size());
779 for (const auto& node : output.node()) {
780 EXPECT_NE("x_on_2", node.name());
781 if (node.name() == "result") {
782 EXPECT_EQ("x_on_1", node.input(0));
783 }
784 }
785 }
786
TEST_F(DependencyOptimizerTest,RemoveGreaterEqualWithNoOp)787 TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
788 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
790 ops::Placeholder::Shape({}));
791 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
792 ops::Placeholder::Shape({}));
793 auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
794 auto noop =
795 ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
796 Output add = ops::Add(
797 s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
798 GrapplerItem item;
799 TF_CHECK_OK(s.ToGraphDef(&item.graph));
800
801 DependencyOptimizer optimizer;
802 GraphDef output;
803 item.fetch.push_back("z");
804 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
805
806 int count = 0;
807 for (const NodeDef& node : output.node()) {
808 if (node.name() == "x") {
809 count++;
810 EXPECT_EQ("Placeholder", node.op());
811 EXPECT_EQ(0, node.input_size());
812 } else if (node.name() == "y") {
813 count++;
814 EXPECT_EQ("Placeholder", node.op());
815 EXPECT_EQ(0, node.input_size());
816 } else if (node.name() == "GreaterEqual") {
817 count++;
818 } else if (node.name() == "NoOp") {
819 count++;
820 } else if (node.name() == "z") {
821 count++;
822 EXPECT_EQ("Add", node.op());
823 EXPECT_EQ(2, node.input_size());
824 EXPECT_EQ("x", node.input(0));
825 EXPECT_EQ("y", node.input(1));
826 }
827 }
828 EXPECT_EQ(3, count);
829 }
830
TEST_F(DependencyOptimizerTest,GroupCrossDeviceControlDeps)831 TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) {
832 GrapplerItem item;
833 {
834 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835 Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
836 {1, 2}, DT_FLOAT);
837 Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
838 {1, 2}, DT_FLOAT);
839 Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
840 {1, 2}, DT_FLOAT);
841 Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
842 {1, 2}, DT_FLOAT);
843 Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
844 {1, 2}, DT_FLOAT);
845 // Node with cross-device dependencies.
846 auto fetch = ops::Identity(
847 s.WithOpName("f")
848 .WithControlDependencies({a.op(), b.op(), c.op(), d.op()})
849 .WithDevice("/GPU:0"),
850 {e});
851
852 TF_CHECK_OK(s.ToGraphDef(&item.graph));
853 item.fetch.push_back("f");
854 }
855
856 GraphDef expected;
857 {
858 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
859 Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
860 {1, 2}, DT_FLOAT);
861 Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
862 {1, 2}, DT_FLOAT);
863 Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
864 {1, 2}, DT_FLOAT);
865 Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
866 {1, 2}, DT_FLOAT);
867 Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
868 {1, 2}, DT_FLOAT);
869 auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f")
870 .WithDevice("/CPU:1")
871 .WithControlDependencies({a.op(), c.op()}));
872 auto fetch =
873 ops::Identity(s.WithOpName("f")
874 .WithControlDependencies({b.op(), d.op(), noop})
875 .WithDevice("/GPU:0"),
876 {e});
877
878 TF_CHECK_OK(s.ToGraphDef(&expected));
879 }
880
881 DependencyOptimizer optimizer;
882 GraphDef output;
883 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
884 CompareGraphs(expected, output);
885
886 // Run the optimizer again to verify idempotence.
887 item.graph.Swap(&output);
888 output.Clear();
889 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
890 CompareGraphs(expected, output);
891 }
892
893 } // namespace
894 } // namespace grappler
895 } // namespace tensorflow
896