1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/grappler_item.h"
21 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
22 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
23 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/grappler/utils/topological_sort.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace {
33 
34 class DependencyOptimizerTest : public GrapplerTest {};
35 
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)36 void VerifyGraphsEqual(const GraphDef& original_graph,
37                        const GraphDef& optimized_graph, const string& func) {
38   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
39   for (int i = 0; i < original_graph.node_size(); ++i) {
40     const NodeDef& original = original_graph.node(i);
41     const NodeDef& optimized = optimized_graph.node(i);
42     EXPECT_EQ(original.name(), optimized.name()) << func;
43     EXPECT_EQ(original.op(), optimized.op()) << func;
44     EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
45     for (int j = 0; j < original.input_size(); ++j) {
46       EXPECT_EQ(original.input(j), optimized.input(j)) << func;
47     }
48   }
49 }
50 
TEST_F(DependencyOptimizerTest,NoOp)51 TEST_F(DependencyOptimizerTest, NoOp) {
52   // This trivial graph is so basic there's nothing to optimize.
53   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
54   GrapplerItem item;
55   CHECK(fake_input.NextItem(&item));
56 
57   DependencyOptimizer optimizer;
58   GraphDef output;
59   Status status = optimizer.Optimize(nullptr, item, &output);
60   TF_EXPECT_OK(status);
61 
62   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
63 }
64 
TEST_F(DependencyOptimizerTest,DependenciesDrivenByConstants)65 TEST_F(DependencyOptimizerTest, DependenciesDrivenByConstants) {
66   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
67   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
68   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
69   Output z = ops::Const(s.WithOpName("z"), {1.0f, 2.0f}, {1, 2});
70   Output add = ops::Add(s.WithOpName("add"), x, y);
71   Output id1 =
72       ops::Identity(s.WithOpName("id1").WithControlDependencies(x), add);
73   Output id2 = ops::Identity(
74       s.WithOpName("id2").WithControlDependencies(y).WithControlDependencies(z),
75       add);
76 
77   GrapplerItem item;
78   TF_CHECK_OK(s.ToGraphDef(&item.graph));
79   item.fetch.push_back("id1");
80   item.fetch.push_back("id2");
81 
82   DependencyOptimizer optimizer;
83   GraphDef output;
84   Status status = optimizer.Optimize(nullptr, item, &output);
85   TF_EXPECT_OK(status);
86   // Run the optimizer twice to make sure the rewrite is idempotent.
87   item.graph.Swap(&output);
88   status = optimizer.Optimize(nullptr, item, &output);
89   TF_EXPECT_OK(status);
90 
91   // The 'z' node should have been optimized away leaving only 5 nodes.
92   EXPECT_EQ(5, output.node_size());
93 
94   for (const NodeDef& node : item.graph.node()) {
95     if (node.name() == "id1" || node.name() == "id2") {
96       EXPECT_EQ(1, node.input_size());
97       EXPECT_EQ("add", node.input(0));
98     }
99   }
100 }
101 
TEST_F(DependencyOptimizerTest,ChangeToNoop)102 TEST_F(DependencyOptimizerTest, ChangeToNoop) {
103   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
104   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
105   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
106   Output add = ops::Add(s.WithOpName("add"), x, y);
107   Output id1 =
108       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
109   Output id2 =
110       ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
111 
112   GrapplerItem item;
113   TF_CHECK_OK(s.ToGraphDef(&item.graph));
114   item.fetch.push_back("id1");
115   item.fetch.push_back("id2");
116 
117   DependencyOptimizer optimizer;
118   GraphDef output;
119   Status status = optimizer.Optimize(nullptr, item, &output);
120   TF_EXPECT_OK(status);
121   // Run the optimizer twice to make sure the rewrite is idempotent.
122   item.graph.Swap(&output);
123   status = optimizer.Optimize(nullptr, item, &output);
124   TF_EXPECT_OK(status);
125 
126   EXPECT_EQ(item.graph.node_size(), output.node_size());
127   int found = 0;
128   for (int i = 0; i < item.graph.node_size(); ++i) {
129     const NodeDef& node = item.graph.node(i);
130     // "add" should get turned into a NoOp and removed.
131     EXPECT_NE("add", node.name());
132     if (node.name() == "id1") {
133       EXPECT_EQ("Identity", node.op());
134       EXPECT_EQ(2, node.input_size());
135       EXPECT_EQ("x", node.input(0));
136       EXPECT_EQ("^y", node.input(1));
137       ++found;
138     } else if (node.name() == "id2") {
139       EXPECT_EQ("Identity", node.op());
140       EXPECT_EQ(2, node.input_size());
141       EXPECT_EQ("y", node.input(0));
142       EXPECT_EQ("^x", node.input(1));
143       ++found;
144     }
145   }
146   EXPECT_EQ(2, found);
147 }
148 
TEST_F(DependencyOptimizerTest,ChangeToNoop_RepeatedInput)149 TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
150   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
151   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
152   Output add = ops::Add(s.WithOpName("add"), x, x);
153   Output id1 =
154       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
155   GrapplerItem item;
156   TF_CHECK_OK(s.ToGraphDef(&item.graph));
157   item.fetch = {"id1"};
158 
159   DependencyOptimizer optimizer;
160   GraphDef output;
161   Status status = optimizer.Optimize(nullptr, item, &output);
162   TF_EXPECT_OK(status);
163   // Run the optimizer twice to make sure the rewrite is idempotent.
164   item.graph.Swap(&output);
165   status = optimizer.Optimize(nullptr, item, &output);
166   TF_EXPECT_OK(status);
167   LOG(INFO) << output.DebugString();
168 
169   EXPECT_EQ(item.graph.node_size(), output.node_size());
170   int found = 0;
171   for (int i = 0; i < item.graph.node_size(); ++i) {
172     const NodeDef& node = item.graph.node(i);
173     // "add" should get turned into a NoOp and removed.
174     EXPECT_NE("add", node.name());
175     if (node.name() == "id1") {
176       EXPECT_EQ("Identity", node.op());
177       EXPECT_EQ(1, node.input_size());
178       EXPECT_EQ("x", node.input(0));
179       ++found;
180     }
181   }
182   EXPECT_EQ(1, found);
183 }
184 
TEST_F(DependencyOptimizerTest,ChangeToNoop_SwitchIdentity)185 TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
186   // This tests that we don't try to repeatedly add Identity nodes
187   // with names like "ConstantFoldingCtrl/foo/bar/switch_$port" when
188   // multiple nodes reading the same output of a Switch node get
189   // optimized (e.g. constant folded or turned into NoOps).
190   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
191   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
192   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
193   ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
194   // "neg" should be turned into a NoOp with a control dependency from
195   // the existing Identity node "ConstantFoldingCtrl/switch_1" and
196   // subsequently eliminated completely from the graph.
197   Output neg = ops::Neg(scope.WithOpName("neg"), s.output_true);
198   // c1 could be a result of constant folding some node fed by neg.
199   Output c1 = ops::Const(scope.WithOpName("c1").WithControlDependencies(neg),
200                          {1.0f, 2.0f}, {1, 2});
201   Output ctrl_dep_id = ops::Identity(
202       scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
203   // c2 could be a result of constant folding a node fed by s, which also
204   // added the ctrl_dep_id node.
205   Output c2 =
206       ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id),
207                  {1.0f, 2.0f}, {1, 2});
208   Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false);
209   Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id);
210 
211   GrapplerItem item;
212   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
213   item.fetch.push_back("c1");
214   item.fetch.push_back("c2");
215   item.fetch.push_back("neg1");
216   item.fetch.push_back("neg2");
217 
218   DependencyOptimizer optimizer;
219   GraphDef output;
220   Status status = optimizer.Optimize(nullptr, item, &output);
221   TF_EXPECT_OK(status);
222 
223   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
224   for (int i = 0; i < output.node_size(); ++i) {
225     const NodeDef& node = output.node(i);
226     // "neg" should be eliminated.
227     EXPECT_NE("neg", node.name());
228     // A control dep from "^ConstantFoldingCtrl/switch_1"
229     // should be attached to "c1".
230     if (node.name() == "c1") {
231       EXPECT_EQ("Const", node.op());
232       EXPECT_EQ(1, node.input_size());
233       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
234     }
235   }
236 }
237 
238 // TODO(rmlarsen): Add test to make sure we skip Switch and Merge.
TEST_F(DependencyOptimizerTest,ChangeToNoop_NoFetch)239 TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
240   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
241   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
242   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
243   Output add = ops::Add(s.WithOpName("add"), x, y);
244   Output id1 =
245       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
246   Output id2 =
247       ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
248 
249   GrapplerItem item;
250   TF_CHECK_OK(s.ToGraphDef(&item.graph));
251 
252   DependencyOptimizer optimizer;
253   GraphDef output;
254   Status status = optimizer.Optimize(nullptr, item, &output);
255   TF_EXPECT_OK(status);
256 
257   TF_CHECK_OK(TopologicalSort(&item.graph));
258   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
259 }
260 
TEST_F(DependencyOptimizerTest,RemoveNoOps_EmptyInputOrOutput)261 TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) {
262   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
263   Output x = ops::RandomUniform(s, {1, 2}, DT_FLOAT);
264   auto noop1 = ops::NoOp(s);
265   auto noop2 = ops::NoOp(s.WithControlDependencies(x));
266   Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x);
267 
268   GrapplerItem item;
269   TF_CHECK_OK(s.ToGraphDef(&item.graph));
270   item.fetch.push_back("Identity");
271 
272   DependencyOptimizer optimizer;
273   GraphDef output;
274   Status status = optimizer.Optimize(nullptr, item, &output);
275   TF_EXPECT_OK(status);
276   // Run the optimizer twice to make sure the rewrite is idempotent.
277   item.graph.Swap(&output);
278   status = optimizer.Optimize(nullptr, item, &output);
279   TF_EXPECT_OK(status);
280 
281   EXPECT_EQ(item.graph.node_size(), output.node_size());
282   for (const NodeDef& node : output.node()) {
283     if (node.name() == "NoOp" || node.name() == "NoOp_1") {
284       EXPECT_EQ(0, node.input_size());
285     } else if (node.name() == "Identity") {
286       EXPECT_EQ(1, node.input_size());
287       EXPECT_EQ("RandomUniform", node.input(0));
288     }
289   }
290 }
291 
TEST_F(DependencyOptimizerTest,RemoveNoOps_DeviceBoundaries)292 TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
293   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
294   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
295                                 DT_FLOAT);
296   Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
297                                 DT_FLOAT);
298   // NoOp with a single input- and two output dependencies.
299   auto noop = ops::NoOp(s.WithControlDependencies(x).WithDevice("/CPU:1"));
300   // NoOp with a two input- and a single output dependency.
301   auto noop_1 = ops::NoOp(
302       s.WithControlDependencies(x).WithControlDependencies(y).WithDevice(
303           "/CPU:0"));
304   Output id = ops::Identity(
305       s.WithControlDependencies({noop.operation}).WithDevice("/CPU:1"), x);
306   Output id_1 = ops::Identity(
307       s.WithControlDependencies({noop.operation, noop_1.operation})
308           .WithDevice("/CPU:1"),
309       y);
310 
311   GrapplerItem item;
312   TF_CHECK_OK(s.ToGraphDef(&item.graph));
313   item.fetch.push_back("Identity");
314   item.fetch.push_back("Identity_1");
315 
316   DependencyOptimizer optimizer;
317   GraphDef output;
318   Status status = optimizer.Optimize(nullptr, item, &output);
319   TF_EXPECT_OK(status);
320 
321   // The optimization should be disabled to prevent increasing the number of
322   // nodes crossing device boundaries.
323   TF_CHECK_OK(TopologicalSort(&item.graph));
324   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
325 }
326 
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_DeviceBoundaries)327 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_DeviceBoundaries) {
328   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
329   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
330                                 DT_FLOAT);
331   Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
332                                 DT_FLOAT);
333   // Identity with a single input- and two output dependencies.
334   auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
335   // Identity with a two input- and a single output dependency.
336   auto id_b = ops::Identity(
337       s.WithOpName("id_b").WithControlDependencies(y).WithDevice("/CPU:0"), x);
338 
339   Output id =
340       ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:1"), id_b);
341   Output id_1 = ops::Identity(s.WithDevice("/CPU:1"), id_a);
342 
343   GrapplerItem item;
344   TF_CHECK_OK(s.ToGraphDef(&item.graph));
345   item.fetch.push_back("Identity");
346   item.fetch.push_back("Identity_1");
347 
348   DependencyOptimizer optimizer;
349   GraphDef output;
350   Status status = optimizer.Optimize(nullptr, item, &output);
351   TF_EXPECT_OK(status);
352 
353   // The optimization should be disabled to prevent increasing the number of
354   // nodes crossing device boundaries.
355   TF_CHECK_OK(TopologicalSort(&item.graph));
356   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
357 }
358 
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_IdenticalDevices)359 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_IdenticalDevices) {
360   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
362                                 DT_FLOAT);
363   auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
364   Output id =
365       ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:0"), id_a);
366 
367   GrapplerItem item;
368   TF_CHECK_OK(s.ToGraphDef(&item.graph));
369   item.fetch.push_back("Identity");
370 
371   DependencyOptimizer optimizer;
372   GraphDef output;
373   Status status = optimizer.Optimize(nullptr, item, &output);
374   TF_EXPECT_OK(status);
375 
376   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
377   for (const NodeDef& node : output.node()) {
378     EXPECT_NE(node.name(), "id_a");
379     if (node.name() == "Identity") {
380       EXPECT_EQ(node.input(0), "x");
381     }
382   }
383 }
384 
TEST_F(DependencyOptimizerTest,RemoveNoOps_SingleInputOrOutput)385 TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
386   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
387   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
388   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
389   // NoOp with a single input- and two output dependencies.
390   auto noop = ops::NoOp(s.WithControlDependencies(x));
391   // NoOp with a two input- and a single output dependency.
392   auto noop_1 =
393       ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y));
394   Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x);
395   Output id_1 = ops::Identity(
396       s.WithControlDependencies({noop.operation, noop_1.operation}), y);
397 
398   GrapplerItem item;
399   TF_CHECK_OK(s.ToGraphDef(&item.graph));
400   item.fetch.push_back("Identity");
401   item.fetch.push_back("Identity_1");
402 
403   DependencyOptimizer optimizer;
404   GraphDef output;
405   Status status = optimizer.Optimize(nullptr, item, &output);
406   TF_EXPECT_OK(status);
407   // Run the optimizer twice to make sure the rewrite is idempotent.
408   item.graph.Swap(&output);
409   status = optimizer.Optimize(nullptr, item, &output);
410   TF_EXPECT_OK(status);
411 
412   EXPECT_EQ(item.graph.node_size(), output.node_size());
413   for (const NodeDef& node : output.node()) {
414     if (node.name() == "NoOp" || node.name() == "NoOp_1") {
415       EXPECT_EQ(0, node.input_size());
416     } else if (node.name() == "Identity") {
417       EXPECT_EQ("x", node.input(0));
418     } else if (node.name() == "Identity_1") {
419       EXPECT_EQ("y", node.input(0));
420       EXPECT_EQ("^x", node.input(1));
421     }
422   }
423 }
424 
TEST_F(DependencyOptimizerTest,RemoveIdentity)425 TEST_F(DependencyOptimizerTest, RemoveIdentity) {
426   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
427   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
428   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
429   Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT);
430 
431   // Identity nodes to be removed.
432   // Case a) with a single input- and multiple outputs.
433   auto id_a = ops::Identity(s.WithOpName("id_a"), x);
434   // Case b) with multiple inputs and a single output.
435   auto id_b = ops::Identity(
436       s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies(
437           z),
438       x);
439   // Case c) with two inputs and two outputs.
440   auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x);
441 
442   // Output for Case a.
443   Output a_a = ops::Identity(s.WithOpName("a_a"), id_a);
444   Output a_b = ops::Identity(s.WithOpName("a_b"), id_a);
445   Output a_c =
446       ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z);
447   Output a_d =
448       ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z);
449   // Output for Case b.
450   Output b_a = ops::Identity(s.WithOpName("b_a"), id_b);
451   // Output for Case c.
452   Output c_a = ops::Identity(s.WithOpName("c_a"), id_c);
453   Output c_b =
454       ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z);
455 
456   GrapplerItem item;
457   TF_CHECK_OK(s.ToGraphDef(&item.graph));
458   item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
459 
460   DependencyOptimizer optimizer;
461   GraphDef output;
462   Status status = optimizer.Optimize(nullptr, item, &output);
463   TF_EXPECT_OK(status);
464 
465   EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
466   int found = 0;
467   for (const NodeDef& node : output.node()) {
468     EXPECT_NE("id_a", node.name());
469     EXPECT_NE("id_b", node.name());
470     EXPECT_NE("id_c", node.name());
471     if (node.name() == "a_a" || node.name() == "a_b") {
472       EXPECT_EQ(1, node.input_size());
473       EXPECT_EQ("x", node.input(0));
474       ++found;
475     }
476     if (node.name() == "a_c" || node.name() == "a_d") {
477       EXPECT_EQ(2, node.input_size());
478       EXPECT_EQ("z", node.input(0));
479       EXPECT_EQ("^x", node.input(1));
480       ++found;
481     }
482     if (node.name() == "b_a") {
483       EXPECT_EQ(3, node.input_size());
484       EXPECT_EQ("x", node.input(0));
485       EXPECT_EQ("^y", node.input(1));
486       EXPECT_EQ("^z", node.input(2));
487       ++found;
488     }
489     if (node.name() == "c_a") {
490       EXPECT_EQ(2, node.input_size());
491       EXPECT_EQ("x", node.input(0));
492       EXPECT_EQ("^y", node.input(1));
493       ++found;
494     }
495     if (node.name() == "c_b") {
496       EXPECT_EQ(3, node.input_size());
497       EXPECT_EQ("z", node.input(0));
498       EXPECT_EQ("^x", node.input(1));
499       EXPECT_EQ("^y", node.input(2));
500       ++found;
501     }
502   }
503   EXPECT_EQ(found, 7);
504 }
505 
TEST_F(DependencyOptimizerTest,RemoveIdentity_RepeatedInputs)506 TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
507   // Corner cases with repeated inputs.
508   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
509   ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL);
510   ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL);
511   ops::Switch sw(scope.WithOpName("switch"), x, x);
512   // id0 should be removed.
513   Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true);
514   // id1 should not be removed, since it would anchor a control dependency
515   // on the switch.
516   Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false);
517   Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0);
518   Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y);
519   Output or2 = ops::LogicalOr(
520       scope.WithOpName("or2").WithControlDependencies(id1), y, y);
521 
522   GrapplerItem item;
523   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
524   item.fetch.push_back("or0");
525   item.fetch.push_back("or1");
526   item.fetch.push_back("or2");
527   DependencyOptimizer optimizer;
528   GraphDef output;
529   Status status = optimizer.Optimize(nullptr, item, &output);
530   TF_EXPECT_OK(status);
531 
532   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
533   int found = 0;
534   for (const NodeDef& node : output.node()) {
535     EXPECT_NE("id0", node.name());
536     if (node.name() == "or0") {
537       EXPECT_EQ(2, node.input_size());
538       EXPECT_EQ("switch:1", node.input(0));
539       EXPECT_EQ("switch:1", node.input(1));
540       ++found;
541     }
542     if (node.name() == "or1") {
543       EXPECT_EQ(2, node.input_size());
544       EXPECT_EQ("switch:1", node.input(0));
545       EXPECT_EQ("y", node.input(1));
546       ++found;
547     }
548     if (node.name() == "or2") {
549       // or1 should be unchanged.
550       EXPECT_EQ(3, node.input_size());
551       EXPECT_EQ("y", node.input(0));
552       EXPECT_EQ("y", node.input(1));
553       EXPECT_EQ("^id1", node.input(2));
554       ++found;
555     }
556   }
557   EXPECT_EQ(found, 3);
558 }
559 
TEST_F(DependencyOptimizerTest,Transitive_Reduction_Simple)560 TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
561   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
562   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
563   Output x = ops::Square(s.WithOpName("x"), c);
564   Output neg1 = ops::Neg(s.WithOpName("neg1"), x);
565   Output neg2 =
566       ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1);
567 
568   GrapplerItem item;
569   TF_CHECK_OK(s.ToGraphDef(&item.graph));
570   item.fetch.push_back("neg2");
571   DependencyOptimizer optimizer;
572   GraphDef output;
573   Status status = optimizer.Optimize(nullptr, item, &output);
574   TF_EXPECT_OK(status);
575   EXPECT_EQ(4, output.node_size());
576   EXPECT_EQ("neg2", output.node(3).name());
577   EXPECT_EQ(1, output.node(3).input_size());
578   EXPECT_EQ("neg1", output.node(3).input(0));
579 }
580 
TEST_F(DependencyOptimizerTest,ChangeToNoop_Identity)581 TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
582   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
583   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
584   Output id_after_var = ops::Identity(scope.WithOpName("id_after_var"), v_in);
585   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
586   ops::Switch s(
587       scope.WithOpName("switch").WithControlDependencies(id_after_var), v_in,
588       v_ctrl);
589   Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true);
590   Output grappler_added_id = ops::Identity(
591       scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
592   Output c1 = ops::Const(scope.WithOpName("c1")
593                              .WithControlDependencies(id_after_var)
594                              .WithControlDependencies(grappler_added_id),
595                          {1.0f, 2.0f}, {1, 2});
596   Output id1 = ops::Identity(scope.WithOpName("id1"), c1);
597   Output id2 = ops::Identity(scope.WithOpName("id2"), id0);
598   Output fetch =
599       ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1);
600 
601   GrapplerItem item;
602   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
603   item.fetch.push_back("c1");
604   item.fetch.push_back("id2");
605   item.fetch.push_back("fetch");
606 
607   DependencyOptimizer optimizer;
608   GraphDef output;
609   Status status = optimizer.Optimize(nullptr, item, &output);
610   TF_EXPECT_OK(status);
611 
612   EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
613   bool found = false;
614   for (int i = 0; i < output.node_size(); ++i) {
615     const NodeDef& node = output.node(i);
616     // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
617     // "id_after_var, nor "id2"" should be eliminated.
618     EXPECT_NE("id0", node.name());
619     EXPECT_NE("id1", node.name());
620     if (node.name() == "c1") {
621       EXPECT_EQ("Const", node.op());
622       EXPECT_EQ(1, node.input_size());
623       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
624       found = true;
625     }
626   }
627   EXPECT_TRUE(found);
628 }
629 
TEST_F(DependencyOptimizerTest,IdentityInputs)630 TEST_F(DependencyOptimizerTest, IdentityInputs) {
631   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
632   Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
633   Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
634   auto s = ops::Switch(scope.WithOpName("s"), x, b);
635 
636   // Identity nodes to be removed.
637   auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
638   auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
639 
640   // Output
641   Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
642   Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
643 
644   GrapplerItem item;
645   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
646   item.fetch = {"out1", "out2"};
647 
648   DependencyOptimizer optimizer;
649   GraphDef output;
650   Status status = optimizer.Optimize(nullptr, item, &output);
651   TF_EXPECT_OK(status);
652 
653   EXPECT_EQ(6, output.node_size());
654   EXPECT_EQ("out1", output.node(4).name());
655   EXPECT_EQ(1, output.node(4).input_size());
656   EXPECT_EQ("s", output.node(4).input(0));
657 
658   EXPECT_EQ("out2", output.node(5).name());
659   EXPECT_EQ(1, output.node(5).input_size());
660   EXPECT_EQ("s:1", output.node(5).input(0));
661 }
662 
TEST_F(DependencyOptimizerTest,RemoveIdentityN_SwitchInput)663 TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
664   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
665   Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
666   Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
667   auto s = ops::Switch(scope.WithOpName("s"), x, b);
668 
669   // IdentityN nodes to be removed.
670   auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
671   auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
672   auto id_b =
673       ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
674 
675   // Outputs
676   Output out1 = ops::Identity(scope.WithOpName("out1"), id_f[0]);
677   Output out2 = ops::Identity(scope.WithOpName("out2"), id_t[0]);
678   Output out3 = ops::Identity(scope.WithOpName("out3"), id_b[0]);
679   Output out4 = ops::Identity(scope.WithOpName("out4"), id_b[1]);
680 
681   GrapplerItem item;
682   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
683   item.fetch = {"out1", "out2", "out3", "out4"};
684 
685   DependencyOptimizer optimizer;
686   GraphDef output;
687   Status status = optimizer.Optimize(nullptr, item, &output);
688   TF_EXPECT_OK(status);
689 
690   EXPECT_EQ(8, output.node_size());
691 
692   auto out1_node = output.node(7);
693   EXPECT_EQ("out1", out1_node.name());
694   EXPECT_EQ(1, out1_node.input_size());
695   EXPECT_EQ("s", out1_node.input(0));
696 
697   auto out2_node = output.node(4);
698   EXPECT_EQ("out2", out2_node.name());
699   EXPECT_EQ(1, out2_node.input_size());
700   EXPECT_EQ("s:1", out2_node.input(0));
701 
702   auto out3_node = output.node(5);
703   EXPECT_EQ("out3", out3_node.name());
704   EXPECT_EQ(1, out3_node.input_size());
705   EXPECT_EQ("s", out3_node.input(0));
706 
707   auto out4_node = output.node(6);
708   EXPECT_EQ("out4", out4_node.name());
709   EXPECT_EQ(1, out4_node.input_size());
710   EXPECT_EQ("s:1", out4_node.input(0));
711 }
712 
TEST_F(DependencyOptimizerTest,DoNotRemoveIdentityNWithControlDependency)713 TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
714   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
715   Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
716   Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
717 
718   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
719   Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
720   Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
721   auto out3 =
722       ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
723 
724   GrapplerItem item;
725   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
726   item.fetch = {"out1", "out2", "out3"};
727 
728   DependencyOptimizer optimizer;
729   GraphDef optimized_graph_def;
730   Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
731   TF_EXPECT_OK(status);
732 
733   EXPECT_EQ(6, optimized_graph_def.node_size());
734 }
735 
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnDifferentDevice)736 TEST_F(DependencyOptimizerTest,
737        Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
738   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
739   Output x_on_1 =
740       ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
741   Output one_on_3 =
742       ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
743   Output x_on_2 =
744       ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
745   Output result =
746       ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
747 
748   GrapplerItem item;
749   TF_CHECK_OK(s.ToGraphDef(&item.graph));
750   item.fetch = {"result"};
751   DependencyOptimizer optimizer;
752   GraphDef output;
753   Status status = optimizer.Optimize(nullptr, item, &output);
754   TF_EXPECT_OK(status);
755 
756   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
757 }
758 
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnSameDevice)759 TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
760   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
761   Output x_on_1 =
762       ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
763   Output one_on_2 =
764       ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
765   Output x_on_2 =
766       ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
767   Output result =
768       ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
769 
770   GrapplerItem item;
771   TF_CHECK_OK(s.ToGraphDef(&item.graph));
772   item.fetch = {"result"};
773   DependencyOptimizer optimizer;
774   GraphDef output;
775   Status status = optimizer.Optimize(nullptr, item, &output);
776   TF_EXPECT_OK(status);
777   LOG(INFO) << output.DebugString();
778   EXPECT_EQ(3, output.node_size());
779   for (const auto& node : output.node()) {
780     EXPECT_NE("x_on_2", node.name());
781     if (node.name() == "result") {
782       EXPECT_EQ("x_on_1", node.input(0));
783     }
784   }
785 }
786 
TEST_F(DependencyOptimizerTest,RemoveGreaterEqualWithNoOp)787 TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
788   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
790                               ops::Placeholder::Shape({}));
791   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
792                               ops::Placeholder::Shape({}));
793   auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
794   auto noop =
795       ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
796   Output add = ops::Add(
797       s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
798   GrapplerItem item;
799   TF_CHECK_OK(s.ToGraphDef(&item.graph));
800 
801   DependencyOptimizer optimizer;
802   GraphDef output;
803   item.fetch.push_back("z");
804   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
805 
806   int count = 0;
807   for (const NodeDef& node : output.node()) {
808     if (node.name() == "x") {
809       count++;
810       EXPECT_EQ("Placeholder", node.op());
811       EXPECT_EQ(0, node.input_size());
812     } else if (node.name() == "y") {
813       count++;
814       EXPECT_EQ("Placeholder", node.op());
815       EXPECT_EQ(0, node.input_size());
816     } else if (node.name() == "GreaterEqual") {
817       count++;
818     } else if (node.name() == "NoOp") {
819       count++;
820     } else if (node.name() == "z") {
821       count++;
822       EXPECT_EQ("Add", node.op());
823       EXPECT_EQ(2, node.input_size());
824       EXPECT_EQ("x", node.input(0));
825       EXPECT_EQ("y", node.input(1));
826     }
827   }
828   EXPECT_EQ(3, count);
829 }
830 
TEST_F(DependencyOptimizerTest,GroupCrossDeviceControlDeps)831 TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) {
832   GrapplerItem item;
833   {
834     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835     Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
836                                   {1, 2}, DT_FLOAT);
837     Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
838                                   {1, 2}, DT_FLOAT);
839     Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
840                                   {1, 2}, DT_FLOAT);
841     Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
842                                   {1, 2}, DT_FLOAT);
843     Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
844                                   {1, 2}, DT_FLOAT);
845     // Node with cross-device dependencies.
846     auto fetch = ops::Identity(
847         s.WithOpName("f")
848             .WithControlDependencies({a.op(), b.op(), c.op(), d.op()})
849             .WithDevice("/GPU:0"),
850         {e});
851 
852     TF_CHECK_OK(s.ToGraphDef(&item.graph));
853     item.fetch.push_back("f");
854   }
855 
856   GraphDef expected;
857   {
858     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
859     Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
860                                   {1, 2}, DT_FLOAT);
861     Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
862                                   {1, 2}, DT_FLOAT);
863     Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
864                                   {1, 2}, DT_FLOAT);
865     Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
866                                   {1, 2}, DT_FLOAT);
867     Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
868                                   {1, 2}, DT_FLOAT);
869     auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f")
870                               .WithDevice("/CPU:1")
871                               .WithControlDependencies({a.op(), c.op()}));
872     auto fetch =
873         ops::Identity(s.WithOpName("f")
874                           .WithControlDependencies({b.op(), d.op(), noop})
875                           .WithDevice("/GPU:0"),
876                       {e});
877 
878     TF_CHECK_OK(s.ToGraphDef(&expected));
879   }
880 
881   DependencyOptimizer optimizer;
882   GraphDef output;
883   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
884   CompareGraphs(expected, output);
885 
886   // Run the optimizer again to verify idempotence.
887   item.graph.Swap(&output);
888   output.Clear();
889   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
890   CompareGraphs(expected, output);
891 }
892 
893 }  // namespace
894 }  // namespace grappler
895 }  // namespace tensorflow
896