1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 #include "tensorflow/cc/ops/math_ops.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
23 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
24 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
32 namespace {
33 
34 constexpr char kHoistFactorOptimizerDiv[] =
35     "ArithmeticOptimizer/HoistCommonFactor_Div_";
36 
37 constexpr char kHoistFactorOptimizerMul[] =
38     "ArithmeticOptimizer/HoistCommonFactor_Mul_";
39 
40 constexpr char kHoistFactorOptimizerAdd[] =
41     "ArithmeticOptimizer/HoistCommonFactor_Add_";
42 
43 constexpr char kSimplifyAggregationConst[] =
44     "ArithmeticOptimizer/SimplifyAggregation_Const_";
45 
46 constexpr char kSimplifyAggregationMul[] =
47     "ArithmeticOptimizer/SimplifyAggregation_Mul_";
48 
49 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)50 string HoistMulName(const string& name) {
51   return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
52 }
53 
54 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)55 string HoistDivName(const string& name) {
56   return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
57 }
58 
59 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)60 string HoistAddName(const string& name) {
61   return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
62 }
63 
64 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)65 string AggregationConstName(const string& name) {
66   return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
67 }
68 
69 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)70 string AggregationMulName(const string& name) {
71   return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
72 }
73 
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)74 void VerifyGraphsMatch(const GraphDef& original_graph,
75                        const GraphDef& optimized_graph, int line) {
76   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
77   for (int i = 0; i < original_graph.node_size(); ++i) {
78     const NodeDef& original = original_graph.node(i);
79     const NodeDef& optimized = optimized_graph.node(i);
80     EXPECT_EQ(original.name(), optimized.name()) << line;
81     EXPECT_EQ(original.op(), optimized.op()) << line;
82     EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
83     for (int j = 0; j < original.input_size(); ++j) {
84       EXPECT_EQ(original.input(j), optimized.input(j)) << line;
85     }
86   }
87 }
88 }  // namespace
89 
TEST_F(ArithmeticOptimizerTest,NoOp)90 TEST_F(ArithmeticOptimizerTest, NoOp) {
91   // This trivial graph is so basic there's nothing to optimize.
92   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
93   GrapplerItem item;
94   CHECK(fake_input.NextItem(&item));
95 
96   ArithmeticOptimizer optimizer;
97   GraphDef output;
98   Status status = optimizer.Optimize(nullptr, item, &output);
99   TF_EXPECT_OK(status);
100   VerifyGraphsMatch(item.graph, output, __LINE__);
101 }
102 
TEST_F(ArithmeticOptimizerTest,OpDedupping)103 TEST_F(ArithmeticOptimizerTest, OpDedupping) {
104   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
105   Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
106   Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
107   Output div = ops::Div(s.WithOpName("div"), c1, c2);
108   GrapplerItem item;
109   TF_CHECK_OK(s.ToGraphDef(&item.graph));
110   item.fetch = {"div"};
111 
112   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
113   EXPECT_EQ(1, tensors_expected.size());
114 
115   ArithmeticOptimizer optimizer;
116   GraphDef output;
117   OptimizeTwice(&optimizer, &item, &output);
118   NodeMap node_map(&output);
119   EXPECT_EQ(2, output.node_size());
120   const NodeDef* new_c1 = node_map.GetNode("c1");
121   ASSERT_NE(new_c1, nullptr);
122 
123   const NodeDef* new_div = node_map.GetNode("div");
124   ASSERT_NE(new_div, nullptr);
125   EXPECT_EQ(2, new_div->input_size());
126   EXPECT_EQ("c1", new_div->input(0));
127   EXPECT_EQ("c1", new_div->input(1));
128 
129   auto tensors = EvaluateNodes(output, item.fetch);
130   EXPECT_EQ(1, tensors.size());
131   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
132 }
133 
TEST_F(ArithmeticOptimizerTest,OpDeduppingAssertAndCheckNumerics)134 TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
135   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
136   Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
137   Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
138   auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
139   auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
140   auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
141   auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
142   Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
143                             {assert1.operation, assert2.operation}),
144                         check1, check2);
145   GrapplerItem item;
146   TF_CHECK_OK(s.ToGraphDef(&item.graph));
147   item.fetch = {"div"};
148   Tensor bool_t(DT_BOOL, TensorShape({}));
149   bool_t.scalar<bool>().setConstant(true);
150   auto tensors_expected =
151       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
152   EXPECT_EQ(1, tensors_expected.size());
153 
154   ArithmeticOptimizer optimizer;
155   GraphDef output;
156 
157   OptimizeTwice(&optimizer, &item, &output);
158   NodeMap node_map(&output);
159 
160   EXPECT_EQ(5, output.node_size());
161   const NodeDef* new_div = node_map.GetNode("div");
162   ASSERT_NE(new_div, nullptr);
163   EXPECT_EQ(4, new_div->input_size());
164   EXPECT_EQ("check1", new_div->input(0));
165   EXPECT_EQ("check1", new_div->input(1));
166   EXPECT_EQ("^assert1", new_div->input(2));
167   EXPECT_EQ("^assert1", new_div->input(3));
168 
169   auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
170   EXPECT_EQ(1, tensors.size());
171   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
172 }
173 
TEST_F(ArithmeticOptimizerTest,OpDedupCommutative)174 TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
175   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
176   Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
177   Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
178   Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
179   Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
180   Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
181   GrapplerItem item;
182   TF_CHECK_OK(s.ToGraphDef(&item.graph));
183   item.fetch = {"div1"};
184   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
185   EXPECT_EQ(1, tensors_expected.size());
186 
187   ArithmeticOptimizer optimizer;
188   GraphDef output;
189   OptimizeTwice(&optimizer, &item, &output);
190   NodeMap node_map(&output);
191 
192   EXPECT_EQ(4, output.node_size());
193   const NodeDef* new_c1 = node_map.GetNode("c1");
194   ASSERT_NE(new_c1, nullptr);
195   const NodeDef* new_c2 = node_map.GetNode("c2");
196   ASSERT_NE(new_c2, nullptr);
197   const NodeDef* new_mul1 = node_map.GetNode("mul1");
198   ASSERT_NE(new_mul1, nullptr);
199   EXPECT_EQ(2, new_mul1->input_size());
200   EXPECT_EQ("c1", new_mul1->input(0));
201   EXPECT_EQ("c2", new_mul1->input(1));
202   const NodeDef* new_div1 = node_map.GetNode("div1");
203   ASSERT_NE(new_div1, nullptr);
204   EXPECT_EQ(2, new_div1->input_size());
205   EXPECT_EQ("mul1", new_div1->input(0));
206   EXPECT_EQ("mul1", new_div1->input(1));
207 
208   auto tensors = EvaluateNodes(output, item.fetch);
209   EXPECT_EQ(1, tensors.size());
210   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
211 }
212 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)213 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
214   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
215   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
216   Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
217   Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
218   Output id = ops::Identity(s.WithOpName("id"), mul);
219 
220   GrapplerItem item;
221   item.fetch = {"id"};
222   TF_CHECK_OK(s.ToGraphDef(&item.graph));
223   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
224   EXPECT_EQ(1, tensors_expected.size());
225 
226   GraphDef output;
227   ArithmeticOptimizer optimizer;
228   EnableOnlyReplaceMulWithSquare(&optimizer);
229   OptimizeAndPrune(&optimizer, &item, &output);
230 
231   EXPECT_EQ(4, output.node_size());
232 
233   NodeMap node_map(&output);
234   const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
235   const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul"));
236 
237   ASSERT_NE(square_node, nullptr);
238   EXPECT_EQ("Square", square_node->op());
239   EXPECT_EQ("c", square_node->input(0));
240   EXPECT_EQ("^d", square_node->input(1));
241 
242   auto tensors = EvaluateNodes(output, item.fetch);
243   EXPECT_EQ(1, tensors.size());
244   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
245 }
246 
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_AdjacentNodes)247 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
248   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
249 
250   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
251   auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
252   auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
253   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
254   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
255   auto id = ops::Identity(s.WithOpName("id"), recip2);
256 
257   GrapplerItem item;
258   item.fetch = {"id"};
259   TF_CHECK_OK(s.ToGraphDef(&item.graph));
260   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
261   EXPECT_EQ(1, tensors_expected.size());
262 
263   GraphDef output;
264   ArithmeticOptimizer optimizer;
265   EnableOnlyRemoveInvolution(&optimizer);
266   OptimizeAndPrune(&optimizer, &item, &output);
267 
268   // Negation and Reciprocal nodes cancelled each other.
269   EXPECT_EQ(2, output.node_size());
270   EXPECT_EQ("id", output.node(1).name());
271   EXPECT_EQ("c", output.node(1).input(0));
272 
273   auto tensors = EvaluateNodes(output, item.fetch);
274   EXPECT_EQ(1, tensors.size());
275   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
276 }
277 
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_AroundValuePreservingChain)278 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AroundValuePreservingChain) {
279   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
280 
281   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
282   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
283   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
284   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
285   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
286   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
287 
288   std::vector<string> fetch = {"id2"};
289 
290   GrapplerItem item;
291   item.fetch = fetch;
292   TF_CHECK_OK(s.ToGraphDef(&item.graph));
293   auto tensors_expected = EvaluateNodes(item.graph, fetch);
294   EXPECT_EQ(1, tensors_expected.size());
295 
296   GraphDef output;
297   ArithmeticOptimizer optimizer;
298   EnableOnlyRemoveInvolution(&optimizer);
299   OptimizeTwiceAndPrune(&optimizer, &item, &output);
300 
301   // Check that Reciprocal nodes were removed from the graph.
302   EXPECT_EQ(3, output.node_size());
303 
304   // And const directly flows into squeeze.
305   int found = 0;
306   for (const NodeDef& node : output.node()) {
307     if (node.name() == "squeeze") {
308       EXPECT_EQ("c", node.input(0));
309       found++;
310     } else if (node.name() == "id2") {
311       EXPECT_EQ("squeeze", node.input(0));
312       found++;
313     }
314   }
315   EXPECT_EQ(2, found);
316 
317   auto tensors = EvaluateNodes(output, fetch);
318   EXPECT_EQ(1, tensors.size());
319   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
320 }
321 
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_SkipControlDependencies)322 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_SkipControlDependencies) {
323   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
324 
325   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
326   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
327   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
328   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
329   auto recip2 = ops::Reciprocal(
330       s.WithOpName("recip2").WithControlDependencies(squeeze), c);
331   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
332 
333   std::vector<string> fetch = {"id2"};
334 
335   GrapplerItem item;
336   item.fetch = fetch;
337   TF_CHECK_OK(s.ToGraphDef(&item.graph));
338 
339   auto tensors_expected = EvaluateNodes(item.graph, fetch);
340   EXPECT_EQ(1, tensors_expected.size());
341 
342   GraphDef output;
343   ArithmeticOptimizer optimizer;
344   EnableOnlyRemoveInvolution(&optimizer);
345   OptimizeTwice(&optimizer, &item, &output);  // do not prune in this test
346 
347   // The optimizer should be a noop.
348   VerifyGraphsMatch(item.graph, output, __LINE__);
349 
350   auto tensors = EvaluateNodes(output, fetch);
351   EXPECT_EQ(1, tensors.size());
352   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
353 }
354 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)355 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
356   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
357   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
358   Output add = ops::Add(s.WithOpName("add"), x, x);
359   Output id = ops::Identity(s.WithOpName("id"), add);
360 
361   GrapplerItem item;
362   item.fetch = {"id"};
363   TF_CHECK_OK(s.ToGraphDef(&item.graph));
364 
365   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
366   EXPECT_EQ(1, tensors_expected.size());
367 
368   ArithmeticOptimizer optimizer;
369   GraphDef output;
370   OptimizeTwice(&optimizer, &item, &output);
371   NodeMap node_map(&output);
372 
373   EXPECT_EQ(5, output.node_size());
374 
375   const string optimized_const_name = AggregationConstName("add");
376   const string optimized_mul_name = AggregationMulName("add");
377 
378   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
379   ASSERT_NE(new_const, nullptr);
380   EXPECT_EQ("^x", new_const->input(0));
381   EXPECT_EQ(string("\0\0\0@", 4),
382             new_const->attr().at("value").tensor().tensor_content());
383 
384   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
385   ASSERT_NE(new_mul, nullptr);
386   EXPECT_EQ(optimized_const_name, new_mul->input(0));
387   EXPECT_EQ("x", new_mul->input(1));
388 
389   const NodeDef* new_id = node_map.GetNode("id");
390   ASSERT_NE(new_id, nullptr);
391   EXPECT_EQ(optimized_mul_name, new_id->input(0));
392 
393   auto tensors = EvaluateNodes(output, item.fetch);
394   EXPECT_EQ(1, tensors.size());
395   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
396 }
397 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)398 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
399   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
400   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
401   Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
402   Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
403   Output id = ops::Identity(s.WithOpName("id"), add);
404 
405   GrapplerItem item;
406   TF_CHECK_OK(s.ToGraphDef(&item.graph));
407 
408   std::vector<string> fetch = {"id"};
409   auto tensors_expected = EvaluateNodes(item.graph, fetch);
410   EXPECT_EQ(1, tensors_expected.size());
411 
412   ArithmeticOptimizer optimizer;
413   GraphDef output;
414   OptimizeTwice(&optimizer, &item, &output);
415   NodeMap node_map(&output);
416 
417   EXPECT_EQ(6, output.node_size());
418 
419   const string optimized_const_name = AggregationConstName("add");
420   const string optimized_mul_name = AggregationMulName("add");
421 
422   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
423   ASSERT_NE(new_const, nullptr);
424   EXPECT_EQ("^x", new_const->input(0));
425   EXPECT_EQ(string("\0\0\0@", 4),
426             new_const->attr().at("value").tensor().tensor_content());
427 
428   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
429   ASSERT_NE(new_mul, nullptr);
430   EXPECT_EQ(optimized_const_name, new_mul->input(0));
431   EXPECT_EQ("x", new_mul->input(1));
432   EXPECT_EQ("^y", new_mul->input(2));
433 
434   const NodeDef* new_id = node_map.GetNode("id");
435   ASSERT_NE(new_id, nullptr);
436   EXPECT_EQ(optimized_mul_name, new_id->input(0));
437 
438   auto tensors = EvaluateNodes(output, fetch);
439   EXPECT_EQ(1, tensors.size());
440   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
441 }
442 
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)443 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
444   // Test case from b/69059093.
445   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
446   Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
447   Output add = ops::Add(s.WithOpName("Add"), p, p);
448   Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
449   Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
450   Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
451   Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
452   Output id = ops::Identity(s.WithOpName("id"), add6);
453 
454   GrapplerItem item;
455   TF_CHECK_OK(s.ToGraphDef(&item.graph));
456 
457   const std::vector<string> devices{
458       "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
459       "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
460   };
461   for (int i = 0; i < item.graph.node_size(); ++i) {
462     item.graph.mutable_node(i)->set_device(devices[i]);
463   }
464 
465   ArithmeticOptimizer optimizer;
466   DisableAddToAddNCombining(&optimizer);
467 
468   GraphDef output;
469   OptimizeTwice(&optimizer, &item, &output);
470 
471   // We expect the following rewrite(s) to occur:
472   //
473   // Mul(p,
474   //     Add_6(Add_4(Const(2), Const(2)),
475   //           Add_5(Const(2), Const(2))))
476   NodeMap node_map(&output);
477 
478   EXPECT_EQ(17, output.node_size());
479 
480   const NodeDef* id_node = node_map.GetNode("id");
481   ASSERT_NE(id_node, nullptr);
482   EXPECT_EQ(1, id_node->input_size());
483   EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
484 
485   const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
486   ASSERT_NE(mul_node, nullptr);
487   EXPECT_EQ(2, mul_node->input_size());
488   EXPECT_EQ("Placeholder", mul_node->input(0));
489   EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
490 
491   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
492   ASSERT_NE(add_6_node, nullptr);
493   EXPECT_EQ(2, add_6_node->input_size());
494   EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
495   EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
496 
497   const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
498   ASSERT_NE(add_4_node, nullptr);
499   EXPECT_EQ("Add", add_4_node->op());
500   EXPECT_EQ(2, add_4_node->input_size());
501   EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0));
502   EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1));
503 
504   const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
505   ASSERT_NE(add_5_node, nullptr);
506   EXPECT_EQ("Add", add_5_node->op());
507   EXPECT_EQ(2, add_5_node->input_size());
508   EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0));
509   EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1));
510 
511   const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
512   ASSERT_NE(add_const_node, nullptr);
513   EXPECT_EQ("Const", add_const_node->op());
514   EXPECT_EQ(1, add_const_node->input_size());
515   EXPECT_EQ("^Placeholder", add_const_node->input(0));
516 
517   const NodeDef* add_1_const_node =
518       node_map.GetNode(AggregationConstName("Add_1"));
519   ASSERT_NE(add_1_const_node, nullptr);
520   EXPECT_EQ("Const", add_1_const_node->op());
521   EXPECT_EQ(1, add_1_const_node->input_size());
522   EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
523 }
524 
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)525 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
526   for (bool matching_shapes : {true, false}) {
527     for (bool use_addn : {true, false}) {
528       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
529       Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
530       Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
531       Output y2 = matching_shapes
532                       ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
533                       : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
534       Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
535       Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
536       Output id =
537           use_addn ? ops::Identity(s.WithOpName("id"),
538                                    ops::AddN(s.WithOpName("add"), {mul1, mul2}))
539                    : ops::Identity(s.WithOpName("id"),
540                                    ops::Add(s.WithOpName("add"), mul1, mul2));
541 
542       GrapplerItem item;
543       item.fetch = {"id"};
544       TF_CHECK_OK(s.ToGraphDef(&item.graph));
545       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
546       EXPECT_EQ(1, tensors_expected.size());
547       ArithmeticOptimizer optimizer;
548       EnableOnlyHoistCommonFactor(&optimizer);
549 
550       GraphDef output;
551       OptimizeTwice(&optimizer, &item, &output);
552 
553       // We expect the following rewrite(s) to occur:
554       //
555       //        Add                 Mul
556       //      /    \               /   \
557       //    Mul    Mul       ->   x    Add
558       //    / \    / \                 / \
559       //   x  y1  y2  x              y1   y2
560       //
561       // If "root" op is AddN and shapes does not match, this rewrite is not
562       // possible and graph should stay intact.
563       NodeMap node_map(&output);
564 
565       if (use_addn && !matching_shapes) {
566         VerifyGraphsMatch(item.graph, output, __LINE__);
567       } else {
568         EXPECT_EQ(9, output.node_size());
569 
570         const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
571         ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
572         EXPECT_EQ("y1", new_add_node->input(0));
573         EXPECT_EQ("y2", new_add_node->input(1));
574 
575         const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
576         ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
577         EXPECT_EQ("x", new_mul_node->input(0));
578         EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
579 
580         const NodeDef* id_node = node_map.GetNode("id");
581         ASSERT_NE(id_node, nullptr) << "Id node not found";
582         EXPECT_EQ("id", id_node->name());
583         EXPECT_EQ(HoistMulName("add"), id_node->input(0));
584       }
585       auto tensors = EvaluateNodes(output, item.fetch);
586       EXPECT_EQ(1, tensors.size());
587       test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
588     }
589   }
590 }
591 
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)592 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
593   for (bool matching_shapes : {true, false}) {
594     for (bool use_addn : {true, false}) {
595       for (bool use_ints : {true, false}) {
596         tensorflow::Scope s = tensorflow::Scope::NewRootScope();
597         Output x = use_ints
598                        ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
599                        : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
600         Output y1 = use_ints
601                         ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
602                         : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
603         Output y2;
604         if (matching_shapes) {
605           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
606                         : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
607         } else {
608           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
609                         : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
610         }
611         Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
612         Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
613         Output id =
614             use_addn
615                 ? ops::Identity(s.WithOpName("id"),
616                                 ops::AddN(s.WithOpName("add"), {div1, div2}))
617                 : ops::Identity(s.WithOpName("id"),
618                                 ops::Add(s.WithOpName("add"), div1, div2));
619 
620         GrapplerItem item;
621         item.fetch = {"id"};
622         TF_CHECK_OK(s.ToGraphDef(&item.graph));
623 
624         auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
625         EXPECT_EQ(1, tensors_expected.size());
626 
627         ArithmeticOptimizer optimizer;
628         EnableOnlyHoistCommonFactor(&optimizer);
629 
630         GraphDef output;
631         OptimizeTwice(&optimizer, &item, &output);
632 
633         // We expect the following rewrite(s) to occur:
634         //
635         //        Add                 Div
636         //      /    \               /   \
637         //    Div    Div       ->  Add    x
638         //    / \    / \           / \
639         //   y1  x  y2  x         y1  y2
640         //
641         // If "root" op is AddN and shapes does not match, this rewrite is not
642         // possible and graph should stay intact.
643         NodeMap node_map(&output);
644 
645         if ((use_addn && !matching_shapes) || use_ints) {
646           VerifyGraphsMatch(item.graph, output, __LINE__);
647         } else {
648           EXPECT_EQ(9, output.node_size());
649 
650           const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
651           ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
652           EXPECT_EQ("y1", new_add_node->input(0));
653           EXPECT_EQ("y2", new_add_node->input(1));
654 
655           const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
656           ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
657           EXPECT_EQ(new_add_node->name(), new_div_node->input(0));
658           EXPECT_EQ("x", new_div_node->input(1));
659 
660           const NodeDef* id_node = node_map.GetNode("id");
661           ASSERT_TRUE(id_node != nullptr) << "Id node not found";
662           EXPECT_EQ("id", id_node->name());
663           EXPECT_EQ(HoistDivName("add"), id_node->input(0));
664         }
665         auto tensors = EvaluateNodes(output, item.fetch);
666         EXPECT_EQ(1, tensors.size());
667         if (use_ints) {
668           test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
669         } else {
670           test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
671         }
672       }
673     }
674   }
675 }
676 
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)677 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
678   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
680   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
681   Output z = ops::Complex(s.WithOpName("z"), re, im);
682   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
683   Output conj = ops::Conj(s.WithOpName("conj"), z);
684   Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
685 
686   GrapplerItem item;
687   item.fetch = {"trans"};
688   TF_CHECK_OK(s.ToGraphDef(&item.graph));
689 
690   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
691   EXPECT_EQ(1, tensors_expected.size());
692 
693   ArithmeticOptimizer optimizer;
694   GraphDef output;
695   OptimizeTwice(&optimizer, &item, &output);
696   NodeMap node_map(&output);
697 
698   EXPECT_EQ(7, output.node_size());
699 
700   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
701   const string optimized_name = strings::StrCat(p, "_", "trans");
702 
703   const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
704   ASSERT_NE(trans_fused_node, nullptr);
705   EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
706   EXPECT_EQ("z", trans_fused_node->input(0));
707   EXPECT_EQ("perm", trans_fused_node->input(1));
708 
709   auto tensors = EvaluateNodes(output, item.fetch);
710   EXPECT_EQ(1, tensors.size());
711   test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
712 }
713 
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)714 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
715   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
716 
717   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
718   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
719   Output z = ops::Complex(s.WithOpName("z"), re, im);
720   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
721   Output conj = ops::Conj(s.WithOpName("conj"), z);
722   Output transp =
723       ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
724 
725   GrapplerItem item;
726   item.fetch = {"conjugate_trans"};
727   TF_CHECK_OK(s.ToGraphDef(&item.graph));
728 
729   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
730   EXPECT_EQ(1, tensors_expected.size());
731 
732   ArithmeticOptimizer optimizer;
733   GraphDef output;
734   OptimizeTwice(&optimizer, &item, &output);
735   NodeMap node_map(&output);
736 
737   EXPECT_EQ(7, output.node_size());
738 
739   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
740   const string optimized_name = strings::StrCat(p, "_", "conjugate_trans");
741 
742   const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
743   ASSERT_NE(conjugate_trans_fused_node, nullptr);
744   EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
745   EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
746   EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
747 
748   auto tensors = EvaluateNodes(output, item.fetch);
749   EXPECT_EQ(1, tensors.size());
750   test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
751 }
752 
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)753 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
754   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
755   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
756   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
757   Output z = ops::Complex(s.WithOpName("z"), re, im);
758   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
759   Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
760   Output conj = ops::Conj(s.WithOpName("conj"), trans);
761 
762   GrapplerItem item;
763   item.fetch = {"conj"};
764   TF_CHECK_OK(s.ToGraphDef(&item.graph));
765 
766   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
767   EXPECT_EQ(1, tensors_expected.size());
768 
769   ArithmeticOptimizer optimizer;
770   GraphDef output;
771   OptimizeTwice(&optimizer, &item, &output);
772   NodeMap node_map(&output);
773 
774   EXPECT_EQ(7, output.node_size());
775 
776   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
777   const string optimized_name = strings::StrCat(p, "_", "conj");
778 
779   const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
780   ASSERT_NE(conj_fused_node, nullptr);
781   EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
782   EXPECT_EQ("z", conj_fused_node->input(0));
783   EXPECT_EQ("perm", conj_fused_node->input(1));
784 
785   auto tensors = EvaluateNodes(output, item.fetch);
786   EXPECT_EQ(1, tensors.size());
787   test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
788 }
789 
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)790 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
791   for (const string matmul_type : {"MatMul", "SparseMatMul", "BatchMatMul"}) {
792     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
793 
794     Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
795     Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
796     Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
797     Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
798     Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
799 
800     auto matmul_op = s.WithOpName("matmul");
801     if (matmul_type == "MatMul") {
802       Output matmul = ops::MatMul(matmul_op, trans_a, trans_b);
803     } else if (matmul_type == "SparseMatMul") {
804       Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
805     } else if (matmul_type == "BatchMatMul") {
806       Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
807     }
808 
809     GrapplerItem item;
810     item.fetch = {"matmul"};
811     TF_CHECK_OK(s.ToGraphDef(&item.graph));
812 
813     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
814     EXPECT_EQ(1, tensors_expected.size());
815 
816     ArithmeticOptimizer optimizer;
817     EnableOnlyFoldTransposeIntoMatMul(&optimizer);
818     GraphDef output;
819     OptimizeTwice(&optimizer, &item, &output);
820     NodeMap node_map(&output);
821 
822     EXPECT_EQ(7, output.node_size());
823 
824     const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
825     const string optimized_name = strings::StrCat(p, "_", "matmul");
826 
827     const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
828     ASSERT_NE(matmul_fused_node, nullptr);
829     EXPECT_EQ("a", matmul_fused_node->input(0));
830     EXPECT_EQ("b", matmul_fused_node->input(1));
831 
832     if (matmul_type == "BatchMatMul") {
833       EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
834       EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
835     } else {
836       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
837       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
838     }
839 
840     auto tensors = EvaluateNodes(output, item.fetch);
841     EXPECT_EQ(1, tensors.size());
842     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
843   }
844 }
845 
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)846 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
847   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
848 
849   Output re_a =
850       ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
851   Output im_a =
852       ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
853   Output re_b =
854       ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
855   Output im_b =
856       ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
857   Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
858   Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
859   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
860   Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
861   Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
862   Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
863 
864   GrapplerItem item;
865   item.fetch = {"matmul"};
866   TF_CHECK_OK(s.ToGraphDef(&item.graph));
867 
868   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
869   EXPECT_EQ(1, tensors_expected.size());
870 
871   ArithmeticOptimizer optimizer;
872   GraphDef output;
873   OptimizeTwice(&optimizer, &item, &output);
874 
875   NodeMap node_map(&output);
876   ASSERT_EQ(11, output.node_size());
877 
878   const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
879   const string optimized_name = strings::StrCat(p, "_", "matmul");
880 
881   const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
882   ASSERT_NE(optimized_matmul, nullptr);
883   EXPECT_EQ("a", optimized_matmul->input(0));
884   EXPECT_EQ("b", optimized_matmul->input(1));
885   EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
886   EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
887 
888   auto tensors = EvaluateNodes(output, item.fetch);
889   EXPECT_EQ(1, tensors.size());
890   test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
891 }
892 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_IdentityReshape)893 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) {
894   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
895   Output inputs =
896       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
897   Output inputs_shape = ops::Shape(s, inputs);
898   // The target shape of the reshape is the concatenation of `batch_size` and
899   // [3,28,28].
900   Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
901                                  ops::Const(s, {1}, {1}));
902   Output target_shape = ops::Concat(
903       s.WithOpName("target_shape"),
904       {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
905   Output reshape = ops::Reshape(s, inputs, target_shape);
906   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
907 
908   GrapplerItem item;
909   item.fetch = {"outputs"};
910   TF_CHECK_OK(s.ToGraphDef(&item.graph));
911   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
912   auto tensors_expected =
913       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
914   EXPECT_EQ(1, tensors_expected.size());
915 
916   GraphDef output;
917   ArithmeticOptimizer optimizer;
918   EnableOnlyRemoveRedundantReshape(&optimizer);
919   OptimizeTwiceAndPrune(&optimizer, &item, &output);
920 
921   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
922   auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
923   EXPECT_EQ(1, tensors.size());
924   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
925 }
926 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes)927 TEST_F(ArithmeticOptimizerTest,
928        RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) {
929   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
930   Output inputs =
931       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
932   Output inputs_shape = ops::Shape(s, inputs);
933   // The target shape of the reshape is the concatenation of `batch_size`, 3,
934   // `height, and `width`.
935   Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
936                                  ops::Const(s, {1}, {1}));
937   Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
938                              ops::Const(s, {1}, {1}));
939   Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
940                             ops::Const(s, {1}, {1}));
941   Output target_shape =
942       ops::Concat(s.WithOpName("target_shape"),
943                   {batch_size, ops::Const(s, {3}, {1}), height, width},
944                   ops::Const(s, {0}, {}));
945   Output reshape = ops::Reshape(s, inputs, target_shape);
946   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
947 
948   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
949   GrapplerItem item;
950   item.fetch = {"outputs"};
951   item.feed = {{"Placeholder", x_t}};
952   TF_CHECK_OK(s.ToGraphDef(&item.graph));
953 
954   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
955   EXPECT_EQ(1, tensors_expected.size());
956 
957   GraphDef output;
958   // Assume valid feed shape in aggressive mode.
959   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
960   EnableOnlyRemoveRedundantReshape(&optimizer);
961   OptimizeTwiceAndPrune(&optimizer, &item, &output);
962 
963   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
964   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
965   EXPECT_EQ(1, tensors.size());
966   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
967 }
968 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotAssumeValidFeeds)969 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) {
970   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
971   Output inputs =
972       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
973   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
974   Output reshape = ops::Reshape(s, inputs, target_shape);
975   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
976 
977   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
978   GrapplerItem item;
979   item.fetch = {"outputs"};
980   item.feed = {{"Placeholder", x_t}};
981   TF_CHECK_OK(s.ToGraphDef(&item.graph));
982 
983   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
984   EXPECT_EQ(1, tensors_expected.size());
985 
986   GraphDef output;
987   ArithmeticOptimizer optimizer;
988   EnableOnlyRemoveRedundantReshape(&optimizer);
989   OptimizeTwiceAndPrune(&optimizer, &item, &output);
990 
991   // The reshape is preserved because the shape of the placeholder can be
992   // different from the shape of the actual feed.
993   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
994 
995   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
996   EXPECT_EQ(1, tensors.size());
997   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
998 }
999 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode)1000 TEST_F(ArithmeticOptimizerTest,
1001        RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) {
1002   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1003   Output inputs =
1004       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1005   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1006   Output reshape = ops::Reshape(s, inputs, target_shape);
1007   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1008 
1009   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1010   GrapplerItem item;
1011   item.fetch = {"outputs"};
1012   item.feed = {{"Placeholder", x_t}};
1013   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1014 
1015   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1016   EXPECT_EQ(1, tensors_expected.size());
1017 
1018   GraphDef output;
1019   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1020   EnableOnlyRemoveRedundantReshape(&optimizer);
1021   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1022 
1023   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
1024   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1025   EXPECT_EQ(1, tensors.size());
1026   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1027 }
1028 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotIdentityReshape)1029 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) {
1030   // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1031   // be from [4,3,28,28] to [8,6,28,28].
1032   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1033   Output inputs =
1034       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1035   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1036   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1037 
1038   GrapplerItem item;
1039   item.fetch = {"outputs"};
1040   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1041   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1042   item.feed = {{"Placeholder", x_t}};
1043   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1044   EXPECT_EQ(1, tensors_expected.size());
1045 
1046   GraphDef output;
1047   ArithmeticOptimizer optimizer;
1048   EnableOnlyRemoveRedundantReshape(&optimizer);
1049   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1050 
1051   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1052   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1053   EXPECT_EQ(1, tensors.size());
1054   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1055 }
1056 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes)1057 TEST_F(ArithmeticOptimizerTest,
1058        RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) {
1059   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1060   Output inputs =
1061       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1062   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1063   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1064 
1065   GrapplerItem item;
1066   item.fetch = {"outputs"};
1067   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1068 
1069   GraphDef output;
1070   ArithmeticOptimizer optimizer;
1071   EnableOnlyRemoveRedundantReshape(&optimizer);
1072   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1073 
1074   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1075 }
1076 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_CombineReshapes)1077 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) {
1078   // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
1079   // reshapes should be combined.
1080   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1081   Output nchw_vect_c =
1082       ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8,
1083                        ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1084   Output transpose =
1085       ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1086                      ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1087   Output nhwc = ops::Reshape(
1088       s.WithOpName("nhwc"), transpose,
1089       ops::Const(s.WithOpName("nhwc_shape"), {8, 28, 28, 12}, {4}));
1090   Output flatten = ops::Reshape(
1091       s.WithOpName("flatten"), nhwc,
1092       ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1093   Output outputs = ops::Identity(s.WithOpName("outputs"), flatten);
1094 
1095   GrapplerItem item;
1096   item.fetch = {"outputs"};
1097   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1098   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
1099   item.feed = {{"nchw_vect_c", x_t}};
1100   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1101   EXPECT_EQ(1, tensors_expected.size());
1102 
1103   GraphDef output;
1104   ArithmeticOptimizer optimizer;
1105   EnableOnlyRemoveRedundantReshape(&optimizer);
1106   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1107 
1108   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1109   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1110   EXPECT_EQ(1, tensors.size());
1111   test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1112 }
1113 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_ProducerIsCast)1114 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_ProducerIsCast) {
1115   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1116   Output nhwc_uint8 =
1117       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1118   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1119   Output nchw_fp32 =
1120       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1121   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1122 
1123   GrapplerItem item;
1124   item.fetch = {"outputs"};
1125   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1126 
1127   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1128   auto tensors_expected =
1129       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1130   EXPECT_EQ(1, tensors_expected.size());
1131 
1132   GraphDef output;
1133   ArithmeticOptimizer optimizer;
1134   OptimizeAndPrune(&optimizer, &item, &output);
1135 
1136   const NodeDef* transpose_node = nullptr;
1137   for (const NodeDef& node : output.node()) {
1138     if (node.op() == "Transpose") {
1139       EXPECT_EQ(transpose_node, nullptr);
1140       EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1141       transpose_node = &node;
1142     }
1143   }
1144   EXPECT_NE(transpose_node, nullptr);
1145 
1146   for (const NodeDef& node : output.node()) {
1147     if (node.op() == "Cast") {
1148       EXPECT_EQ(NodeName(node.input(0)), transpose_node->name());
1149     }
1150   }
1151 
1152   auto tensors =
1153       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1154   EXPECT_EQ(1, tensors.size());
1155   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1156 }
1157 
TEST_F(ArithmeticOptimizerTest,ReorderS2DCast_ProducerIsCast)1158 TEST_F(ArithmeticOptimizerTest, ReorderS2DCast_ProducerIsCast) {
1159   // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1160   // under a /GPU:0 scope, because this test would fail if the testing machine
1161   // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1162   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1163   Output outputs =
1164       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1165   outputs = ops::Cast(s, outputs, DT_FLOAT);
1166   outputs = ops::SpaceToDepth(s, outputs, 2);
1167   outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1168 
1169   GrapplerItem item;
1170   item.fetch = {"outputs"};
1171   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1172 
1173   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1174   auto tensors_expected =
1175       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1176   EXPECT_EQ(1, tensors_expected.size());
1177 
1178   GraphDef output;
1179   ArithmeticOptimizer optimizer;
1180   OptimizeAndPrune(&optimizer, &item, &output);
1181 
1182   const NodeDef* s2d_node = nullptr;
1183   for (const NodeDef& node : output.node()) {
1184     if (node.op() == "SpaceToDepth") {
1185       EXPECT_EQ(s2d_node, nullptr);
1186       EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1187       s2d_node = &node;
1188     }
1189   }
1190   EXPECT_NE(s2d_node, nullptr);
1191 
1192   for (const NodeDef& node : output.node()) {
1193     if (node.op() == "Cast") {
1194       EXPECT_EQ(NodeName(node.input(0)), s2d_node->name());
1195     }
1196   }
1197 
1198   auto tensors =
1199       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1200   EXPECT_EQ(1, tensors.size());
1201   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1202 }
1203 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_ProducerIsTranspose)1204 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_ProducerIsTranspose) {
1205   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1206   Output nhwc_fp32 =
1207       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1208   Output nchw_fp32 =
1209       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1210   Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1211   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1212 
1213   GrapplerItem item;
1214   item.fetch = {"outputs"};
1215   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1216 
1217   auto input_t =
1218       GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1219   auto tensors_expected =
1220       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1221   EXPECT_EQ(1, tensors_expected.size());
1222 
1223   GraphDef output;
1224   ArithmeticOptimizer optimizer;
1225   OptimizeAndPrune(&optimizer, &item, &output);
1226 
1227   const NodeDef* cast_node = nullptr;
1228   for (const NodeDef& node : output.node()) {
1229     if (node.op() == "Cast") {
1230       EXPECT_EQ(cast_node, nullptr);
1231       cast_node = &node;
1232       EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1233     }
1234   }
1235   EXPECT_NE(cast_node, nullptr);
1236 
1237   for (const NodeDef& node : output.node()) {
1238     if (node.op() == "Transpose") {
1239       EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1240       EXPECT_EQ(NodeName(node.input(0)), cast_node->name());
1241     }
1242   }
1243 
1244   auto tensors =
1245       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1246   EXPECT_EQ(1, tensors.size());
1247   test::ExpectTensorEqual<uint8>(tensors_expected[0], tensors[0]);
1248 }
1249 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1250 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1251   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1252   Output nhwc_uint8 =
1253       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1254   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1255   Output nhwc_fp32_reversed =
1256       ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1257   Output nchw_fp32_reversed =
1258       ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1259 
1260   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1261 
1262   GrapplerItem item;
1263   item.fetch = {"outputs"};
1264   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1265 
1266   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1267   auto tensors_expected =
1268       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1269   EXPECT_EQ(1, tensors_expected.size());
1270 
1271   GraphDef output;
1272   ArithmeticOptimizer optimizer;
1273   OptimizeAndPrune(&optimizer, &item, &output);
1274 
1275   const NodeDef* reverse_node = nullptr;
1276   const NodeDef* transpose_node = nullptr;
1277   const NodeDef* cast_node = nullptr;
1278   for (const NodeDef& node : output.node()) {
1279     if (node.op() == "Transpose") {
1280       EXPECT_EQ(transpose_node, nullptr);
1281       EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1282       transpose_node = &node;
1283     } else if (node.op() == "ReverseV2") {
1284       EXPECT_EQ(reverse_node, nullptr);
1285       EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1286       reverse_node = &node;
1287     } else if (node.op() == "Cast") {
1288       cast_node = &node;
1289     }
1290   }
1291   EXPECT_NE(cast_node, nullptr);
1292   EXPECT_NE(reverse_node, nullptr);
1293   EXPECT_NE(transpose_node, nullptr);
1294   EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1295   EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1296   EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1297 
1298   auto tensors =
1299       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1300   EXPECT_EQ(1, tensors.size());
1301   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1302 }
1303 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_CheckNumericsToIdentity)1304 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_CheckNumericsToIdentity) {
1305   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1306   Output nhwc_uint8 =
1307       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1308   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1309   Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1310   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1311 
1312   GrapplerItem item;
1313   item.fetch = {"outputs"};
1314   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1315 
1316   GraphDef output;
1317   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1318   CompareGraphs(item.graph, output);
1319 }
1320 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCast_ProducerIsCast)1321 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast_ProducerIsCast) {
1322   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1323   Output nhwc_fp32 =
1324       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1325   Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1326   Output nchw_uint8 =
1327       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1328   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1329 
1330   GrapplerItem item;
1331   item.fetch = {"outputs"};
1332   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1333 
1334   GraphDef output;
1335   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1336   CompareGraphs(item.graph, output);
1337 }
1338 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCast_ProducerIsTranspose)1339 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast_ProducerIsTranspose) {
1340   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1341   Output nhwc_uint8 =
1342       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1343   Output nchw_uint8 =
1344       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1345   Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1346   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1347 
1348   GrapplerItem item;
1349   item.fetch = {"outputs"};
1350   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1351 
1352   GraphDef output;
1353   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1354   CompareGraphs(item.graph, output);
1355 }
1356 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1357 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1358   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1359   Output inputs_shape =
1360       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1361   Output inputs =
1362       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1363   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1364   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1365   Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1366   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1367   Output transpose2 =
1368       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1369   Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1370   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1371   Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1372 
1373   GrapplerItem item;
1374   item.fetch = {"id1", "id2"};
1375   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1376 
1377   GraphDef output;
1378   ArithmeticOptimizer optimizer;
1379   EnableOnlyRemoveIdentityTranspose(&optimizer);
1380   OptimizeAndPrune(&optimizer, &item, &output);
1381 
1382   std::set<string> nodes_after_optimization;
1383   for (const NodeDef& node : output.node()) {
1384     nodes_after_optimization.insert(node.name());
1385   }
1386   EXPECT_EQ(nodes_after_optimization,
1387             std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1388 }
1389 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1390 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1391   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1392   Output inputs_shape =
1393       ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1394   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1395                                    ops::Placeholder::Shape({8, 12, 28, 28}));
1396   OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1397   Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1398   Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1399   Output branch0 = split[0];
1400   Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1401   Output branch2 = split[2];
1402   Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1403   Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1404 
1405   GrapplerItem item;
1406   item.fetch = {"outputs"};
1407   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1408 
1409   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1410   item.feed = {{"inputs", x_t}};
1411   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1412   EXPECT_EQ(1, tensors_expected.size());
1413 
1414   GraphDef output;
1415   ArithmeticOptimizer optimizer;
1416   EnableOnlyRemoveIdentityTranspose(&optimizer);
1417   OptimizeAndPrune(&optimizer, &item, &output);
1418 
1419   for (const NodeDef& node : output.node()) {
1420     if (node.op() == "Concat") {
1421       EXPECT_EQ(node.input(0), "Split");
1422       EXPECT_EQ(node.input(1), "Split:1");
1423       EXPECT_EQ(node.input(2), "Split:2");
1424     }
1425   }
1426 
1427   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1428   EXPECT_EQ(1, tensors.size());
1429   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1430 }
1431 
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1432 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1433   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1434   Output inputs =
1435       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1436   Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1437   Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1438   Output outputs =
1439       ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1440                     ops::Const(s.WithOpName("outputs_const"), 1.0f));
1441 
1442   GrapplerItem item;
1443   item.fetch = {"outputs"};
1444   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1445 
1446   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1447   item.feed = {{"Placeholder", x_t}};
1448   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1449   EXPECT_EQ(1, tensors_expected.size());
1450 
1451   GraphDef output;
1452   ArithmeticOptimizer optimizer;
1453   EnableOnlyRemoveIdentityTranspose(&optimizer);
1454   OptimizeAndPrune(&optimizer, &item, &output);
1455 
1456   NodeMap node_map(&output);
1457   const NodeDef* outputs_node = node_map.GetNode("outputs");
1458   EXPECT_EQ(2, outputs_node->input_size());
1459   EXPECT_EQ(outputs_node->input(0), "outputs_const");
1460   EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1461 
1462   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1463   EXPECT_EQ(1, tensors.size());
1464   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1465 }
1466 
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1467 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1468   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1469   Output inputs_shape =
1470       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1471   Output inputs =
1472       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1473   Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1474   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1475   Output transpose2 =
1476       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1477   Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1478 
1479   GrapplerItem item;
1480   item.fetch = {"outputs"};
1481   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1482 
1483   GraphDef output;
1484   ArithmeticOptimizer optimizer;
1485   EnableOnlyRemoveIdentityTranspose(&optimizer);
1486   OptimizeAndPrune(&optimizer, &item, &output);
1487 
1488   EXPECT_EQ(6, output.node_size());
1489 }
1490 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1491 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1492   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1493   Output inputs_shape =
1494       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1495   Output inputs =
1496       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1497   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1498   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1499   Output transpose1 = ops::Transpose(
1500       s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1501   Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1502   Output transpose2 =
1503       ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1504   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1505 
1506   GrapplerItem item;
1507   item.fetch = {"id1"};
1508   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1509 
1510   GraphDef output;
1511   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1512   EnableOnlyRemoveIdentityTranspose(&optimizer);
1513   OptimizeAndPrune(&optimizer, &item, &output);
1514 
1515   std::set<string> nodes_after_optimization;
1516   for (const NodeDef& node : output.node()) {
1517     nodes_after_optimization.insert(node.name());
1518     if (node.name() == "id") {
1519       EXPECT_EQ(2, node.input_size());
1520       EXPECT_EQ("inputs", node.input(0));
1521       EXPECT_EQ("^perm2", node.input(1));
1522     }
1523     if (node.name() == "id1") {
1524       EXPECT_EQ(1, node.input_size());
1525       EXPECT_EQ("id", node.input(0));
1526     }
1527   }
1528   EXPECT_EQ(nodes_after_optimization,
1529             std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
1530 }
1531 
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)1532 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
1533   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1534   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1535                                    ops::Placeholder::Shape({8, 28, 28, 3}));
1536   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1537   Output scaled_inputs =
1538       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1539   Output perm_nhwc_to_nchw =
1540       ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1541   Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1542                                       scaled_inputs, perm_nhwc_to_nchw);
1543   Output weights = ops::Const(s.WithOpName("weights"),
1544                               Input::Initializer(127.0f, {5, 5, 3, 16}));
1545   Output conv =
1546       ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1547                   "VALID", ops::Conv2D::DataFormat("NCHW"));
1548   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1549 
1550   GrapplerItem item;
1551   item.fetch = {"outputs"};
1552   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1553 
1554   GraphDef output;
1555   ArithmeticOptimizer optimizer;
1556   EnableOnlyFoldMultipleIntoConv(&optimizer);
1557   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1558 
1559   NodeMap node_map(&output);
1560 
1561   // `conv` is now a folded convolution with scaled weights.
1562   const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1563   ASSERT_NE(folded_conv, nullptr);
1564 
1565   const NodeDef* folded_conv_weights = node_map.GetNode(folded_conv->input(1));
1566   ASSERT_NE(folded_conv_weights, nullptr);
1567   EXPECT_EQ("Mul", folded_conv_weights->op());
1568 
1569   // Its input should be a transpose of `inputs`.
1570   const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0)));
1571   ASSERT_NE(transpose, nullptr);
1572   EXPECT_EQ("inputs", transpose->input(0));
1573 }
1574 
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)1575 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
1576   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1577   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1578                                    ops::Placeholder::Shape({8, 28, 28, 3}));
1579   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1580   Output scaled_inputs =
1581       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1582   Output perm_nhwc_to_nchw =
1583       ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1584   Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1585                                       scaled_inputs, perm_nhwc_to_nchw);
1586   Output weights = ops::Const(s.WithOpName("weights"),
1587                               Input::Initializer(127.0f, {5, 5, 3, 16}));
1588   Output conv =
1589       ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1590                   "VALID", ops::Conv2D::DataFormat("NCHW"));
1591   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1592 
1593   Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
1594   memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
1595          inputs_nchw_tensor.tensor_data().size());
1596 
1597   GrapplerItem item;
1598   item.fetch = {"outputs"};
1599   item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
1600   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1601 
1602   GraphDef output;
1603   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1604 
1605   item.graph.Swap(&output);
1606   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1607 
1608   NodeMap node_map(&output);
1609   const NodeDef* inputs_nchw_node_def =
1610       node_map.GetNode(inputs_nchw.node()->name());
1611   EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
1612             scaled_inputs.node()->name());
1613 }
1614 
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)1615 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
1616   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1617   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1618                                    ops::Placeholder::Shape({8, 28, 28, 28, 3}));
1619   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1620   Output scaled_inputs =
1621       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1622   Output weights = ops::Const(s.WithOpName("weights"),
1623                               Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
1624   Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
1625                             {1, 1, 1, 1, 1}, "VALID");
1626   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1627 
1628   GrapplerItem item;
1629   item.fetch = {"outputs"};
1630   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1631 
1632   GraphDef output;
1633   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1634 
1635   item.graph.Swap(&output);
1636   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1637 
1638   NodeMap node_map(&output);
1639   // `conv` is now a folded convolution on `inputs` and scaled weights.
1640   const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1641   CHECK_EQ(inputs.node()->name(), NodeName(folded_conv->input(0)));
1642   CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul");
1643 }
1644 
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)1645 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
1646   // This unit test exercises two optimizations, folding mul into conv, and
1647   // reordering cast and transpose.
1648   //
1649   //   Conv2D(Transpose(Mul(Cast(I), S)), W)
1650   //     =>
1651   //   Conv2D(Transpose(Cast(I)), W*S)
1652   //     =>
1653   //   Conv2D(Cast(Transpose(I)), W*S)
1654   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1655 
1656   Output inputs =
1657       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1658   Output cast = ops::Cast(s, inputs, DT_FLOAT);
1659   Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
1660   Output transpose =
1661       ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
1662   Output weights = ops::Const(s.WithOpName("weights"),
1663                               Input::Initializer(127.0f, {5, 5, 3, 16}));
1664   Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
1665                             ops::Conv2D::DataFormat("NCHW"));
1666   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1667 
1668   GrapplerItem item;
1669   item.fetch = {"outputs"};
1670   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1671 
1672   GraphDef output;
1673   ArithmeticOptimizer optimizer;  // all optimization stages are on
1674   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1675   NodeMap node_map(&output);
1676 
1677   // Expected names for reordered cast and transpose.
1678   const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
1679   const string optimized_cast_name = strings::StrCat(p, "float_Cast");
1680   const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose");
1681 
1682   // Expected names for folded multiply and conv.
1683   const string optimized_weights =
1684       "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
1685 
1686   const NodeDef* inputs_node = node_map.GetNode("Placeholder");
1687   const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
1688   const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
1689 
1690   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1691   const NodeDef* conv_node = node_map.GetNode("Conv2D");
1692 
1693   ASSERT_NE(inputs_node, nullptr);
1694   ASSERT_NE(transpose_node, nullptr);
1695   ASSERT_NE(cast_node, nullptr);
1696   ASSERT_NE(weights_node, nullptr);
1697   ASSERT_NE(conv_node, nullptr);
1698 
1699   EXPECT_EQ(output.node_size(), 7);
1700   EXPECT_EQ(transpose_node->input(0), inputs_node->name());
1701   EXPECT_EQ(cast_node->input(0), transpose_node->name());
1702   EXPECT_EQ(conv_node->input(0), cast_node->name());
1703   EXPECT_EQ(conv_node->input(1), weights_node->name());
1704 }
1705 
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)1706 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
1707   // This unit test exercises optimization of folding mul into conv for
1708   // multiple nodes in the graph.
1709   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1710 
1711   GrapplerItem item;
1712   Output conv[2];
1713 
1714   for (int i = 0; i < 2; ++i) {
1715     Output inputs =
1716         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
1717     Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
1718     Output weights = ops::Const(s.WithOpName("weights"),
1719                                 Input::Initializer(127.0f, {5, 5, 3, 16}));
1720     conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
1721                           ops::Conv2D::DataFormat("NCHW"));
1722   }
1723   Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
1724 
1725   item.fetch = {"outputs"};
1726   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1727 
1728   GraphDef output;
1729   ArithmeticOptimizer optimizer;
1730   EnableOnlyFoldMultipleIntoConv(&optimizer);
1731   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1732 
1733   NodeMap node_map(&output);
1734 
1735   using strings::StrCat;
1736   const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
1737   const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
1738   const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
1739 
1740   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1741   const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
1742   const NodeDef* conv_node = node_map.GetNode("Conv2D");
1743   const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
1744 
1745   ASSERT_NE(weights_node, nullptr);
1746   ASSERT_NE(weights_node_1, nullptr);
1747   ASSERT_NE(conv_node, nullptr);
1748   ASSERT_NE(conv_node_1, nullptr);
1749 
1750   EXPECT_EQ(conv_node->input(1), weights_node->name());
1751   EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
1752 }
1753 
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)1754 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
1755   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1756   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
1757                                    ops::Placeholder::Shape({2, 3}));
1758   Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
1759   Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
1760   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1761 
1762   GrapplerItem item;
1763   item.fetch = {"outputs"};
1764   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1765 
1766   auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
1767   item.feed = {{"inputs", x_t}};
1768   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1769   EXPECT_EQ(1, tensors_expected.size());
1770 
1771   GraphDef output;
1772   ArithmeticOptimizer optimizer;
1773   EnableOnlyRemoveRedundantBitcast(&optimizer);
1774 
1775   OptimizeAndPrune(&optimizer, &item, &output);
1776   NodeMap node_map(&output);
1777 
1778   // Bitcasts combined into a single op and inputs redirected to updated Bitcast
1779   EXPECT_EQ(3, output.node_size());
1780   EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
1781   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
1782 
1783   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1784   EXPECT_EQ(1, tensors.size());
1785   test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1786 }
1787 
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)1788 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
1789   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1790   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1791                                    ops::Placeholder::Shape({2, 3}));
1792   Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
1793   Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
1794   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1795 
1796   GrapplerItem item;
1797   item.fetch = {"outputs"};
1798   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1799 
1800   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1801   item.feed = {{"inputs", x_t}};
1802   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1803   EXPECT_EQ(1, tensors_expected.size());
1804 
1805   GraphDef output;
1806   ArithmeticOptimizer optimizer;
1807   EnableOnlyRemoveRedundantBitcast(&optimizer);
1808 
1809   OptimizeAndPrune(&optimizer, &item, &output);
1810   NodeMap node_map(&output);
1811 
1812   // Bitcasts removed and inputs redirected to outputs
1813   EXPECT_EQ(2, output.node_size());
1814   EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
1815   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1816 
1817   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1818   EXPECT_EQ(1, tensors.size());
1819   test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1820 }
1821 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)1822 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
1823   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1824   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1825                                    ops::Placeholder::Shape({2, 3}));
1826   Output cast = ops::Cast(s, inputs, DT_INT8);
1827   Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
1828 
1829   GrapplerItem item;
1830   item.fetch = {"outputs"};
1831   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1832 
1833   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1834   item.feed = {{"inputs", x_t}};
1835   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1836   EXPECT_EQ(1, tensors_expected.size());
1837 
1838   GraphDef output;
1839   ArithmeticOptimizer optimizer;
1840   EnableOnlyRemoveRedundantCast(&optimizer);
1841 
1842   OptimizeAndPrune(&optimizer, &item, &output);
1843   NodeMap node_map(&output);
1844 
1845   // Cast removed and inputs redirected to outputs
1846   EXPECT_EQ(2, output.node_size());
1847   EXPECT_EQ(0, CountOpNodes(output, "Cast"));
1848   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1849 
1850   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1851   EXPECT_EQ(1, tensors.size());
1852   test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1853 }
1854 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddOpsOfIdenticalShape)1855 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
1856   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1857   tensorflow::Scope sx = s.NewSubScope("x");
1858   tensorflow::Scope sy = s.NewSubScope("y");
1859 
1860   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1861   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1862   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1863   auto add_ab = ops::Add(sx.WithOpName("Add_ab"), a, b);
1864   auto add_abc = ops::Add(sy.WithOpName("Add_abc"), add_ab, c);
1865 
1866   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
1867 
1868   GrapplerItem item;
1869   item.fetch = {"outputs"};
1870   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1871 
1872   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1873   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1874   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1875   std::vector<std::pair<string, Tensor>> feed = {
1876       {"a", a_t}, {"b", b_t}, {"c", c_t}};
1877   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1878   EXPECT_EQ(1, tensors_expected.size());
1879 
1880   GraphDef output;
1881   ArithmeticOptimizer optimizer;
1882   EnableOnlyAddToAddNCombining(&optimizer);
1883 
1884   OptimizeAndPrune(&optimizer, &item, &output);
1885 
1886   // We expect the following rewrite(s) to occur:
1887   //
1888   //     +
1889   //    / \
1890   //   +   c      -->    AddN(a, b, c)
1891   //  / \
1892   // a   b
1893   EXPECT_EQ(5, output.node_size());
1894 
1895   NodeMap node_map(&output);
1896 
1897   // check add tree was replaced with AddN
1898   const NodeDef* collapsed_add =
1899       node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1900   ASSERT_NE(collapsed_add, nullptr);
1901 
1902   EXPECT_EQ("AddN", collapsed_add->op());
1903   EXPECT_EQ(3, collapsed_add->input_size());
1904   EXPECT_EQ("a", collapsed_add->input(0));
1905   EXPECT_EQ("b", collapsed_add->input(1));
1906   EXPECT_EQ("c", collapsed_add->input(2));
1907 
1908   // check output was re-wired to new node
1909   const NodeDef* updated_outputs = node_map.GetNode("outputs");
1910   ASSERT_NE(updated_outputs, nullptr);
1911 
1912   EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
1913 
1914   auto tensors = EvaluateNodes(output, item.fetch, feed);
1915   EXPECT_EQ(1, tensors.size());
1916   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1917 }
1918 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MultiplePasses)1919 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
1920   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1921 
1922   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1923   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1924   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1925   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
1926   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
1927 
1928   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
1929   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
1930   auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
1931   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
1932   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
1933 
1934   auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
1935   auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
1936 
1937   GrapplerItem item;
1938   item.fetch = {"outputs"};
1939   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1940 
1941   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1942   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1943   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1944   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1945   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1946   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1947   std::vector<std::pair<string, Tensor>> feed = {
1948       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
1949   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1950   EXPECT_EQ(1, tensors_expected.size());
1951 
1952   GraphDef output;
1953   ArithmeticOptimizer optimizer;
1954   EnableOnlyAddToAddNCombining(&optimizer);
1955 
1956   OptimizeAndPrune(&optimizer, &item, &output);
1957 
1958   // We expect the following rewrite(s) to occur:
1959   //
1960   //         *
1961   //      /     \
1962   //     +       +                        *
1963   //    / \     / \                    /     \
1964   //   +   c   x   + -->    AddN(a, b, c)  AddN(x, y, z))
1965   //  / \         / \
1966   // a   b       y   z
1967   EXPECT_EQ(10, output.node_size());
1968 
1969   NodeMap node_map(&output);
1970 
1971   // check left Add subtree replaced with AddN
1972   const NodeDef* collapsed_left =
1973       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1974   ASSERT_NE(collapsed_left, nullptr);
1975 
1976   EXPECT_EQ("AddN", collapsed_left->op());
1977   EXPECT_EQ(3, collapsed_left->input_size());
1978   EXPECT_EQ("a", collapsed_left->input(0));
1979   EXPECT_EQ("b", collapsed_left->input(1));
1980   EXPECT_EQ("c", collapsed_left->input(2));
1981 
1982   // check right Add subtree replaced with AddN
1983   const NodeDef* collapsed_right =
1984       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
1985   ASSERT_NE(collapsed_right, nullptr);
1986 
1987   EXPECT_EQ("AddN", collapsed_right->op());
1988   EXPECT_EQ(3, collapsed_right->input_size());
1989   EXPECT_EQ("x", collapsed_right->input(0));
1990   EXPECT_EQ("y", collapsed_right->input(1));
1991   EXPECT_EQ("z", collapsed_right->input(2));
1992 
1993   // check that Mul inputs re-wired to new Nodes
1994   const NodeDef* updated_mul = node_map.GetNode("Mul");
1995   ASSERT_NE(updated_mul, nullptr);
1996 
1997   EXPECT_EQ("Mul", updated_mul->op());
1998   EXPECT_EQ(2, updated_mul->input_size());
1999   EXPECT_EQ(collapsed_left->name(), updated_mul->input(0));
2000   EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
2001 
2002   auto tensors = EvaluateNodes(output, item.fetch, feed);
2003   EXPECT_EQ(1, tensors.size());
2004   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2005 }
2006 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddInputMultipleTimes)2007 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
2008   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2009 
2010   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2011   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2012   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2013   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2014   auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2015   auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2016   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2017 
2018   GrapplerItem item;
2019   item.fetch = {"outputs"};
2020   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2021 
2022   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2023   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2024   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2025   std::vector<std::pair<string, Tensor>> feed = {
2026       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2027   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2028   EXPECT_EQ(1, tensors_expected.size());
2029 
2030   GraphDef output;
2031   ArithmeticOptimizer optimizer;
2032   EnableOnlyAddToAddNCombining(&optimizer);
2033 
2034   OptimizeAndPrune(&optimizer, &item, &output);
2035 
2036   // We expect the following rewrite(s) to occur:
2037   //
2038   //     +
2039   //    / \
2040   //   +   +     -->    AddN(a, b, b, c)
2041   //  / \ / \                   ^
2042   // a   b   c                  b added twice!
2043   EXPECT_EQ(5, output.node_size());
2044 
2045   NodeMap node_map(&output);
2046 
2047   // check Add tree replaced with AddN
2048   const NodeDef* collapsed_add =
2049       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2050   ASSERT_NE(collapsed_add, nullptr);
2051 
2052   EXPECT_EQ("AddN", collapsed_add->op());
2053   EXPECT_EQ(4, collapsed_add->input_size());
2054   EXPECT_EQ("a", collapsed_add->input(0));
2055   EXPECT_EQ("b", collapsed_add->input(1));
2056   EXPECT_EQ("b", collapsed_add->input(2));
2057   EXPECT_EQ("c", collapsed_add->input(3));
2058 
2059   auto tensors = EvaluateNodes(output, item.fetch, feed);
2060   EXPECT_EQ(1, tensors.size());
2061   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2062 }
2063 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddOpsOfSymbolicallyEqualShape)2064 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
2065   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2066 
2067   // unknown input shape propagated symbolically through the graph
2068   auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2069 
2070   // [a, b, c] have symbolically equal shapes
2071   auto a = ops::Sqrt(s.WithOpName("a"), input);
2072   auto b = ops::Square(s.WithOpName("b"), input);
2073   auto c = ops::Round(s.WithOpName("c"), input);
2074 
2075   // [add_ab, add_abc] shape must be inferred from inputs
2076   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2077   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2078 
2079   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2080 
2081   GrapplerItem item;
2082   item.fetch = {"outputs"};
2083   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2084 
2085   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2086   std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2087   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2088   EXPECT_EQ(1, tensors_expected.size());
2089 
2090   GraphDef output;
2091   ArithmeticOptimizer optimizer;
2092   EnableOnlyAddToAddNCombining(&optimizer);
2093 
2094   OptimizeAndPrune(&optimizer, &item, &output);
2095 
2096   // We expect the following rewrite(s) to occur:
2097   //
2098   //     +
2099   //    / \
2100   //   +   c      -->    AddN(a, b, c)
2101   //  / \
2102   // a   b
2103   EXPECT_EQ(6, output.node_size());
2104 
2105   NodeMap node_map(&output);
2106 
2107   // check add tree was replaced with AddN
2108   const NodeDef* collapsed_add =
2109       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2110   ASSERT_NE(collapsed_add, nullptr);
2111   EXPECT_EQ("AddN", collapsed_add->op());
2112   EXPECT_EQ(3, collapsed_add->input_size());
2113   EXPECT_EQ("a", collapsed_add->input(0));
2114   EXPECT_EQ("b", collapsed_add->input(1));
2115   EXPECT_EQ("c", collapsed_add->input(2));
2116 
2117   // check output was re-wired to new node
2118   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2119   ASSERT_NE(updated_outputs, nullptr);
2120   EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
2121 
2122   auto tensors = EvaluateNodes(output, item.fetch, feed);
2123   EXPECT_EQ(1, tensors.size());
2124   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2125 }
2126 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MinimizeBCast)2127 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
2128   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2129 
2130   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2131   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2132   auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2133   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2134   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2135 
2136   auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2137   auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2138   auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2139   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2140   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2141 
2142   auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2143   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2144 
2145   GrapplerItem item;
2146   item.fetch = {"outputs"};
2147   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2148 
2149   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2150   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2151   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2152   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2153   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2154   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2155   std::vector<std::pair<string, Tensor>> feed = {
2156       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2157   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2158   EXPECT_EQ(1, tensors_expected.size());
2159 
2160   GraphDef output;
2161   ArithmeticOptimizer optimizer;
2162   EnableOnlyAddToAddNCombining(&optimizer);
2163 
2164   OptimizeAndPrune(&optimizer, &item, &output);
2165 
2166   // We expect the following rewrite(s) to occur:
2167   //  1) [a, x], [b, y], [c, z] - aggregate same shapes first
2168   //  2) Build an aggregation tree minimizing cost of broadcast
2169   //
2170   //         +                              +
2171   //      /     \                       /       \
2172   //     +       +                     +       AddN(c, z)
2173   //    / \     / \                 /     \
2174   //   +   c   x   + -->    AddN(a, x)  AddN(b, y)
2175   //  / \         / \
2176   // a   b       y   z
2177   EXPECT_EQ(12, output.node_size());
2178   NodeMap node_map(&output);
2179 
2180   // expected names of outer and inner nodes
2181   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2182   string outer_0_add_name =
2183       "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2184   string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2185   string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2186   string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2187 
2188   // Add [a, x] first
2189   const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2190   ASSERT_NE(add_ax_node, nullptr);
2191   EXPECT_EQ("AddN", add_ax_node->op());
2192   EXPECT_EQ(2, add_ax_node->input_size());
2193   EXPECT_EQ("a", add_ax_node->input(0));
2194   EXPECT_EQ("x", add_ax_node->input(1));
2195 
2196   // Then add [b, y]
2197   const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2198   ASSERT_NE(add_by_node, nullptr);
2199   EXPECT_EQ("AddN", add_by_node->op());
2200   EXPECT_EQ(2, add_by_node->input_size());
2201   EXPECT_EQ("b", add_by_node->input(0));
2202   EXPECT_EQ("y", add_by_node->input(1));
2203 
2204   // Then add [c, z]
2205   const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2206   ASSERT_NE(add_cz_node, nullptr);
2207   EXPECT_EQ("AddN", add_cz_node->op());
2208   EXPECT_EQ(2, add_cz_node->input_size());
2209   EXPECT_EQ("c", add_cz_node->input(0));
2210   EXPECT_EQ("z", add_cz_node->input(1));
2211 
2212   // Then add results together starting from smaller shapes [a, x] + [b, y]
2213   const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2214   ASSERT_NE(outer_0_node, nullptr);
2215   EXPECT_EQ("Add", outer_0_node->op());
2216   EXPECT_EQ(2, outer_0_node->input_size());
2217   EXPECT_EQ(inner_0_add_name, outer_0_node->input(0));
2218   EXPECT_EQ(inner_1_add_name, outer_0_node->input(1));
2219 
2220   // And finally top level Add node
2221   const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2222   ASSERT_NE(outer_node, nullptr);
2223   EXPECT_EQ("Add", outer_node->op());
2224   EXPECT_EQ(2, outer_node->input_size());
2225   EXPECT_EQ(outer_0_add_name, outer_node->input(0));
2226   EXPECT_EQ(inner_2_add_name, outer_node->input(1));
2227 
2228   // And outputs reading new top level Add node
2229   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2230   ASSERT_NE(updated_outputs, nullptr);
2231   EXPECT_EQ(outer_add_name, updated_outputs->input(0));
2232 
2233   auto tensors = EvaluateNodes(output, item.fetch, feed);
2234   EXPECT_EQ(1, tensors.size());
2235   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2236 }
2237 
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MinimizeBCastWithSymbolicShapes)2238 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
2239   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2240 
2241   // We have a small input with one unknown dimension
2242   auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2243 
2244   // And second input which is larger, but has the same unknown dimension
2245   // device spec prevents this node from rewriting
2246   auto d = "/device:CPU:0";
2247   auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2248   auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2249 
2250   // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2251   auto a = ops::Sqrt(s.WithOpName("a"), small);
2252   auto b = ops::Square(s.WithOpName("b"), large);
2253   auto c = ops::Round(s.WithOpName("c"), small);
2254 
2255   // [add_ab, add_abc] shape must be inferred from inputs
2256   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2257   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2258 
2259   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2260 
2261   GrapplerItem item;
2262   item.fetch = {"outputs"};
2263   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2264 
2265   auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2266   auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2267   std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2268   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2269   EXPECT_EQ(1, tensors_expected.size());
2270 
2271   GraphDef output;
2272   ArithmeticOptimizer optimizer;
2273   EnableOnlyAddToAddNCombining(&optimizer);
2274   OptimizeAndPrune(&optimizer, &item, &output);
2275 
2276   // We expect the following rewrite(s) to occur: it's much cheaper to add small
2277   // tensors, and do the broadcast just once
2278   //
2279   //     +                  +
2280   //    / \                / \
2281   //   +   c      -->     +   b
2282   //  / \                / \
2283   // a   b              a   c
2284   EXPECT_EQ(9, output.node_size());
2285   NodeMap node_map(&output);
2286 
2287   // expected names of outer and inner nodes
2288   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2289   string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2290 
2291   // outer Add node
2292   const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2293   ASSERT_NE(outer_add, nullptr);
2294   EXPECT_EQ("Add", outer_add->op());
2295   EXPECT_EQ(inner_add_name, outer_add->input(0));
2296   EXPECT_EQ("b", outer_add->input(1));
2297 
2298   // inner AddN node
2299   const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2300   ASSERT_NE(inner_add, nullptr);
2301   EXPECT_EQ(2, inner_add->input_size());
2302   EXPECT_EQ("a", inner_add->input(0));
2303   EXPECT_EQ("c", inner_add->input(1));
2304 
2305   // check output was re-wired to new node
2306   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2307   ASSERT_NE(updated_outputs, nullptr);
2308   EXPECT_EQ(outer_add_name, updated_outputs->input(0));
2309 
2310   auto tensors = EvaluateNodes(output, item.fetch, feed);
2311   EXPECT_EQ(1, tensors.size());
2312   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
2313 }
2314 
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2315 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2316   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2317   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2318   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2319   Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2320   Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2321   Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2322   Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2323   Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2324   Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2325   Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2326   Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2327   Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2328   Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2329   Output neg_x_with_dep = ops::Neg(
2330       s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2331   Output add_negx_with_dep_y =
2332       ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2333   auto add_all =
2334       ops::AddN(s.WithOpName("add_all"),
2335                 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2336                  sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2337 
2338   GrapplerItem item;
2339   item.fetch = {"add_all"};
2340   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2341 
2342   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2343   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2344   std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2345   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2346   EXPECT_EQ(1, tensors_expected.size());
2347 
2348   GraphDef output;
2349   ArithmeticOptimizer optimizer;
2350   EnableOnlyRemoveNegation(&optimizer);
2351   OptimizeTwice(&optimizer, &item, &output);
2352 
2353   EXPECT_EQ(item.graph.node_size(), output.node_size());
2354   int found = 0;
2355   for (int i = 0; i < output.node_size(); ++i) {
2356     const NodeDef& node = output.node(i);
2357     if (node.name() == "Add_negx_y") {
2358       ++found;
2359       EXPECT_EQ("Sub", node.op());
2360       EXPECT_EQ(2, node.input_size());
2361       EXPECT_EQ("y", node.input(0));
2362       EXPECT_EQ("x", node.input(1));
2363     } else if (node.name() == "Add_x_negy") {
2364       ++found;
2365       EXPECT_EQ("Sub", node.op());
2366       EXPECT_EQ(2, node.input_size());
2367       EXPECT_EQ("x", node.input(0));
2368       EXPECT_EQ("y", node.input(1));
2369     } else if (node.name() == "Add_negx_negy") {
2370       ++found;
2371       EXPECT_EQ("Sub", node.op());
2372       EXPECT_EQ(2, node.input_size());
2373       EXPECT_EQ("Neg_x", node.input(0));
2374       EXPECT_EQ("y", node.input(1));
2375     } else if (node.name() == "Sub_x_negy") {
2376       ++found;
2377       EXPECT_EQ("Add", node.op());
2378       EXPECT_EQ(2, node.input_size());
2379       EXPECT_EQ("x", node.input(0));
2380       EXPECT_EQ("y", node.input(1));
2381     } else if (node.name() == "Sub_negx_negy") {
2382       ++found;
2383       EXPECT_EQ("Sub", node.op());
2384       EXPECT_EQ(2, node.input_size());
2385       EXPECT_EQ("y", node.input(0));
2386       EXPECT_EQ("x", node.input(1));
2387     } else if (node.name() == "Add_negx_with_dep_y") {
2388       ++found;
2389       EXPECT_EQ("Sub", node.op());
2390       EXPECT_EQ(3, node.input_size());
2391       EXPECT_EQ("y", node.input(0));
2392       EXPECT_EQ("x", node.input(1));
2393       EXPECT_EQ("^Add_x_y", node.input(2));
2394     }
2395   }
2396   EXPECT_EQ(6, found);
2397 
2398   auto tensors = EvaluateNodes(output, item.fetch, feed);
2399   EXPECT_EQ(1, tensors.size());
2400   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2401 }
2402 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2403 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2404   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2405   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2406   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2407   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2408   Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2409 
2410   GrapplerItem item;
2411   item.fetch = {"output"};
2412   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2413   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2414   EXPECT_EQ(1, tensors_expected.size());
2415 
2416   GraphDef output;
2417   ArithmeticOptimizer optimizer;
2418   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2419   OptimizeAndPrune(&optimizer, &item, &output);
2420   auto tensors = EvaluateNodes(output, item.fetch);
2421   EXPECT_EQ(1, tensors.size());
2422 
2423   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2424   EXPECT_EQ(item.graph.node_size(), output.node_size());
2425   for (int i = 0; i < output.node_size(); ++i) {
2426     const NodeDef& node = output.node(i);
2427     if (node.name() == "output") {
2428       EXPECT_EQ("Mul", node.op());
2429       EXPECT_EQ(2, node.input_size());
2430       EXPECT_EQ("x", node.input(0));
2431       EXPECT_EQ("sqrt_y", node.input(1));
2432     } else if (node.name() == "sqrt_y") {
2433       EXPECT_EQ("Rsqrt", node.op());
2434       EXPECT_EQ(1, node.input_size());
2435       EXPECT_EQ("y", node.input(0));
2436     }
2437   }
2438 }
2439 
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2440 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2441   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2442   Output floats = ops::Const(s.WithOpName("floats"),
2443                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2444   Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2445   Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2446   Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2447   Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2448 
2449   GrapplerItem item;
2450   item.fetch = {"grad", "output0"};
2451   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2452   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2453   ASSERT_EQ(2, tensors_expected.size());
2454 
2455   GraphDef output;
2456   ArithmeticOptimizer optimizer;
2457   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2458   OptimizeAndPrune(&optimizer, &item, &output);
2459   auto tensors = EvaluateNodes(output, item.fetch);
2460   ASSERT_EQ(2, tensors.size());
2461 
2462   for (int i = 0; i < tensors.size(); i++) {
2463     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2464     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2465   }
2466   EXPECT_EQ(item.graph.node_size(), output.node_size());
2467   for (int i = 0; i < output.node_size(); ++i) {
2468     const NodeDef& node = output.node(i);
2469     if (node.name() == "grad") {
2470       EXPECT_EQ("Div", node.op());
2471       EXPECT_EQ(2, node.input_size());
2472       EXPECT_EQ("mul1", node.input(0));
2473       EXPECT_EQ("output0", node.input(1));
2474     } else if (node.name() == "output0") {
2475       EXPECT_EQ("Sqrt", node.op());
2476       EXPECT_EQ(1, node.input_size());
2477       EXPECT_EQ("floats", node.input(0));
2478     }
2479   }
2480 }
2481 
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)2482 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
2483   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2484   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2485   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2486   Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2487   Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2488 
2489   GrapplerItem item;
2490   item.fetch = {"output"};
2491   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2492   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2493   EXPECT_EQ(1, tensors_expected.size());
2494 
2495   GraphDef output;
2496   ArithmeticOptimizer optimizer;
2497   EnableOnlyFuseSquaredDiff(&optimizer);
2498   OptimizeAndPrune(&optimizer, &item, &output);
2499   const auto tensors = EvaluateNodes(output, item.fetch);
2500   EXPECT_EQ(1, tensors.size());
2501 
2502   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2503   EXPECT_EQ(item.graph.node_size(), output.node_size());
2504   for (int i = 0; i < output.node_size(); ++i) {
2505     const NodeDef& node = output.node(i);
2506     if (node.name() == "output") {
2507       EXPECT_EQ("Identity", node.op());
2508       EXPECT_EQ(1, node.input_size());
2509       EXPECT_EQ("sub_x_y", node.input(0));
2510     } else if (node.name() == "sub_x_y") {
2511       EXPECT_EQ("SquaredDifference", node.op());
2512       EXPECT_EQ(2, node.input_size());
2513       EXPECT_EQ("x", node.input(0));
2514       EXPECT_EQ("y", node.input(1));
2515     }
2516   }
2517 }
2518 
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)2519 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
2520   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2521   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2522   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2523   Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2524   Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2525 
2526   GrapplerItem item;
2527   item.fetch = {"output", "sub_x_y"};
2528   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2529   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2530   ASSERT_EQ(2, tensors_expected.size());
2531 
2532   GraphDef output;
2533   ArithmeticOptimizer optimizer;
2534   EnableOnlyFuseSquaredDiff(&optimizer);
2535   OptimizeAndPrune(&optimizer, &item, &output);
2536   const auto tensors = EvaluateNodes(output, item.fetch);
2537   ASSERT_EQ(2, tensors.size());
2538 
2539   for (int i = 0; i < tensors.size(); i++) {
2540     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2541     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2542   }
2543   EXPECT_EQ(item.graph.node_size(), output.node_size());
2544   for (int i = 0; i < output.node_size(); ++i) {
2545     const NodeDef& node = output.node(i);
2546     if (node.name() == "output") {
2547       EXPECT_EQ("Square", node.op());
2548       EXPECT_EQ(1, node.input_size());
2549       EXPECT_EQ("sub_x_y", node.input(0));
2550     } else if (node.name() == "sub_x_y") {
2551       EXPECT_EQ("Sub", node.op());
2552       EXPECT_EQ(2, node.input_size());
2553       EXPECT_EQ("x", node.input(0));
2554       EXPECT_EQ("y", node.input(1));
2555     }
2556   }
2557 }
2558 
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)2559 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
2560   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2561   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2562   Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
2563   Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
2564 
2565   GrapplerItem item;
2566   item.fetch = {"output"};
2567   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2568   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2569   EXPECT_EQ(1, tensors_expected.size());
2570 
2571   GraphDef output;
2572   ArithmeticOptimizer optimizer;
2573   EnableOnlyLogSoftmax(&optimizer);
2574   OptimizeAndPrune(&optimizer, &item, &output);
2575   const auto tensors = EvaluateNodes(output, item.fetch);
2576   EXPECT_EQ(1, tensors.size());
2577 
2578   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2579   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
2580   for (int i = 0; i < output.node_size(); ++i) {
2581     const NodeDef& node = output.node(i);
2582     if (node.name() == "output") {
2583       EXPECT_EQ("LogSoftmax", node.op());
2584       EXPECT_EQ(1, node.input_size());
2585       EXPECT_EQ("x", node.input(0));
2586     }
2587   }
2588 }
2589 
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)2590 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
2591   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2592   Output floats = ops::Const(s.WithOpName("floats"),
2593                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2594   Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
2595   Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
2596 
2597   GrapplerItem item;
2598   item.fetch = {"softmax", "final_output"};
2599   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2600   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2601   ASSERT_EQ(2, tensors_expected.size());
2602 
2603   GraphDef output;
2604   ArithmeticOptimizer optimizer;
2605   EnableOnlyLogSoftmax(&optimizer);
2606   OptimizeTwice(&optimizer, &item, &output);
2607   const auto tensors = EvaluateNodes(output, item.fetch);
2608   ASSERT_EQ(2, tensors.size());
2609 
2610   // Should be a NoOp since we are not allowed to change the output of fetch
2611   // nodes.
2612   VerifyGraphsMatch(item.graph, output, __LINE__);
2613 
2614   for (int i = 0; i < tensors.size(); i++) {
2615     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2616     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2617   }
2618 }
2619 
TEST_F(ArithmeticOptimizerTest,ConvertPow)2620 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
2621   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2622   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2623   auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
2624   auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
2625   auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
2626   auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
2627   auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
2628   auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
2629   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2630   auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
2631   auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
2632   auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
2633   Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
2634   Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
2635   Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
2636   Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
2637   Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
2638   Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
2639   Output out = ops::Pow(s.WithOpName("out"), x, y);
2640   Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
2641   Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
2642 
2643   GrapplerItem item;
2644   item.fetch = {"out2",  "out1", "out.5",      "out0",      "out_.5",
2645                 "out_1", "out",  "out_bcast1", "out_bcast2"};
2646   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2647   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2648   EXPECT_EQ(9, tensors_expected.size());
2649 
2650   GraphDef got;
2651   ArithmeticOptimizer optimizer;
2652   EnableOnlyConvertPow(&optimizer);
2653   OptimizeAndPrune(&optimizer, &item, &got);
2654   auto tensors = EvaluateNodes(got, item.fetch);
2655   EXPECT_EQ(9, tensors.size());
2656 
2657   for (int i = 0; i < tensors.size(); ++i) {
2658     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2659     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2660   }
2661 
2662   GraphDef want;
2663   AddNode("x", "Const", {}, {}, &want);
2664   AddNode("y2", "Const", {}, {}, &want);
2665   AddNode("y1", "Const", {}, {}, &want);
2666   AddNode("y.5", "Const", {}, {}, &want);
2667   AddNode("y0", "Const", {}, {}, &want);
2668   AddNode("y_.5", "Const", {}, {}, &want);
2669   AddNode("y_1", "Const", {}, {}, &want);
2670   AddNode("y", "Const", {}, {}, &want);
2671   AddNode("z", "Const", {}, {}, &want);
2672   AddNode("ones", "Const", {}, {}, &want);
2673   AddNode("zeros", "Const", {}, {}, &want);
2674   AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
2675   AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
2676   AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
2677   AddNode("out0", "Const",
2678           {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
2679   AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
2680   AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
2681   AddNode("out", "Pow", {"x", "y"}, {}, &want);
2682   AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
2683   AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
2684 
2685   CompareGraphs(want, got);
2686 }
2687 
TEST_F(ArithmeticOptimizerTest,Log1p)2688 TEST_F(ArithmeticOptimizerTest, Log1p) {
2689   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2690 
2691   auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
2692   auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
2693   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2694   auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
2695   auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
2696   Output out1 = ops::Log(s.WithOpName("out1"), a12);
2697   Output out2 = ops::Log(s.WithOpName("out2"), a23);
2698 
2699   GrapplerItem item;
2700   item.fetch = {"out1", "out2"};
2701   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2702   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2703   EXPECT_EQ(2, tensors_expected.size());
2704 
2705   GraphDef got;
2706   ArithmeticOptimizer optimizer;
2707   EnableOnlyLog1p(&optimizer);
2708   OptimizeAndPrune(&optimizer, &item, &got);
2709   auto tensors = EvaluateNodes(got, item.fetch);
2710   EXPECT_EQ(2, tensors.size());
2711 
2712   for (int i = 0; i < 2; ++i) {
2713     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2714     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2715   }
2716 
2717   GraphDef want;
2718   AddNode("x1", "Const", {}, {}, &want);
2719   AddNode("x2", "Const", {}, {}, &want);
2720   AddNode("x3", "Const", {}, {}, &want);
2721   AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
2722   AddNode("out1", "Log1p",
2723           {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {},
2724           &want);
2725   AddNode("out2", "Log", {"a23"}, {}, &want);
2726 
2727   CompareGraphs(want, got);
2728 }
2729 
TEST_F(ArithmeticOptimizerTest,Expm1)2730 TEST_F(ArithmeticOptimizerTest, Expm1) {
2731   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2732 
2733   auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
2734   auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
2735   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2736   auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
2737   Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
2738   Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
2739 
2740   GrapplerItem item;
2741   item.fetch = {"out1", "out2"};
2742   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2743   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2744   EXPECT_EQ(2, tensors_expected.size());
2745 
2746   GraphDef got;
2747   ArithmeticOptimizer optimizer;
2748   EnableOnlyExpm1(&optimizer);
2749   OptimizeAndPrune(&optimizer, &item, &got);
2750   auto tensors = EvaluateNodes(got, item.fetch);
2751   EXPECT_EQ(2, tensors.size());
2752 
2753   for (int i = 0; i < 2; ++i) {
2754     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2755     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2756   }
2757 
2758   GraphDef want;
2759   AddNode("x1", "Const", {}, {}, &want);
2760   AddNode("x2", "Const", {}, {}, &want);
2761   AddNode("x3", "Const", {}, {}, &want);
2762   AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
2763   AddNode("out1", "Expm1",
2764           {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
2765           &want);
2766   AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
2767 
2768   CompareGraphs(want, got);
2769 }
2770 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)2771 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
2772   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2773 
2774   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2775   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2776   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2777 
2778   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2779   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2780 
2781   auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
2782 
2783   GrapplerItem item;
2784   item.fetch = {"outputs"};
2785   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2786 
2787   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2788   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2789   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2790   std::vector<std::pair<string, Tensor>> feed = {
2791       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2792   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2793   EXPECT_EQ(1, tensors_expected.size());
2794 
2795   GraphDef output;
2796   ArithmeticOptimizer optimizer;
2797   EnableOnlyMinimizeBroadcasts(&optimizer);
2798 
2799   OptimizeAndPrune(&optimizer, &item, &output);
2800 
2801   // We expect the following rewrite(s) to occur:
2802   //
2803   //     *                  *
2804   //    / \                / \
2805   //   *   c      -->     *   b
2806   //  / \                / \
2807   // a   b              a   c
2808   NodeMap node_map(&output);
2809 
2810   const NodeDef* mul1_node = node_map.GetNode("mul1");
2811   ASSERT_NE(mul1_node, nullptr);
2812   EXPECT_EQ("a", mul1_node->input(0));
2813   EXPECT_EQ("c", mul1_node->input(1));
2814 
2815   const NodeDef* mul2_node = node_map.GetNode("mul2");
2816   ASSERT_NE(mul2_node, nullptr);
2817   EXPECT_EQ("mul1", mul2_node->input(0));
2818   EXPECT_EQ("b", mul2_node->input(1));
2819 
2820   auto tensors = EvaluateNodes(output, item.fetch, feed);
2821   EXPECT_EQ(1, tensors.size());
2822   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2823 }
2824 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)2825 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
2826   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2827 
2828   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
2829   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
2830   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
2831   auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
2832   auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
2833 
2834   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2835   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2836   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
2837   auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
2838 
2839   auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
2840 
2841   GrapplerItem item;
2842   item.fetch = {"outputs"};
2843   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2844 
2845   auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2846   auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
2847   auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2848   auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2849   auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2850   std::vector<std::pair<string, Tensor>> feed = {
2851       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
2852   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2853   EXPECT_EQ(1, tensors_expected.size());
2854 
2855   GraphDef output;
2856   ArithmeticOptimizer optimizer;
2857   EnableOnlyMinimizeBroadcasts(&optimizer);
2858 
2859   OptimizeAndPrune(&optimizer, &item, &output);
2860 
2861   // We expect the following rewrite(s) to occur: Graph is "flattened" and
2862   // largest shape pushed to the top.
2863   //
2864   //          *
2865   //        /   \
2866   //       *     e                *
2867   //      /  \                  /   \
2868   //     *    d               *      b
2869   //    / \                 /  \
2870   //   *   c      -->     *      *
2871   //  / \                / \    / \
2872   // a   b              a   c  d   e
2873   NodeMap node_map(&output);
2874 
2875   const NodeDef* mul1_node = node_map.GetNode("mul1");
2876   ASSERT_NE(mul1_node, nullptr);
2877   EXPECT_EQ("a", mul1_node->input(0));
2878   EXPECT_EQ("c", mul1_node->input(1));
2879 
2880   const NodeDef* mul2_node = node_map.GetNode("mul2");
2881   ASSERT_NE(mul2_node, nullptr);
2882   EXPECT_EQ("d", mul2_node->input(0));
2883   EXPECT_EQ("e", mul2_node->input(1));
2884 
2885   const NodeDef* mul3_node = node_map.GetNode("mul3");
2886   ASSERT_NE(mul3_node, nullptr);
2887   EXPECT_EQ("mul1", mul3_node->input(0));
2888   EXPECT_EQ("mul2", mul3_node->input(1));
2889 
2890   const NodeDef* mul4_node = node_map.GetNode("mul4");
2891   ASSERT_NE(mul4_node, nullptr);
2892   EXPECT_EQ("mul3", mul4_node->input(0));
2893   EXPECT_EQ("b", mul4_node->input(1));
2894 
2895   auto tensors = EvaluateNodes(output, item.fetch, feed);
2896   EXPECT_EQ(1, tensors.size());
2897   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
2898 }
2899 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)2900 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
2901   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2902 
2903   // [a, b, c] - scalars, [d] - matrix
2904   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2905   auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
2906   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2907   auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
2908 
2909   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2910   auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
2911   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
2912 
2913   auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
2914 
2915   GrapplerItem item;
2916   item.fetch = {"outputs"};
2917   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2918 
2919   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2920   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2921   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2922   auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2923   std::vector<std::pair<string, Tensor>> feed = {
2924       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
2925   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2926   EXPECT_EQ(1, tensors_expected.size());
2927 
2928   GraphDef output;
2929   ArithmeticOptimizer optimizer;
2930   EnableOnlyMinimizeBroadcasts(&optimizer);
2931 
2932   OptimizeAndPrune(&optimizer, &item, &output);
2933 
2934   // We expect the following rewrite(s) to occur:
2935   //
2936   //                              *
2937   //                            /  \
2938   //       *                   *    D
2939   //     /   \                / \
2940   //    *     *      ->      *   c
2941   //   / \   / \            / \
2942   //  a   b c   D          a   b
2943   NodeMap node_map(&output);
2944 
2945   const NodeDef* mul1_node = node_map.GetNode("mul2");
2946   ASSERT_NE(mul1_node, nullptr);
2947   EXPECT_EQ("a", mul1_node->input(0));
2948   EXPECT_EQ("b", mul1_node->input(1));
2949 
2950   const NodeDef* mul2_node = node_map.GetNode("mul1");
2951   ASSERT_NE(mul2_node, nullptr);
2952   EXPECT_EQ("mul2", mul2_node->input(0));
2953   EXPECT_EQ("c", mul2_node->input(1));
2954 
2955   const NodeDef* mul3_node = node_map.GetNode("mul3");
2956   ASSERT_NE(mul3_node, nullptr);
2957   EXPECT_EQ("D", mul3_node->input(0));
2958   EXPECT_EQ("mul1", mul3_node->input(1));
2959 
2960   auto tensors = EvaluateNodes(output, item.fetch, feed);
2961   EXPECT_EQ(1, tensors.size());
2962   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2963 }
2964 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)2965 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
2966   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2967   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
2968   Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
2969   Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
2970   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
2971   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
2972   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
2973   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
2974   // Test case with chains of length 1.
2975   // Rewrites
2976   //       Concat({Exp(a), Exp(b), Exp(c)})
2977   // into
2978   //       Exp(Concat({a, b, c})).
2979   Output sin_a =
2980       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
2981   Output exp_a =
2982       ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
2983   Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
2984   Output exp_c =
2985       ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
2986   Output concat =
2987       ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
2988   Output id = ops::Identity(s.WithOpName("id"), concat);
2989 
2990   // Test case with chains of length 2.
2991   // Rewrites
2992   //       Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
2993   // into
2994   //       Cos(Exp(Concat({a, b, c}))).
2995   Output exp_a2 =
2996       ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
2997   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
2998   Output exp_c2 =
2999       ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3000   Output cos_exp_a2 = ops::Cos(
3001       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3002   Output cos_exp_b2 = ops::Cos(
3003       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3004   Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3005   Output concat2 = ops::Concat(s.WithOpName("concat2"),
3006                                {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3007   Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3008   GrapplerItem item;
3009   item.fetch = {"id", "id2"};
3010   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3011 
3012   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3013 
3014   GraphDef output;
3015   ArithmeticOptimizer optimizer;
3016   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3017   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3018 
3019   int found = 0;
3020   for (const NodeDef& node : output.node()) {
3021     if (node.name() == "concat") {
3022       EXPECT_EQ(6, node.input_size());
3023       EXPECT_EQ("sin_a", node.input(0));
3024       EXPECT_EQ("b", node.input(1));
3025       EXPECT_EQ("c", node.input(2));
3026       EXPECT_EQ("axis", node.input(3));
3027       EXPECT_EQ("^ctrl1", node.input(4));
3028       EXPECT_EQ("^ctrl2", node.input(5));
3029       found++;
3030     }
3031     if (node.name() == "exp_a") {
3032       EXPECT_EQ(2, node.input_size());
3033       EXPECT_EQ("concat", node.input(0));
3034       EXPECT_EQ("^ctrl1", node.input(1));
3035       found++;
3036     }
3037     if (node.name() == "id") {
3038       EXPECT_EQ(1, node.input_size());
3039       EXPECT_EQ("exp_a", node.input(0));
3040       found++;
3041     }
3042 
3043     if (node.name() == "concat2") {
3044       EXPECT_EQ(7, node.input_size());
3045       EXPECT_EQ("sin_a", node.input(0));
3046       EXPECT_EQ("b", node.input(1));
3047       EXPECT_EQ("c", node.input(2));
3048       EXPECT_EQ("axis", node.input(3));
3049       EXPECT_EQ("^ctrl1", node.input(4));
3050       EXPECT_EQ("^ctrl2", node.input(5));
3051       EXPECT_EQ("^ctrl3", node.input(6));
3052       found++;
3053     }
3054     if (node.name() == "exp_a2") {
3055       EXPECT_EQ(2, node.input_size());
3056       EXPECT_EQ("concat2", node.input(0));
3057       EXPECT_EQ("^ctrl1", node.input(1));
3058       found++;
3059     }
3060     if (node.name() == "cos_exp_a2") {
3061       EXPECT_EQ(2, node.input_size());
3062       EXPECT_EQ("exp_a2", node.input(0));
3063       EXPECT_EQ("^ctrl1", node.input(1));
3064       found++;
3065     }
3066     if (node.name() == "id2") {
3067       EXPECT_EQ(1, node.input_size());
3068       EXPECT_EQ("cos_exp_a2", node.input(0));
3069       found++;
3070     }
3071   }
3072   EXPECT_EQ(7, found);
3073 
3074   auto tensors = EvaluateNodes(output, item.fetch);
3075   EXPECT_EQ(tensors.size(), tensors_expected.size());
3076   EXPECT_EQ(tensors.size(), item.fetch.size());
3077   for (int i = 0; i < item.fetch.size(); ++i) {
3078     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3079   }
3080 }
3081 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3082 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3083   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3084   Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3085   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3086   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3087   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3088   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3089   // Test case with chains of length 1.
3090   // Rewrites
3091   //          [Sin(y) for y in Split(x)]
3092   // into
3093   //          [y for y in Split(Sin(x))].
3094   ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3095   Output sin_a =
3096       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3097   Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3098   Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3099   Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3100   Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3101 
3102   // Test case with SplitV and chains of length 2.
3103   // Rewrites
3104   //          [Cos(Exp(y)) for y in Split(x)]
3105   // into
3106   //          [y for y in Split(Cos(Exp(x)))].
3107   Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3108   ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3109   Output exp_a2 = ops::Exp(
3110       s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3111   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3112   Output cos_exp_a2 = ops::Cos(
3113       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3114   Output cos_exp_b2 = ops::Cos(
3115       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3116   Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3117   Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3118 
3119   GrapplerItem item;
3120   item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3121   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3122 
3123   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3124 
3125   GraphDef output;
3126   ArithmeticOptimizer optimizer;
3127   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3128   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3129 
3130   int found = 0;
3131   for (const NodeDef& node : output.node()) {
3132     // The following 6 nodes should be pruned.
3133     EXPECT_NE(node.name(), "sin_a");
3134     EXPECT_NE(node.name(), "sin_b");
3135     EXPECT_NE(node.name(), "exp_a2");
3136     EXPECT_NE(node.name(), "exp_b2");
3137     EXPECT_NE(node.name(), "cos_exp_a2");
3138     EXPECT_NE(node.name(), "cos_exp_b2");
3139 
3140     if (node.name() == "split1") {
3141       EXPECT_EQ(2, node.input_size());
3142       EXPECT_EQ("axis", node.input(0));
3143       EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
3144       found++;
3145     }
3146     if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3147       EXPECT_EQ("Sin", node.op());
3148       EXPECT_EQ(2, node.input_size());
3149       EXPECT_EQ("x", node.input(0));
3150       EXPECT_EQ("^ctrl1", node.input(1));
3151       found++;
3152     }
3153     if (node.name() == "id_a") {
3154       EXPECT_EQ(1, node.input_size());
3155       EXPECT_EQ("split1", node.input(0));
3156       found++;
3157     }
3158     if (node.name() == "exp_b") {
3159       EXPECT_EQ(1, node.input_size());
3160       EXPECT_EQ("split1:1", node.input(0));
3161       found++;
3162     }
3163     if (node.name() == "id_b") {
3164       EXPECT_EQ(1, node.input_size());
3165       EXPECT_EQ("exp_b", node.input(0));
3166       found++;
3167     }
3168     if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3169       EXPECT_EQ("Exp", node.op());
3170       EXPECT_EQ(4, node.input_size());
3171       EXPECT_EQ("x", node.input(0));
3172       EXPECT_EQ("^ctrl1", node.input(1));
3173       EXPECT_EQ("^ctrl2", node.input(2));
3174       EXPECT_EQ("^ctrl3", node.input(3));
3175       found++;
3176     }
3177     if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3178       EXPECT_EQ("Cos", node.op());
3179       EXPECT_EQ(1, node.input_size());
3180       EXPECT_EQ("ArithmeticOptimizer/_exp_a2_split2", node.input(0));
3181       found++;
3182     }
3183     if (node.name() == "split2") {
3184       EXPECT_EQ(3, node.input_size());
3185       EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
3186       EXPECT_EQ("size_splits2", node.input(1));
3187       EXPECT_EQ("axis", node.input(2));
3188       found++;
3189     }
3190     if (node.name() == "id_a2") {
3191       EXPECT_EQ(1, node.input_size());
3192       EXPECT_EQ("split2", node.input(0));
3193       found++;
3194     }
3195     if (node.name() == "id_b2") {
3196       EXPECT_EQ(1, node.input_size());
3197       EXPECT_EQ("split2:1", node.input(0));
3198       found++;
3199     }
3200   }
3201   EXPECT_EQ(10, found);
3202 
3203   auto tensors = EvaluateNodes(output, item.fetch);
3204   EXPECT_EQ(tensors.size(), tensors_expected.size());
3205   EXPECT_EQ(tensors.size(), item.fetch.size());
3206   for (int i = 0; i < item.fetch.size(); ++i) {
3207     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3208   }
3209 }
3210 
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3211 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3212   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3213   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3214   Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3215   Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3216   Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3217   Output id1 = ops::Identity(s.WithOpName("id1"), a);
3218   Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3219   Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3220   GrapplerItem item;
3221   item.fetch = {"out1", "out2"};
3222   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3223 
3224   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3225 
3226   GraphDef output;
3227   ArithmeticOptimizer optimizer;
3228   EnableOnlyRemoveIdempotent(&optimizer);
3229   OptimizeTwice(&optimizer, &item, &output);
3230 
3231   EXPECT_EQ(7, output.node_size());
3232   int found = 0;
3233   for (const NodeDef& node : output.node()) {
3234     if (node.name() == "out1") {
3235       EXPECT_EQ(1, node.input_size());
3236       EXPECT_EQ("sn1", node.input(0));
3237       found++;
3238     } else if (node.name() == "out2") {
3239       EXPECT_EQ(1, node.input_size());
3240       EXPECT_EQ("id1", node.input(0));
3241       found++;
3242     } else if (node.name() == "sn1") {
3243       EXPECT_EQ(1, node.input_size());
3244       EXPECT_EQ("a", node.input(0));
3245       found++;
3246     }
3247   }
3248   EXPECT_EQ(3, found);
3249 
3250   auto tensors = EvaluateNodes(output, item.fetch);
3251   EXPECT_EQ(tensors.size(), tensors_expected.size());
3252   EXPECT_EQ(tensors.size(), item.fetch.size());
3253   for (int i = 0; i < item.fetch.size(); ++i) {
3254     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3255   }
3256 }
3257 
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3258 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3259   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3260   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3261   Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3262   Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3263   Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3264   Output lt = ops::Less(s.WithOpName("lt"), a, b);
3265   Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3266   Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3267   Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3268   // not_eq is reserved
3269   Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3270   Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3271   Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3272   Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3273   Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3274   Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3275   Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3276   Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3277   Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3278   Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3279   Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3280   Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3281 
3282   GrapplerItem item;
3283   item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3284                 "id_not_le", "id_not_gt",  "id_not_ge"};
3285   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3286 
3287   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3288 
3289   GraphDef output;
3290   ArithmeticOptimizer optimizer;
3291   EnableOnlyRemoveLogicalNot(&optimizer);
3292   OptimizeTwice(&optimizer, &item, &output);
3293 
3294   int found = 0;
3295   for (const NodeDef& node : output.node()) {
3296     if (node.name() == "id_not_eq") {
3297       EXPECT_EQ("eq", node.input(0));
3298       ++found;
3299     }
3300     if (node.name() == "id_not_neq") {
3301       EXPECT_EQ("neq", node.input(0));
3302       ++found;
3303     }
3304     if (node.name() == "id_not_lt") {
3305       EXPECT_EQ("lt", node.input(0));
3306       ++found;
3307     }
3308     if (node.name() == "id_not_le") {
3309       EXPECT_EQ("le", node.input(0));
3310       ++found;
3311     }
3312     if (node.name() == "id_not_gt") {
3313       EXPECT_EQ("gt", node.input(0));
3314       ++found;
3315     }
3316     if (node.name() == "id_not_ge") {
3317       EXPECT_EQ("ge", node.input(0));
3318       ++found;
3319     }
3320 
3321     if (node.name() == "eq") {
3322       EXPECT_EQ("NotEqual", node.op());
3323       ++found;
3324     }
3325     if (node.name() == "neq") {
3326       EXPECT_EQ("Equal", node.op());
3327       ++found;
3328     }
3329     if (node.name() == "lt") {
3330       EXPECT_EQ("GreaterEqual", node.op());
3331       ++found;
3332     }
3333     if (node.name() == "le") {
3334       EXPECT_EQ("Greater", node.op());
3335       ++found;
3336     }
3337     if (node.name() == "gt") {
3338       EXPECT_EQ("LessEqual", node.op());
3339       ++found;
3340     }
3341     if (node.name() == "ge") {
3342       EXPECT_EQ("Less", node.op());
3343       ++found;
3344     }
3345   }
3346   EXPECT_EQ(12, found);
3347 
3348   auto tensors = EvaluateNodes(output, item.fetch);
3349   EXPECT_EQ(tensors.size(), tensors_expected.size());
3350   EXPECT_EQ(tensors.size(), item.fetch.size());
3351   for (int i = 0; i < item.fetch.size(); ++i) {
3352     test::ExpectTensorEqual<bool>(tensors_expected[i], tensors[i]);
3353   }
3354 }
3355 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3356 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3357   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3358   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3359   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3360   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3361   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3362 
3363   GrapplerItem item;
3364   item.fetch = {"final_out"};
3365   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3366   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3367   EXPECT_EQ(1, tensors_expected.size());
3368 
3369   GraphDef output;
3370   ArithmeticOptimizer optimizer;
3371   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3372   OptimizeAndPrune(&optimizer, &item, &output);
3373   auto tensors = EvaluateNodes(output, item.fetch);
3374   EXPECT_EQ(1, tensors.size());
3375 
3376   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3377   EXPECT_EQ(item.graph.node_size(), output.node_size());
3378   // Check if the inputs are switched
3379   int required_node_count = 0;
3380   for (int i = 0; i < output.node_size(); ++i) {
3381     const NodeDef& node = output.node(i);
3382     if (node.name() == "sqrt") {
3383       EXPECT_EQ("Sqrt", node.op());
3384       EXPECT_EQ(1, node.input_size());
3385       EXPECT_EQ("reduce_max", node.input(0));
3386       ++required_node_count;
3387     } else if (node.name() == "reduce_max") {
3388       EXPECT_EQ("Max", node.op());
3389       EXPECT_EQ(2, node.input_size());
3390       EXPECT_EQ("x", node.input(0));
3391       ++required_node_count;
3392     }
3393   }
3394   EXPECT_EQ(2, required_node_count);
3395 }
3396 
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3397 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3398   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3399   const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3400   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3401   Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3402   Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3403 
3404   GrapplerItem item;
3405   item.fetch = {"final_out"};
3406   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3407   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3408   EXPECT_EQ(1, tensors_expected.size());
3409 
3410   GraphDef output;
3411   ArithmeticOptimizer optimizer;
3412   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3413   OptimizeAndPrune(&optimizer, &item, &output);
3414   const auto tensors = EvaluateNodes(output, item.fetch);
3415   EXPECT_EQ(1, tensors.size());
3416 
3417   test::ExpectTensorEqual<int64>(tensors_expected[0], tensors[0]);
3418   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
3419   // Check if the inputs are switched
3420   int required_node_count = 0;
3421   for (int i = 0; i < output.node_size(); ++i) {
3422     const NodeDef& node = output.node(i);
3423     if (node.name() == "final_out") {
3424       EXPECT_EQ("Identity", node.op());
3425       EXPECT_EQ(1, node.input_size());
3426       EXPECT_EQ("arg_max", node.input(0));
3427       ++required_node_count;
3428     } else if (node.name() == "arg_max") {
3429       EXPECT_EQ("ArgMax", node.op());
3430       EXPECT_EQ(2, node.input_size());
3431       EXPECT_EQ("x", node.input(0));
3432       ++required_node_count;
3433     }
3434   }
3435   EXPECT_EQ(2, required_node_count);
3436 }
3437 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode)3438 TEST_F(ArithmeticOptimizerTest,
3439        OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
3440   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3441   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3442   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3443   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3444   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3445 
3446   GrapplerItem item;
3447   item.fetch = {"sqrt", "final_out"};
3448   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3449   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3450   EXPECT_EQ(2, tensors_expected.size());
3451 
3452   GraphDef output;
3453   ArithmeticOptimizer optimizer;
3454   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3455   OptimizeTwice(&optimizer, &item, &output);
3456 
3457   // Should be a NoOp since we are not allowed to change the output of fetch
3458   // nodes.
3459   VerifyGraphsMatch(item.graph, output, __LINE__);
3460 }
3461 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)3462 TEST_F(ArithmeticOptimizerTest,
3463        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3464   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3465   auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
3466   Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
3467   Output y = ops::Neg(s.WithOpName("y"), reshape);
3468   Output z = ops::Max(s.WithOpName("z"), y, {0});
3469 
3470   GrapplerItem item;
3471   item.fetch = {"z"};
3472   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3473   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3474   ASSERT_EQ(1, tensors_expected.size());
3475 
3476   GraphDef output;
3477   ArithmeticOptimizer optimizer;
3478   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3479   OptimizeTwice(&optimizer, &item, &output);
3480 
3481   // Should be a NoOp since we are not allowed to change the output of fetch
3482   // nodes.
3483   VerifyGraphsMatch(item.graph, output, __LINE__);
3484 
3485   auto tensors = EvaluateNodes(output, item.fetch);
3486   ASSERT_EQ(1, tensors.size());
3487   test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
3488   test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
3489 }
3490 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)3491 TEST_F(ArithmeticOptimizerTest,
3492        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
3493   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3494   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3495   Output neg = ops::Neg(s.WithOpName("neg"), x);
3496   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
3497   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3498 
3499   GrapplerItem item;
3500   item.fetch = {"final_out"};
3501   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3502   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3503   EXPECT_EQ(1, tensors_expected.size());
3504 
3505   GraphDef output;
3506   ArithmeticOptimizer optimizer;
3507   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3508   OptimizeAndPrune(&optimizer, &item, &output);
3509   auto tensors = EvaluateNodes(output, item.fetch);
3510   EXPECT_EQ(1, tensors.size());
3511 
3512   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3513   EXPECT_EQ(item.graph.node_size(), output.node_size());
3514   // Check if the inputs are switched
3515   int required_node_count = 0;
3516   for (int i = 0; i < output.node_size(); ++i) {
3517     const NodeDef& node = output.node(i);
3518     if (node.name() == "neg") {
3519       EXPECT_EQ("Neg", node.op());
3520       EXPECT_EQ(1, node.input_size());
3521       EXPECT_EQ("reduce_max", node.input(0));
3522       ++required_node_count;
3523     } else if (node.name() == "reduce_max") {
3524       EXPECT_EQ("Min", node.op());
3525       EXPECT_EQ(2, node.input_size());
3526       EXPECT_EQ("x", node.input(0));
3527       ++required_node_count;
3528     }
3529   }
3530   EXPECT_EQ(2, required_node_count);
3531 }
3532 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)3533 TEST_F(ArithmeticOptimizerTest,
3534        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
3535   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3536   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3537   Output neg = ops::Neg(s.WithOpName("neg"), x);
3538   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
3539                                  {1, 2, 2, 1}, "VALID");
3540 
3541   GrapplerItem item;
3542   item.fetch = {"max_pool"};
3543   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3544   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3545   ASSERT_EQ(1, tensors_expected.size());
3546 
3547   GraphDef output;
3548   ArithmeticOptimizer optimizer;
3549   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3550   OptimizeTwice(&optimizer, &item, &output);
3551 
3552   // Should be a NoOp
3553   VerifyGraphsMatch(item.graph, output, __LINE__);
3554 
3555   auto tensors = EvaluateNodes(output, item.fetch);
3556   ASSERT_EQ(1, tensors.size());
3557   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3558 }
3559 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)3560 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
3561   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3562   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3563   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3564   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
3565                                  {1, 2, 2, 1}, "VALID");
3566   Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
3567 
3568   GrapplerItem item;
3569   item.fetch = {"final_out"};
3570   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3571   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3572   EXPECT_EQ(1, tensors_expected.size());
3573 
3574   GraphDef output;
3575   ArithmeticOptimizer optimizer;
3576   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3577   OptimizeAndPrune(&optimizer, &item, &output);
3578   auto tensors = EvaluateNodes(output, item.fetch);
3579   EXPECT_EQ(1, tensors.size());
3580 
3581   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3582   EXPECT_EQ(item.graph.node_size(), output.node_size());
3583   // Check if the inputs are switched
3584   int required_node_count = 0;
3585   for (int i = 0; i < output.node_size(); ++i) {
3586     const NodeDef& node = output.node(i);
3587     if (node.name() == "sqrt") {
3588       EXPECT_EQ("Sqrt", node.op());
3589       EXPECT_EQ(1, node.input_size());
3590       EXPECT_EQ("max_pool", node.input(0));
3591       ++required_node_count;
3592     } else if (node.name() == "max_pool") {
3593       EXPECT_EQ("MaxPool", node.op());
3594       EXPECT_EQ(1, node.input_size());
3595       EXPECT_EQ("x", node.input(0));
3596       ++required_node_count;
3597     }
3598   }
3599   EXPECT_EQ(2, required_node_count);
3600 }
3601 
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)3602 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
3603   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3604 
3605   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3606   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3607   Output log = ops::Log(s.WithOpName("log"), sqrt);
3608   Output relu = ops::Relu(s.WithOpName("relu"), log);
3609   Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
3610 
3611   GrapplerItem item;
3612   item.fetch = {"final_out"};
3613   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3614 
3615   // Place all nodes on CPU.
3616   for (int i = 0; i < item.graph.node_size(); ++i) {
3617     item.graph.mutable_node(i)->set_device("/device:CPU:0");
3618   }
3619 
3620   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3621   EXPECT_EQ(1, tensors_expected.size());
3622 
3623   GraphDef output;
3624   ArithmeticOptimizer optimizer;
3625   EnableOnlyUnaryOpsComposition(&optimizer);
3626   OptimizeAndPrune(&optimizer, &item, &output);
3627 
3628   EXPECT_EQ(3, output.node_size());
3629 
3630   // Check that Sqrt/Log/Relu were replaced with a single op.
3631   int required_node_count = 0;
3632   for (int i = 0; i < output.node_size(); ++i) {
3633     const NodeDef& node = output.node(i);
3634     if (node.name() == "final_out") {
3635       EXPECT_EQ("Identity", node.op());
3636       EXPECT_EQ(1, node.input_size());
3637       EXPECT_EQ("relu/unary_ops_composition", node.input(0));
3638       ++required_node_count;
3639     } else if (node.name() == "relu/unary_ops_composition") {
3640       EXPECT_EQ("_UnaryOpsComposition", node.op());
3641       EXPECT_EQ(1, node.input_size());
3642       EXPECT_EQ("x", node.input(0));
3643 
3644       auto op_names = node.attr().at("op_names").list().s();
3645       EXPECT_EQ(3, op_names.size());
3646       EXPECT_EQ("Sqrt", op_names[0]);
3647       EXPECT_EQ("Log", op_names[1]);
3648       EXPECT_EQ("Relu", op_names[2]);
3649       ++required_node_count;
3650     }
3651   }
3652   EXPECT_EQ(2, required_node_count);
3653 
3654   auto tensors = EvaluateNodes(output, item.fetch);
3655   EXPECT_EQ(1, tensors.size());
3656   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3657 }
3658 
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)3659 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
3660   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3661   auto a_in =
3662       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
3663   auto b_in =
3664       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
3665   auto c_in =
3666       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
3667   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
3668                                        PartialTensorShape({-1, -1}));
3669   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
3670                                        PartialTensorShape({-1, -1}));
3671   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
3672                                        PartialTensorShape({-1, -1}));
3673   // stacked = tf.stack((a, b, c), axis=1).
3674   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
3675   auto stacked =
3676       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
3677                  ops::Stack::Axis(1));
3678   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
3679   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
3680   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
3681   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
3682   auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
3683   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
3684   auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
3685   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
3686   auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
3687   auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
3688   auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
3689 
3690   // stacked[:, 0]
3691   using SS = ops::StridedSlice;
3692   auto pa_slice = ops::Identity(
3693       s.WithOpName("pa_slice_out"),
3694       SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
3695          SS::BeginMask(0b0101)  // 5
3696              .EllipsisMask(0)
3697              .EndMask(0b0101)  // 5
3698              .NewAxisMask(0)
3699              .ShrinkAxisMask(0b0010)));  // 2
3700 
3701   // stacked[:, 1]
3702   auto pb_slice = ops::Identity(
3703       s.WithOpName("pb_slice_out"),
3704       SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
3705          SS::BeginMask(0b0101)  // 5
3706              .EllipsisMask(0)
3707              .EndMask(0b0101)  // 5
3708              .NewAxisMask(0)
3709              .ShrinkAxisMask(0b0010)));  // 2
3710 
3711   // stacked[:, 2]
3712   auto pc_slice = ops::Identity(
3713       s.WithOpName("pc_slice_out"),
3714       SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
3715          SS::BeginMask(0b0101)  // 5
3716              .EllipsisMask(0)
3717              .EndMask(0b0101)  // 5
3718              .NewAxisMask(0)
3719              .ShrinkAxisMask(0b0010)));  // 2
3720 
3721   // stacked[:, 0:1, :]
3722   auto pa_slice_01 = ops::Identity(
3723       s.WithOpName("pa_slice_01_out"),
3724       SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
3725          SS::BeginMask(0b0101)  // 5
3726              .EllipsisMask(0)
3727              .EndMask(0b0101)  // 5
3728              .NewAxisMask(0)
3729              .ShrinkAxisMask(0)));
3730 
3731   // stacked[:, :1, :]
3732   auto pa_slice_to1 = ops::Identity(
3733       s.WithOpName("pa_slice_to1_out"),
3734       SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
3735          SS::BeginMask(0b0111)  // 7
3736              .EllipsisMask(0)
3737              .EndMask(0b0101)  // 5
3738              .NewAxisMask(0)
3739              .ShrinkAxisMask(0)));
3740 
3741   // stacked[:, 1:2, :]
3742   auto pb_slice_12 = ops::Identity(
3743       s.WithOpName("pb_slice_12_out"),
3744       SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
3745          SS::BeginMask(0b0101)  // 5
3746              .EllipsisMask(0)
3747              .EndMask(0b0101)  // 5
3748              .NewAxisMask(0)
3749              .ShrinkAxisMask(0)));
3750 
3751   // stacked[:, 2:, :].
3752   auto pc_slice_2to = ops::Identity(
3753       s.WithOpName("pc_slice_2to_out"),
3754       SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
3755          SS::BeginMask(0b0101)  // 5
3756              .EllipsisMask(0)
3757              .EndMask(0b0111)  // 7
3758              .NewAxisMask(0)
3759              .ShrinkAxisMask(0)));
3760 
3761   GrapplerItem item;
3762   item.fetch = {"a",
3763                 "b",
3764                 "c",
3765                 "pa_slice_out",
3766                 "pb_slice_out",
3767                 "pc_slice_out",
3768                 "expanded_a",
3769                 "expanded_b",
3770                 "expanded_c",
3771                 "pa_slice_01_out",
3772                 "pa_slice_to1_out",
3773                 "pb_slice_12_out",
3774                 "pc_slice_2to_out"};
3775   enum FetchItem {
3776     fA,
3777     fB,
3778     fC,
3779     fASliceOut,
3780     fBSliceOut,
3781     fCSliceOut,
3782     fExpandedA,
3783     fExpandedB,
3784     fExpandedC,
3785     fASlice01Out,
3786     fASliceTo1Out,
3787     fBSlice12Out,
3788     fCSlice2ToOut,
3789   };
3790   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3791   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3792 
3793   // stacked[:, 0, :] == a.
3794   test::ExpectTensorEqual<float>(tensors_expected[fA],
3795                                  tensors_expected[fASliceOut]);
3796   // stacked[:, 1, :] == b.
3797   test::ExpectTensorEqual<float>(tensors_expected[fB],
3798                                  tensors_expected[fBSliceOut]);
3799   // stacked[:, 2, :] == c.
3800   test::ExpectTensorEqual<float>(tensors_expected[fC],
3801                                  tensors_expected[fCSliceOut]);
3802 
3803   // stacked[:, 0:1, :] == expand_dims(a, 1).
3804   test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3805                                  tensors_expected[fASlice01Out]);
3806 
3807   // stacked[:, :1, :] == expand_dims(a, 1).
3808   test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3809                                  tensors_expected[fASliceTo1Out]);
3810 
3811   // stacked[:, 1:2, :] == expand_dims(b, 1).
3812   test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
3813                                  tensors_expected[fBSlice12Out]);
3814   // stacked[:, 2:, :] == expand_dims(c, 1).
3815   test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
3816                                  tensors_expected[fCSlice2ToOut]);
3817 
3818   GraphDef output;
3819   ArithmeticOptimizer optimizer;
3820   EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
3821   OptimizeAndPrune(&optimizer, &item, &output);
3822 
3823   for (const auto& node : output.node()) {
3824     if (node.name() == "pa_slice_out") {
3825       EXPECT_EQ(node.input(0), "a");
3826     } else if (node.name() == "pb_slice_out") {
3827       EXPECT_EQ(node.input(0), "b");
3828     } else if (node.name() == "pc_slice_out") {
3829       EXPECT_EQ(node.input(0), "c");
3830     } else if (str_util::EndsWith(node.name(), "_out")) {
3831       EXPECT_EQ(strings::StrCat(node.input(0), "_out"),
3832                 strings::StrCat(
3833                     "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
3834                     node.name()));
3835     }
3836   }
3837 
3838   auto tensors = EvaluateNodes(output, item.fetch);
3839 
3840   // stacked[:, 0, :] == a.
3841   test::ExpectTensorEqual<float>(tensors_expected[fA], tensors[fASliceOut]);
3842 
3843   // stacked[:, 1, :] == b.
3844   test::ExpectTensorEqual<float>(tensors_expected[fB], tensors[fBSliceOut]);
3845   // stacked[:, 2, :] == c.
3846   test::ExpectTensorEqual<float>(tensors_expected[fC], tensors[fCSliceOut]);
3847 
3848   // stacked[:, 0:1, :] == expand_dims(a, 1).
3849   test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3850                                  tensors[fASlice01Out]);
3851 
3852   // stacked[:, :1, :] == expand_dims(a, 1).
3853   test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3854                                  tensors[fASliceTo1Out]);
3855 
3856   // stacked[:, 1:2, :] == expand_dims(b, 1).
3857   test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
3858                                  tensors[fBSlice12Out]);
3859   // stacked[:, 2:, :] == expand_dims(c, 1).
3860   test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
3861                                  tensors[fCSlice2ToOut]);
3862 }
3863 
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)3864 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
3865   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3866   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3867   Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
3868   Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
3869   Output id = ops::Identity(s.WithOpName("id"), add);
3870 
3871   GrapplerItem item;
3872   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3873   item.fetch = {"id"};
3874   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3875   EXPECT_EQ(1, tensors_expected.size());
3876 
3877   GraphDef output;
3878   ArithmeticOptimizer optimizer;
3879   EnableOnlySimplifyAggregation(&optimizer);
3880   OptimizeAndPrune(&optimizer, &item, &output);
3881 
3882   // Extra node created for multiplier.
3883   EXPECT_EQ(5, output.node_size());
3884 
3885   auto tensors = EvaluateNodes(output, item.fetch);
3886   EXPECT_EQ(1, tensors.size());
3887   test::ExpectTensorEqual<bfloat16>(tensors_expected[0], tensors[0]);
3888 }
3889 
3890 }  // namespace grappler
3891 }  // namespace tensorflow
3892