1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <complex>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
30 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 
35 namespace tensorflow {
36 namespace grappler {
37 
38 namespace {
39 
40 constexpr char kHoistFactorOptimizerDiv[] =
41     "ArithmeticOptimizer/HoistCommonFactor_Div_";
42 
43 constexpr char kHoistFactorOptimizerMul[] =
44     "ArithmeticOptimizer/HoistCommonFactor_Mul_";
45 
46 constexpr char kHoistFactorOptimizerAdd[] =
47     "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
48 
49 constexpr char kSimplifyAggregationConst[] =
50     "ArithmeticOptimizer/SimplifyAggregation_Const_";
51 
52 constexpr char kSimplifyAggregationMul[] =
53     "ArithmeticOptimizer/SimplifyAggregation_Mul_";
54 
55 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)56 string HoistMulName(const string& name) {
57   return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
58 }
59 
60 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)61 string HoistDivName(const string& name) {
62   return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
63 }
64 
65 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)66 string HoistAddName(const string& name) {
67   return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
68 }
69 
70 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)71 string AggregationConstName(const string& name) {
72   return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
73 }
74 
75 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)76 string AggregationMulName(const string& name) {
77   return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
78 }
79 
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)80 void VerifyGraphsMatch(const GraphDef& original_graph,
81                        const GraphDef& optimized_graph, int line) {
82   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
83   for (int i = 0; i < original_graph.node_size(); ++i) {
84     const NodeDef& original = original_graph.node(i);
85     const NodeDef& optimized = optimized_graph.node(i);
86     EXPECT_EQ(original.name(), optimized.name()) << line;
87     EXPECT_EQ(original.op(), optimized.op()) << line;
88     EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
89     for (int j = 0; j < original.input_size(); ++j) {
90       EXPECT_EQ(original.input(j), optimized.input(j)) << line;
91     }
92   }
93 }
94 }  // namespace
95 
TEST_F(ArithmeticOptimizerTest,NoOp)96 TEST_F(ArithmeticOptimizerTest, NoOp) {
97   // This trivial graph is so basic there's nothing to optimize.
98   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
99   GrapplerItem item;
100   CHECK(fake_input.NextItem(&item));
101 
102   ArithmeticOptimizer optimizer;
103   GraphDef output;
104   Status status = optimizer.Optimize(nullptr, item, &output);
105   TF_EXPECT_OK(status);
106   VerifyGraphsMatch(item.graph, output, __LINE__);
107 }
108 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)109 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
110   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
111   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
112   Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
113   Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
114   Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
115   Output id = ops::Identity(s.WithOpName("id"), mul);
116   Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
117 
118   GrapplerItem item;
119   item.fetch = {"id", "id2"};
120   TF_CHECK_OK(s.ToGraphDef(&item.graph));
121   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
122   ASSERT_EQ(tensors_expected.size(), 2);
123 
124   GraphDef output;
125   ArithmeticOptimizer optimizer;
126   EnableOnlyReplaceMulWithSquare(&optimizer);
127   OptimizeAndPrune(&optimizer, &item, &output);
128 
129   EXPECT_EQ(output.node_size(), 6);
130 
131   NodeMap node_map(&output);
132   const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
133   const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
134 
135   ASSERT_NE(square_node, nullptr);
136   EXPECT_EQ(square_node->op(), "Square");
137   ASSERT_EQ(square_node->input_size(), 2);
138   EXPECT_EQ(square_node->input(0), "c");
139   EXPECT_EQ(square_node->input(1), "^d");
140 
141   const NodeDef* square_node2 =
142       node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
143   ASSERT_NE(square_node2, nullptr);
144   EXPECT_EQ(square_node2->op(), "Square");
145   ASSERT_EQ(square_node2->input_size(), 1);
146   EXPECT_EQ(square_node2->input(0), "d");
147 
148   auto tensors = EvaluateNodes(output, item.fetch);
149   ASSERT_EQ(tensors.size(), 2);
150   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
151 }
152 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)153 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
154   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
155 
156   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
157   auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
158   auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
159   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
160   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
161   auto id = ops::Identity(s.WithOpName("id"), recip2);
162 
163   GrapplerItem item;
164   item.fetch = {"id"};
165   TF_CHECK_OK(s.ToGraphDef(&item.graph));
166   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
167   ASSERT_EQ(tensors_expected.size(), 1);
168 
169   GraphDef output;
170   ArithmeticOptimizer optimizer;
171   EnableOnlyRemoveInvolution(&optimizer);
172   OptimizeAndPrune(&optimizer, &item, &output);
173 
174   // Negation and Reciprocal nodes cancelled each other.
175   ASSERT_EQ(output.node_size(), 2);
176   EXPECT_EQ(output.node(1).name(), "id");
177   ASSERT_EQ(output.node(1).input_size(), 1);
178   EXPECT_EQ(output.node(1).input(0), "c");
179 
180   auto tensors = EvaluateNodes(output, item.fetch);
181   ASSERT_EQ(tensors.size(), 1);
182   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
183 }
184 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)185 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
186   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
187 
188   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
189   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
190   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
191   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
192   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
193   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
194 
195   std::vector<string> fetch = {"id2"};
196 
197   GrapplerItem item;
198   item.fetch = fetch;
199   TF_CHECK_OK(s.ToGraphDef(&item.graph));
200   auto tensors_expected = EvaluateNodes(item.graph, fetch);
201   ASSERT_EQ(tensors_expected.size(), 1);
202 
203   GraphDef output;
204   ArithmeticOptimizer optimizer;
205   EnableOnlyRemoveInvolution(&optimizer);
206   OptimizeTwiceAndPrune(&optimizer, &item, &output);
207 
208   // Check that Reciprocal nodes were removed from the graph.
209   EXPECT_EQ(output.node_size(), 3);
210 
211   // And const directly flows into squeeze.
212   int found = 0;
213   for (const NodeDef& node : output.node()) {
214     if (node.name() == "squeeze") {
215       ASSERT_EQ(node.input_size(), 1);
216       EXPECT_EQ(node.input(0), "c");
217       found++;
218     } else if (node.name() == "id2") {
219       ASSERT_EQ(node.input_size(), 1);
220       EXPECT_EQ(node.input(0), "squeeze");
221       found++;
222     }
223   }
224   EXPECT_EQ(found, 2);
225 
226   auto tensors = EvaluateNodes(output, fetch);
227   ASSERT_EQ(tensors.size(), 1);
228   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
229 }
230 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)231 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
232   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
233 
234   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
235   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
236   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
237   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
238   auto recip2 = ops::Reciprocal(
239       s.WithOpName("recip2").WithControlDependencies(squeeze), c);
240   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
241 
242   std::vector<string> fetch = {"id2"};
243 
244   GrapplerItem item;
245   item.fetch = fetch;
246   TF_CHECK_OK(s.ToGraphDef(&item.graph));
247 
248   auto tensors_expected = EvaluateNodes(item.graph, fetch);
249   ASSERT_EQ(tensors_expected.size(), 1);
250 
251   GraphDef output;
252   ArithmeticOptimizer optimizer;
253   EnableOnlyRemoveInvolution(&optimizer);
254   OptimizeTwice(&optimizer, &item, &output);  // do not prune in this test
255 
256   // The optimizer should be a noop.
257   VerifyGraphsMatch(item.graph, output, __LINE__);
258 
259   auto tensors = EvaluateNodes(output, fetch);
260   ASSERT_EQ(tensors.size(), 1);
261   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
262 }
263 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)264 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
265   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
266   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
267   Output add = ops::Add(s.WithOpName("add"), x, x);
268   Output id = ops::Identity(s.WithOpName("id"), add);
269 
270   GrapplerItem item;
271   item.fetch = {"id"};
272   TF_CHECK_OK(s.ToGraphDef(&item.graph));
273 
274   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
275   ASSERT_EQ(tensors_expected.size(), 1);
276 
277   ArithmeticOptimizer optimizer;
278   GraphDef output;
279   OptimizeTwice(&optimizer, &item, &output);
280   NodeMap node_map(&output);
281 
282   EXPECT_EQ(output.node_size(), 5);
283 
284   const string optimized_const_name = AggregationConstName("add");
285   const string optimized_mul_name = AggregationMulName("add");
286 
287   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
288   ASSERT_NE(new_const, nullptr);
289   ASSERT_EQ(new_const->input_size(), 1);
290   EXPECT_EQ(new_const->input(0), "^x");
291   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
292             string("\0\0\0@", 4));
293 
294   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
295   ASSERT_NE(new_mul, nullptr);
296   ASSERT_EQ(new_mul->input_size(), 2);
297   EXPECT_EQ(new_mul->input(0), optimized_const_name);
298   EXPECT_EQ(new_mul->input(1), "x");
299 
300   const NodeDef* new_id = node_map.GetNode("id");
301   ASSERT_NE(new_id, nullptr);
302   ASSERT_EQ(new_id->input_size(), 1);
303   EXPECT_EQ(new_id->input(0), optimized_mul_name);
304 
305   auto tensors = EvaluateNodes(output, item.fetch);
306   ASSERT_EQ(tensors.size(), 1);
307   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
308 }
309 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)310 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
311   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
312   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
313   Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
314   Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
315   Output id = ops::Identity(s.WithOpName("id"), add);
316 
317   GrapplerItem item;
318   TF_CHECK_OK(s.ToGraphDef(&item.graph));
319 
320   std::vector<string> fetch = {"id"};
321   auto tensors_expected = EvaluateNodes(item.graph, fetch);
322   ASSERT_EQ(tensors_expected.size(), 1);
323 
324   ArithmeticOptimizer optimizer;
325   GraphDef output;
326   OptimizeTwice(&optimizer, &item, &output);
327   NodeMap node_map(&output);
328 
329   EXPECT_EQ(output.node_size(), 6);
330 
331   const string optimized_const_name = AggregationConstName("add");
332   const string optimized_mul_name = AggregationMulName("add");
333 
334   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
335   ASSERT_NE(new_const, nullptr);
336   ASSERT_EQ(new_const->input_size(), 1);
337   EXPECT_EQ(new_const->input(0), "^x");
338   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
339             string("\0\0\0@", 4));
340 
341   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
342   ASSERT_NE(new_mul, nullptr);
343   ASSERT_EQ(new_mul->input_size(), 3);
344   EXPECT_EQ(new_mul->input(0), optimized_const_name);
345   EXPECT_EQ(new_mul->input(1), "x");
346   EXPECT_EQ(new_mul->input(2), "^y");
347 
348   const NodeDef* new_id = node_map.GetNode("id");
349   ASSERT_NE(new_id, nullptr);
350   ASSERT_EQ(new_id->input_size(), 1);
351   EXPECT_EQ(new_id->input(0), optimized_mul_name);
352 
353   auto tensors = EvaluateNodes(output, fetch);
354   ASSERT_EQ(tensors.size(), 1);
355   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
356 }
357 
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)358 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
359   // Test case from b/69059093.
360   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361   Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
362   Output add = ops::Add(s.WithOpName("Add"), p, p);
363   Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
364   Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
365   Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
366   Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
367   Output id = ops::Identity(s.WithOpName("id"), add6);
368 
369   GrapplerItem item;
370   item.fetch = {"id"};
371   TF_CHECK_OK(s.ToGraphDef(&item.graph));
372 
373   const std::vector<string> devices{
374       "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
375       "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
376   };
377   for (int i = 0; i < item.graph.node_size(); ++i) {
378     item.graph.mutable_node(i)->set_device(devices[i]);
379   }
380 
381   ArithmeticOptimizer optimizer;
382   DisableAddToAddNCombining(&optimizer);
383 
384   GraphDef output;
385   DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
386 
387   // We expect the following rewrite(s) to occur:
388   //
389   // Mul(p,
390   //     Add_6(Add_4(Const(2), Const(2)),
391   //           Add_5(Const(2), Const(2)))
392   NodeMap node_map(&output);
393 
394   EXPECT_EQ(output.node_size(), 8);
395 
396   const NodeDef* id_node = node_map.GetNode("id");
397   ASSERT_NE(id_node, nullptr);
398   ASSERT_EQ(id_node->input_size(), 1);
399   EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
400 
401   const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
402   ASSERT_NE(mul_node, nullptr);
403   ASSERT_EQ(mul_node->input_size(), 2);
404   EXPECT_EQ(mul_node->input(0), "Placeholder");
405   EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
406 
407   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
408   ASSERT_NE(add_6_node, nullptr);
409   ASSERT_EQ(add_6_node->input_size(), 2);
410   EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
411   EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
412 
413   const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
414   ASSERT_NE(add_4_node, nullptr);
415   EXPECT_EQ(add_4_node->op(), "Add");
416   ASSERT_EQ(2, add_4_node->input_size());
417   EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
418   EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
419 
420   const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
421   ASSERT_NE(add_5_node, nullptr);
422   EXPECT_EQ(add_5_node->op(), "Add");
423   ASSERT_EQ(add_5_node->input_size(), 2);
424   EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
425   EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
426 
427   const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
428   ASSERT_NE(add_const_node, nullptr);
429   EXPECT_EQ(add_const_node->op(), "Const");
430   ASSERT_EQ(add_const_node->input_size(), 1);
431   EXPECT_EQ(add_const_node->input(0), "^Placeholder");
432 
433   const NodeDef* add_1_const_node =
434       node_map.GetNode(AggregationConstName("Add_1"));
435   ASSERT_NE(add_1_const_node, nullptr);
436   EXPECT_EQ(add_1_const_node->op(), "Const");
437   ASSERT_EQ(add_1_const_node->input_size(), 1);
438   EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
439 }
440 
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)441 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
442   for (bool matching_shapes : {true, false}) {
443     for (bool use_addn : {true, false}) {
444       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
445       Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
446       Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
447       Output y2 = matching_shapes
448                       ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
449                       : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
450       Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
451       Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
452       Output id =
453           use_addn ? ops::Identity(s.WithOpName("id"),
454                                    ops::AddN(s.WithOpName("add"), {mul1, mul2}))
455                    : ops::Identity(s.WithOpName("id"),
456                                    ops::Add(s.WithOpName("add"), mul1, mul2));
457 
458       GrapplerItem item;
459       item.fetch = {"id"};
460       TF_CHECK_OK(s.ToGraphDef(&item.graph));
461       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
462       ASSERT_EQ(tensors_expected.size(), 1);
463       ArithmeticOptimizer optimizer;
464       EnableOnlyHoistCommonFactor(&optimizer);
465 
466       GraphDef output;
467       OptimizeTwice(&optimizer, &item, &output);
468 
469       // We expect the following rewrite(s) to occur:
470       //
471       //        Add                 Mul
472       //      /    \               /   \
473       //    Mul    Mul       ->   x    Add
474       //    / \    / \                 / \
475       //   x  y1  y2  x              y1   y2
476       //
477       // If "root" op is AddN and shapes does not match, this rewrite is not
478       // possible and graph should stay intact.
479       NodeMap node_map(&output);
480 
481       if (use_addn && !matching_shapes) {
482         VerifyGraphsMatch(item.graph, output, __LINE__);
483       } else {
484         EXPECT_EQ(output.node_size(), 9);
485 
486         const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
487         ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
488         ASSERT_EQ(new_add_node->input_size(), 2);
489         EXPECT_EQ(new_add_node->input(0), "y1");
490         EXPECT_EQ(new_add_node->input(1), "y2");
491 
492         const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
493         ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
494         ASSERT_EQ(new_mul_node->input_size(), 2);
495         EXPECT_EQ(new_mul_node->input(0), "x");
496         EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
497 
498         const NodeDef* id_node = node_map.GetNode("id");
499         ASSERT_NE(id_node, nullptr) << "Id node not found";
500         EXPECT_EQ(id_node->name(), "id");
501         ASSERT_EQ(id_node->input_size(), 1);
502         EXPECT_EQ(id_node->input(0), HoistMulName("add"));
503       }
504       auto tensors = EvaluateNodes(output, item.fetch);
505       ASSERT_EQ(tensors.size(), 1);
506       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
507     }
508   }
509 }
510 
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)511 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
512   for (bool matching_shapes : {true, false}) {
513     for (bool use_addn : {true, false}) {
514       for (bool use_ints : {true, false}) {
515         tensorflow::Scope s = tensorflow::Scope::NewRootScope();
516         Output x = use_ints
517                        ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
518                        : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
519         Output y1 = use_ints
520                         ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
521                         : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
522         Output y2;
523         if (matching_shapes) {
524           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
525                         : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
526         } else {
527           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
528                         : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
529         }
530         Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
531         Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
532         Output id =
533             use_addn
534                 ? ops::Identity(s.WithOpName("id"),
535                                 ops::AddN(s.WithOpName("add"), {div1, div2}))
536                 : ops::Identity(s.WithOpName("id"),
537                                 ops::Add(s.WithOpName("add"), div1, div2));
538 
539         GrapplerItem item;
540         item.fetch = {"id"};
541         TF_CHECK_OK(s.ToGraphDef(&item.graph));
542 
543         auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
544         ASSERT_EQ(tensors_expected.size(), 1);
545 
546         ArithmeticOptimizer optimizer;
547         EnableOnlyHoistCommonFactor(&optimizer);
548 
549         GraphDef output;
550         OptimizeTwice(&optimizer, &item, &output);
551 
552         // We expect the following rewrite(s) to occur:
553         //
554         //        Add                 Div
555         //      /    \               /   \
556         //    Div    Div       ->  Add    x
557         //    / \    / \           / \
558         //   y1  x  y2  x         y1  y2
559         //
560         // If "root" op is AddN and shapes does not match, this rewrite is not
561         // possible and graph should stay intact.
562         NodeMap node_map(&output);
563 
564         if ((use_addn && !matching_shapes) || use_ints) {
565           VerifyGraphsMatch(item.graph, output, __LINE__);
566         } else {
567           EXPECT_EQ(output.node_size(), 9);
568 
569           const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
570           ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
571           ASSERT_EQ(new_add_node->input_size(), 2);
572           EXPECT_EQ(new_add_node->input(0), "y1");
573           EXPECT_EQ(new_add_node->input(1), "y2");
574 
575           const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
576           ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
577           ASSERT_EQ(new_div_node->input_size(), 2);
578           EXPECT_EQ(new_div_node->input(0), new_add_node->name());
579           EXPECT_EQ(new_div_node->input(1), "x");
580 
581           const NodeDef* id_node = node_map.GetNode("id");
582           ASSERT_TRUE(id_node != nullptr) << "Id node not found";
583           EXPECT_EQ("id", id_node->name());
584           ASSERT_EQ(id_node->input_size(), 1);
585           EXPECT_EQ(id_node->input(0), HoistDivName("add"));
586         }
587         auto tensors = EvaluateNodes(output, item.fetch);
588         ASSERT_EQ(tensors.size(), 1);
589         if (use_ints) {
590           test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
591         } else {
592           test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
593         }
594       }
595     }
596   }
597 }
598 
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)599 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
600   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
601   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
602   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
603   Output z = ops::Complex(s.WithOpName("z"), re, im);
604   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
605   Output conj = ops::Conj(s.WithOpName("conj"), z);
606   Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
607 
608   GrapplerItem item;
609   item.fetch = {"trans"};
610   TF_CHECK_OK(s.ToGraphDef(&item.graph));
611 
612   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
613   ASSERT_EQ(tensors_expected.size(), 1);
614 
615   ArithmeticOptimizer optimizer;
616   GraphDef output;
617   OptimizeTwice(&optimizer, &item, &output);
618   NodeMap node_map(&output);
619 
620   EXPECT_EQ(output.node_size(), 7);
621 
622   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
623   const string optimized_name = absl::StrCat(p, "_", "trans");
624 
625   const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
626   ASSERT_NE(trans_fused_node, nullptr);
627   EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
628   ASSERT_EQ(trans_fused_node->input_size(), 2);
629   EXPECT_EQ(trans_fused_node->input(0), "z");
630   EXPECT_EQ(trans_fused_node->input(1), "perm");
631 
632   auto tensors = EvaluateNodes(output, item.fetch);
633   ASSERT_EQ(tensors.size(), 1);
634   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
635 }
636 
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)637 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
638   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
639 
640   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
641   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
642   Output z = ops::Complex(s.WithOpName("z"), re, im);
643   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
644   Output conj = ops::Conj(s.WithOpName("conj"), z);
645   Output transp =
646       ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
647 
648   GrapplerItem item;
649   item.fetch = {"conjugate_trans"};
650   TF_CHECK_OK(s.ToGraphDef(&item.graph));
651 
652   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
653   ASSERT_EQ(tensors_expected.size(), 1);
654 
655   ArithmeticOptimizer optimizer;
656   GraphDef output;
657   OptimizeTwice(&optimizer, &item, &output);
658   NodeMap node_map(&output);
659 
660   EXPECT_EQ(output.node_size(), 7);
661 
662   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
663   const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
664 
665   const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
666   ASSERT_NE(conjugate_trans_fused_node, nullptr);
667   EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
668   ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
669   EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
670   EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
671 
672   auto tensors = EvaluateNodes(output, item.fetch);
673   ASSERT_EQ(tensors.size(), 1);
674   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
675 }
676 
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)677 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
678   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
680   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
681   Output z = ops::Complex(s.WithOpName("z"), re, im);
682   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
683   Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
684   Output conj = ops::Conj(s.WithOpName("conj"), trans);
685 
686   GrapplerItem item;
687   item.fetch = {"conj"};
688   TF_CHECK_OK(s.ToGraphDef(&item.graph));
689 
690   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
691   ASSERT_EQ(tensors_expected.size(), 1);
692 
693   ArithmeticOptimizer optimizer;
694   GraphDef output;
695   OptimizeTwice(&optimizer, &item, &output);
696   NodeMap node_map(&output);
697 
698   EXPECT_EQ(output.node_size(), 7);
699 
700   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
701   const string optimized_name = absl::StrCat(p, "_", "conj");
702 
703   const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
704   ASSERT_NE(conj_fused_node, nullptr);
705   EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
706   ASSERT_EQ(conj_fused_node->input_size(), 2);
707   EXPECT_EQ(conj_fused_node->input(0), "z");
708   EXPECT_EQ(conj_fused_node->input(1), "perm");
709 
710   auto tensors = EvaluateNodes(output, item.fetch);
711   ASSERT_EQ(tensors.size(), 1);
712   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
713 }
714 
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)715 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
716   for (const string matmul_type :
717        {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
718     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
719 
720     Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
721     Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
722     Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
723     Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
724     Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
725 
726     Output matmul;
727     auto matmul_op = s.WithOpName("matmul");
728     if (matmul_type == "MatMul") {
729       matmul = ops::MatMul(matmul_op, trans_a, trans_b);
730     } else if (matmul_type == "SparseMatMul") {
731       matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
732     } else if (matmul_type == "BatchMatMul") {
733       matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
734     } else if (matmul_type == "BatchMatMulV2") {
735       matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
736     }
737 
738     auto identity = ops::Identity(s.WithOpName("identity"), matmul);
739 
740     GrapplerItem item;
741     item.fetch = {"identity"};
742     TF_CHECK_OK(s.ToGraphDef(&item.graph));
743 
744     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
745     ASSERT_EQ(tensors_expected.size(), 1);
746 
747     ArithmeticOptimizer optimizer;
748     EnableOnlyFoldTransposeIntoMatMul(&optimizer);
749     GraphDef output;
750     OptimizeTwice(&optimizer, &item, &output);
751     NodeMap node_map(&output);
752 
753     EXPECT_EQ(output.node_size(), 8);
754 
755     const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
756     const string optimized_name = absl::StrCat(p, "_", "matmul");
757 
758     const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
759     ASSERT_NE(matmul_fused_node, nullptr);
760     ASSERT_EQ(matmul_fused_node->input_size(), 2);
761     EXPECT_EQ(matmul_fused_node->input(0), "a");
762     EXPECT_EQ(matmul_fused_node->input(1), "b");
763 
764     if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
765       EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
766       EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
767     } else {
768       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
769       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
770     }
771 
772     const NodeDef* identity_node = node_map.GetNode("identity");
773     ASSERT_NE(identity_node, nullptr);
774     ASSERT_EQ(identity_node->input_size(), 1);
775     EXPECT_EQ(identity_node->input(0), optimized_name);
776 
777     auto tensors = EvaluateNodes(output, item.fetch);
778     ASSERT_EQ(tensors.size(), 1);
779     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
780   }
781 }
782 
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)783 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
784   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785 
786   Output re_a =
787       ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
788   Output im_a =
789       ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
790   Output re_b =
791       ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
792   Output im_b =
793       ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
794   Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
795   Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
796   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
797   Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
798   Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
799   Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
800   Output identity = ops::Identity(s.WithOpName("identity"), matmul);
801 
802   GrapplerItem item;
803   item.fetch = {"identity"};
804   TF_CHECK_OK(s.ToGraphDef(&item.graph));
805 
806   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
807   ASSERT_EQ(tensors_expected.size(), 1);
808 
809   ArithmeticOptimizer optimizer;
810   GraphDef output;
811   OptimizeTwice(&optimizer, &item, &output);
812 
813   NodeMap node_map(&output);
814   EXPECT_EQ(output.node_size(), 12);
815 
816   const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
817   const string optimized_name = absl::StrCat(p, "_", "matmul");
818 
819   const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
820   ASSERT_NE(optimized_matmul, nullptr);
821   ASSERT_EQ(optimized_matmul->input_size(), 2);
822   EXPECT_EQ(optimized_matmul->input(0), "a");
823   EXPECT_EQ(optimized_matmul->input(1), "b");
824   EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
825   EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
826 
827   auto tensors = EvaluateNodes(output, item.fetch);
828   ASSERT_EQ(tensors.size(), 1);
829   test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
830 }
831 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)832 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
833   for (bool is_broadcastto : {false, true}) {
834     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835     Output inputs =
836         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
837     Output inputs_shape = ops::Shape(s, inputs);
838     // The target shape of the reshape is the concatenation of `batch_size` and
839     // [3,28,28].
840     Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
841                                    ops::Const(s, {1}, {1}));
842     Output target_shape = ops::Concat(
843         s.WithOpName("target_shape"),
844         {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
845     if (is_broadcastto) {
846       Output outputs = ops::Identity(s.WithOpName("outputs"),
847                                      ops::BroadcastTo(s, inputs, target_shape));
848     } else {
849       Output outputs = ops::Identity(s.WithOpName("outputs"),
850                                      ops::Reshape(s, inputs, target_shape));
851     }
852 
853     GrapplerItem item;
854     item.fetch = {"outputs"};
855     TF_CHECK_OK(s.ToGraphDef(&item.graph));
856     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
857     auto tensors_expected =
858         EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
859     ASSERT_EQ(tensors_expected.size(), 1);
860 
861     GraphDef output;
862     ArithmeticOptimizer optimizer;
863     EnableOnlyRemoveRedundantReshape(&optimizer);
864     OptimizeTwiceAndPrune(&optimizer, &item, &output);
865 
866     EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
867     EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
868     auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
869     ASSERT_EQ(tensors.size(), 1);
870     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
871   }
872 }
873 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)874 TEST_F(ArithmeticOptimizerTest,
875        RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
876   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
877   Output inputs =
878       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
879   Output inputs_shape = ops::Shape(s, inputs);
880   // The target shape of the reshape is the concatenation of `batch_size`, 3,
881   // `height, and `width`.
882   Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
883                                  ops::Const(s, {1}, {1}));
884   Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
885                              ops::Const(s, {1}, {1}));
886   Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
887                             ops::Const(s, {1}, {1}));
888   Output target_shape =
889       ops::Concat(s.WithOpName("target_shape"),
890                   {batch_size, ops::Const(s, {3}, {1}), height, width},
891                   ops::Const(s, {0}, {}));
892   Output reshape = ops::Reshape(s, inputs, target_shape);
893   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
894 
895   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
896   GrapplerItem item;
897   item.fetch = {"outputs"};
898   item.feed = {{"Placeholder", x_t}};
899   TF_CHECK_OK(s.ToGraphDef(&item.graph));
900 
901   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
902   ASSERT_EQ(tensors_expected.size(), 1);
903 
904   GraphDef output;
905   // Assume valid feed shape in aggressive mode.
906   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
907   EnableOnlyRemoveRedundantReshape(&optimizer);
908   OptimizeTwiceAndPrune(&optimizer, &item, &output);
909 
910   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
911   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
912   ASSERT_EQ(tensors.size(), 1);
913   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
914 }
915 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)916 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
917   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
918   Output inputs =
919       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
920   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
921   Output reshape = ops::Reshape(s, inputs, target_shape);
922   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
923 
924   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
925   GrapplerItem item;
926   item.fetch = {"outputs"};
927   item.feed = {{"Placeholder", x_t}};
928   TF_CHECK_OK(s.ToGraphDef(&item.graph));
929 
930   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
931   ASSERT_EQ(tensors_expected.size(), 1);
932 
933   GraphDef output;
934   ArithmeticOptimizer optimizer;
935   EnableOnlyRemoveRedundantReshape(&optimizer);
936   OptimizeTwiceAndPrune(&optimizer, &item, &output);
937 
938   // The reshape is preserved because the shape of the placeholder can be
939   // different from the shape of the actual feed.
940   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
941 
942   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
943   ASSERT_EQ(tensors.size(), 1);
944   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
945 }
946 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)947 TEST_F(ArithmeticOptimizerTest,
948        RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
949   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
950   Output inputs =
951       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
952   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
953   Output reshape = ops::Reshape(s, inputs, target_shape);
954   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
955 
956   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
957   GrapplerItem item;
958   item.fetch = {"outputs"};
959   item.feed = {{"Placeholder", x_t}};
960   TF_CHECK_OK(s.ToGraphDef(&item.graph));
961 
962   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
963   ASSERT_EQ(tensors_expected.size(), 1);
964 
965   GraphDef output;
966   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
967   EnableOnlyRemoveRedundantReshape(&optimizer);
968   OptimizeTwiceAndPrune(&optimizer, &item, &output);
969 
970   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
971   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
972   ASSERT_EQ(tensors.size(), 1);
973   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
974 }
975 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)976 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
977   // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
978   // be from [4,3,28,28] to [8,6,28,28].
979   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
980   Output inputs =
981       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
982   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
983   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
984 
985   GrapplerItem item;
986   item.fetch = {"outputs"};
987   TF_CHECK_OK(s.ToGraphDef(&item.graph));
988   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
989   item.feed = {{"Placeholder", x_t}};
990   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
991   ASSERT_EQ(tensors_expected.size(), 1);
992 
993   GraphDef output;
994   ArithmeticOptimizer optimizer;
995   EnableOnlyRemoveRedundantReshape(&optimizer);
996   OptimizeTwiceAndPrune(&optimizer, &item, &output);
997 
998   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
999   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1000   ASSERT_EQ(tensors.size(), 1);
1001   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1002 }
1003 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1004 TEST_F(ArithmeticOptimizerTest,
1005        RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1006   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1007   Output inputs =
1008       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1009   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1010   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1011 
1012   GrapplerItem item;
1013   item.fetch = {"outputs"};
1014   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1015 
1016   GraphDef output;
1017   ArithmeticOptimizer optimizer;
1018   EnableOnlyRemoveRedundantReshape(&optimizer);
1019   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1020 
1021   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1022 }
1023 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1024 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1025   // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
1026   // reshapes should be combined.
1027   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1028   Output nchw_vect_c =
1029       ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8,
1030                        ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1031   Output transpose =
1032       ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1033                      ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1034   Output nhwc = ops::Reshape(
1035       s.WithOpName("nhwc"), transpose,
1036       ops::Const(
1037           s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1038           {8, 28, 28, 12}, {4}));
1039   Output flatten = ops::Reshape(
1040       s.WithOpName("flatten"), nhwc,
1041       ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1042   Output outputs = ops::Identity(s.WithOpName("outputs"), flatten);
1043 
1044   GraphDef graph;
1045   TF_CHECK_OK(s.ToGraphDef(&graph));
1046   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
1047   auto eval = EvaluateNodes(graph, {"outputs", "nhwc"}, {{"nchw_vect_c", x_t}});
1048 
1049   ASSERT_EQ(eval.size(), 2);
1050   auto expected_output_t = eval[0];
1051   auto nhwc_t = eval[1];
1052 
1053   {
1054     GrapplerItem item;
1055     item.graph = graph;
1056     item.fetch = {"outputs"};
1057     item.feed = {{"nchw_vect_c", x_t}};
1058 
1059     GraphDef output;
1060     ArithmeticOptimizer optimizer;
1061     EnableOnlyRemoveRedundantReshape(&optimizer);
1062     OptimizeTwiceAndPrune(&optimizer, &item, &output);
1063 
1064     EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1065     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1066     ASSERT_EQ(tensors.size(), 1);
1067     test::ExpectTensorEqual<int8>(tensors[0], expected_output_t);
1068   }
1069 
1070   // Test when the first reshape node output is the feed tensor.
1071   // (Expected no reshape removal to happen.)
1072   {
1073     GrapplerItem item;
1074     item.graph = graph;
1075     item.fetch = {"outputs"};
1076     item.feed = {{"nhwc", nhwc_t}};
1077 
1078     GraphDef output;
1079     ArithmeticOptimizer optimizer;
1080     EnableOnlyRemoveRedundantReshape(&optimizer);
1081     OptimizeTwiceAndPrune(&optimizer, &item, &output);
1082 
1083     EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1084     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1085     ASSERT_EQ(tensors.size(), 1);
1086     test::ExpectTensorEqual<int8>(tensors[0], expected_output_t);
1087   }
1088 }
1089 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1090 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1091   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1092   Output nhwc_uint8 =
1093       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1094   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1095   Output nchw_fp32 =
1096       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1097   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1098 
1099   GrapplerItem item;
1100   item.fetch = {"outputs"};
1101   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1102 
1103   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1104   auto tensors_expected =
1105       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1106   ASSERT_EQ(tensors_expected.size(), 1);
1107 
1108   GraphDef output;
1109   ArithmeticOptimizer optimizer;
1110   OptimizeAndPrune(&optimizer, &item, &output);
1111 
1112   const NodeDef* transpose_node = nullptr;
1113   for (const NodeDef& node : output.node()) {
1114     if (node.op() == "Transpose") {
1115       EXPECT_EQ(transpose_node, nullptr);
1116       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1117       transpose_node = &node;
1118     }
1119   }
1120   ASSERT_NE(transpose_node, nullptr);
1121 
1122   for (const NodeDef& node : output.node()) {
1123     if (node.op() == "Cast") {
1124       ASSERT_EQ(node.input_size(), 1);
1125       EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1126     }
1127   }
1128 
1129   auto tensors =
1130       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1131   ASSERT_EQ(tensors.size(), 1);
1132   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1133 }
1134 
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1135 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1136   // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1137   // under a /GPU:0 scope, because this test would fail if the testing machine
1138   // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1139   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1140   Output outputs =
1141       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1142   outputs = ops::Cast(s, outputs, DT_FLOAT);
1143   outputs = ops::SpaceToDepth(s, outputs, 2);
1144   outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1145 
1146   GrapplerItem item;
1147   item.fetch = {"outputs"};
1148   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1149 
1150   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1151   auto tensors_expected =
1152       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1153   ASSERT_EQ(tensors_expected.size(), 1);
1154 
1155   GraphDef output;
1156   ArithmeticOptimizer optimizer;
1157   OptimizeAndPrune(&optimizer, &item, &output);
1158 
1159   const NodeDef* s2d_node = nullptr;
1160   for (const NodeDef& node : output.node()) {
1161     if (node.op() == "SpaceToDepth") {
1162       EXPECT_EQ(s2d_node, nullptr);
1163       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1164       s2d_node = &node;
1165     }
1166   }
1167   ASSERT_NE(s2d_node, nullptr);
1168 
1169   for (const NodeDef& node : output.node()) {
1170     if (node.op() == "Cast") {
1171       ASSERT_EQ(node.input_size(), 1);
1172       EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1173     }
1174   }
1175 
1176   auto tensors =
1177       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1178   ASSERT_EQ(tensors.size(), 1);
1179   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1180 }
1181 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1182 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1183   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1184   Output nhwc_fp32 =
1185       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1186   Output nchw_fp32 =
1187       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1188   Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1189   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1190 
1191   GrapplerItem item;
1192   item.fetch = {"outputs"};
1193   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1194 
1195   auto input_t =
1196       GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1197   auto tensors_expected =
1198       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1199   ASSERT_EQ(tensors_expected.size(), 1);
1200 
1201   GraphDef output;
1202   ArithmeticOptimizer optimizer;
1203   OptimizeAndPrune(&optimizer, &item, &output);
1204 
1205   const NodeDef* cast_node = nullptr;
1206   for (const NodeDef& node : output.node()) {
1207     if (node.op() == "Cast") {
1208       EXPECT_EQ(cast_node, nullptr);
1209       cast_node = &node;
1210       ASSERT_EQ(node.input_size(), 1);
1211       EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1212     }
1213   }
1214   ASSERT_NE(cast_node, nullptr);
1215 
1216   for (const NodeDef& node : output.node()) {
1217     if (node.op() == "Transpose") {
1218       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1219       ASSERT_EQ(node.input_size(), 2);
1220       EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1221     }
1222   }
1223 
1224   auto tensors =
1225       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1226   ASSERT_EQ(tensors.size(), 1);
1227   test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1228 }
1229 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1230 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1231   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1232   Output nhwc_uint8 =
1233       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1234   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1235   Output nhwc_fp32_reversed =
1236       ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1237   Output nchw_fp32_reversed =
1238       ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1239 
1240   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1241 
1242   GrapplerItem item;
1243   item.fetch = {"outputs"};
1244   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1245 
1246   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1247   auto tensors_expected =
1248       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1249   ASSERT_EQ(tensors_expected.size(), 1);
1250 
1251   GraphDef output;
1252   ArithmeticOptimizer optimizer;
1253   OptimizeAndPrune(&optimizer, &item, &output);
1254 
1255   const NodeDef* reverse_node = nullptr;
1256   const NodeDef* transpose_node = nullptr;
1257   const NodeDef* cast_node = nullptr;
1258   for (const NodeDef& node : output.node()) {
1259     if (node.op() == "Transpose") {
1260       EXPECT_EQ(transpose_node, nullptr);
1261       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1262       transpose_node = &node;
1263     } else if (node.op() == "ReverseV2") {
1264       EXPECT_EQ(reverse_node, nullptr);
1265       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1266       reverse_node = &node;
1267     } else if (node.op() == "Cast") {
1268       cast_node = &node;
1269     }
1270   }
1271   ASSERT_NE(cast_node, nullptr);
1272   ASSERT_NE(reverse_node, nullptr);
1273   ASSERT_NE(transpose_node, nullptr);
1274   ASSERT_EQ(reverse_node->input_size(), 2);
1275   EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1276   ASSERT_EQ(transpose_node->input_size(), 2);
1277   EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1278   ASSERT_EQ(cast_node->input_size(), 1);
1279   EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1280 
1281   auto tensors =
1282       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1283   ASSERT_EQ(tensors.size(), 1);
1284   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1285 }
1286 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1287 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1288   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1289   Output nhwc_uint8 =
1290       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1291   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1292   Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1293   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1294 
1295   GrapplerItem item;
1296   item.fetch = {"outputs"};
1297   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1298 
1299   GraphDef output;
1300   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1301   CompareGraphs(item.graph, output);
1302 }
1303 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1304 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1305   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1306   Output nhwc_fp32 =
1307       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1308   Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1309   Output nchw_uint8 =
1310       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1311   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1312 
1313   GrapplerItem item;
1314   item.fetch = {"outputs"};
1315   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1316 
1317   GraphDef output;
1318   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1319   CompareGraphs(item.graph, output);
1320 }
1321 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1322 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1323   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1324   Output nhwc_uint8 =
1325       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1326   Output nchw_uint8 =
1327       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1328   Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1329   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1330 
1331   GrapplerItem item;
1332   item.fetch = {"outputs"};
1333   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1334 
1335   GraphDef output;
1336   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1337   CompareGraphs(item.graph, output);
1338 }
1339 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1340 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1341   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1342   Output inputs_shape =
1343       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1344   Output inputs =
1345       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1346   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1347   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1348   Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1349   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1350   Output transpose2 =
1351       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1352   Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1353   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1354   Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1355 
1356   GrapplerItem item;
1357   item.fetch = {"id1", "id2"};
1358   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1359 
1360   GraphDef output;
1361   ArithmeticOptimizer optimizer;
1362   EnableOnlyRemoveIdentityTranspose(&optimizer);
1363   OptimizeAndPrune(&optimizer, &item, &output);
1364 
1365   std::set<string> nodes_after_optimization;
1366   for (const NodeDef& node : output.node()) {
1367     nodes_after_optimization.insert(node.name());
1368   }
1369   EXPECT_EQ(nodes_after_optimization,
1370             std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1371 }
1372 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1373 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1374   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1375   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1376   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1377   Output z = ops::Complex(s.WithOpName("z"), re, im);
1378   Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1379   Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1380   Output id = ops::Identity(s.WithOpName("id"), transpose);
1381 
1382   GrapplerItem item;
1383   item.fetch = {"id"};
1384   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1385 
1386   GraphDef output;
1387   ArithmeticOptimizer optimizer;
1388   EnableOnlyRemoveIdentityTranspose(&optimizer);
1389   OptimizeAndPrune(&optimizer, &item, &output);
1390   NodeMap node_map(&output);
1391 
1392   EXPECT_EQ(output.node_size(), 5);
1393 
1394   const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1395   const string optimized_name = absl::StrCat(p, "_", "trans");
1396 
1397   const NodeDef* conj = node_map.GetNode(optimized_name);
1398   ASSERT_NE(conj, nullptr);
1399   EXPECT_EQ(conj->op(), "Conj");
1400   ASSERT_EQ(conj->input_size(), 1);
1401   EXPECT_EQ(conj->input(0), "z");
1402 }
1403 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1404 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1405   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1406   Output inputs_shape =
1407       ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1408   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1409                                    ops::Placeholder::Shape({8, 12, 28, 28}));
1410   OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1411   Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1412   Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1413   Output branch0 = split[0];
1414   Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1415   Output branch2 = split[2];
1416   Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1417   Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1418 
1419   GrapplerItem item;
1420   item.fetch = {"outputs"};
1421   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1422 
1423   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1424   item.feed = {{"inputs", x_t}};
1425   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1426   ASSERT_EQ(tensors_expected.size(), 1);
1427 
1428   GraphDef output;
1429   ArithmeticOptimizer optimizer;
1430   EnableOnlyRemoveIdentityTranspose(&optimizer);
1431   OptimizeAndPrune(&optimizer, &item, &output);
1432 
1433   for (const NodeDef& node : output.node()) {
1434     if (node.op() == "Concat") {
1435       ASSERT_EQ(node.input_size(), 3);
1436       EXPECT_EQ(node.input(0), "Split");
1437       EXPECT_EQ(node.input(1), "Split:1");
1438       EXPECT_EQ(node.input(2), "Split:2");
1439     }
1440   }
1441 
1442   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1443   ASSERT_EQ(tensors.size(), 1);
1444   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1445 }
1446 
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1447 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1448   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1449   Output inputs =
1450       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1451   Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1452   Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1453   Output outputs =
1454       ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1455                     ops::Const(s.WithOpName("outputs_const"), 1.0f));
1456 
1457   GrapplerItem item;
1458   item.fetch = {"outputs"};
1459   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1460 
1461   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1462   item.feed = {{"Placeholder", x_t}};
1463   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1464   ASSERT_EQ(tensors_expected.size(), 1);
1465 
1466   GraphDef output;
1467   ArithmeticOptimizer optimizer;
1468   EnableOnlyRemoveIdentityTranspose(&optimizer);
1469   OptimizeAndPrune(&optimizer, &item, &output);
1470 
1471   NodeMap node_map(&output);
1472   const NodeDef* outputs_node = node_map.GetNode("outputs");
1473   ASSERT_EQ(outputs_node->input_size(), 2);
1474   EXPECT_EQ(outputs_node->input(0), "outputs_const");
1475   EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1476 
1477   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1478   ASSERT_EQ(tensors.size(), 1);
1479   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1480 }
1481 
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1482 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1483   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1484   Output inputs_shape =
1485       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1486   Output inputs =
1487       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1488   Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1489   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1490   Output transpose2 =
1491       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1492   Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1493 
1494   GrapplerItem item;
1495   item.fetch = {"outputs"};
1496   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1497 
1498   GraphDef output;
1499   ArithmeticOptimizer optimizer;
1500   EnableOnlyRemoveIdentityTranspose(&optimizer);
1501   OptimizeAndPrune(&optimizer, &item, &output);
1502 
1503   EXPECT_EQ(output.node_size(), 6);
1504 }
1505 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1506 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1507   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1508   Output inputs_shape =
1509       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1510   Output inputs =
1511       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1512   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1513   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1514   Output transpose1 = ops::Transpose(
1515       s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1516   Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1517   Output transpose2 =
1518       ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1519   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1520 
1521   GrapplerItem item;
1522   item.fetch = {"id1"};
1523   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1524 
1525   GraphDef output;
1526   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1527   EnableOnlyRemoveIdentityTranspose(&optimizer);
1528   OptimizeAndPrune(&optimizer, &item, &output);
1529 
1530   std::set<string> nodes_after_optimization;
1531   for (const NodeDef& node : output.node()) {
1532     nodes_after_optimization.insert(node.name());
1533     if (node.name() == "id") {
1534       ASSERT_EQ(node.input_size(), 1);
1535       EXPECT_EQ(node.input(0), "inputs");
1536     }
1537     if (node.name() == "id1") {
1538       ASSERT_EQ(node.input_size(), 1);
1539       EXPECT_EQ(node.input(0), "id");
1540     }
1541   }
1542   EXPECT_EQ(nodes_after_optimization,
1543             std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
1544 }
1545 
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)1546 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
1547   for (bool swap_inputs : {false, true}) {
1548     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1549     Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1550                                      ops::Placeholder::Shape({1, 28, 28, 3}));
1551     Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1552     Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
1553                                          swap_inputs ? scale : inputs,
1554                                          swap_inputs ? inputs : scale);
1555     Output perm_nhwc_to_nchw =
1556         ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1557     Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1558                                         scaled_inputs, perm_nhwc_to_nchw);
1559     Output weights = ops::Const(s.WithOpName("weights"),
1560                                 Input::Initializer(127.0f, {5, 5, 3, 4}));
1561     Output conv =
1562         ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1563                     "VALID", ops::Conv2D::DataFormat("NCHW"));
1564     Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1565 
1566     GrapplerItem item;
1567     item.fetch = {"outputs"};
1568     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1569 
1570     //    LOG(INFO) << "Before:\n" << item.graph.DebugString();
1571     GraphDef output;
1572     ArithmeticOptimizer optimizer;
1573     EnableOnlyFoldMultipleIntoConv(&optimizer);
1574     OptimizeTwiceAndPrune(&optimizer, &item, &output);
1575 
1576     //    LOG(INFO) << "After:\n"  << output.DebugString();
1577     NodeMap node_map(&output);
1578     // `conv` is now a folded convolution with scaled weights.
1579     const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1580     ASSERT_NE(folded_conv, nullptr);
1581 
1582     const NodeDef* folded_conv_weights =
1583         node_map.GetNode(folded_conv->input(1));
1584     ASSERT_NE(folded_conv_weights, nullptr);
1585     EXPECT_EQ(folded_conv_weights->op(), "Mul");
1586 
1587     // Its input should be a transpose of `inputs`.
1588     const NodeDef* transpose =
1589         node_map.GetNode(NodeName(folded_conv->input(0)));
1590     ASSERT_NE(transpose, nullptr);
1591     ASSERT_EQ(transpose->input_size(), 2);
1592     EXPECT_EQ(transpose->input(0), "inputs");
1593   }
1594 }
1595 
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)1596 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
1597   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1598   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1599                                    ops::Placeholder::Shape({8, 28, 28, 3}));
1600   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1601   Output scaled_inputs =
1602       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1603   Output perm_nhwc_to_nchw =
1604       ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1605   Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1606                                       scaled_inputs, perm_nhwc_to_nchw);
1607   Output weights = ops::Const(s.WithOpName("weights"),
1608                               Input::Initializer(127.0f, {5, 5, 3, 16}));
1609   Output conv =
1610       ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1611                   "VALID", ops::Conv2D::DataFormat("NCHW"));
1612   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1613 
1614   Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
1615   memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
1616          inputs_nchw_tensor.tensor_data().size());
1617 
1618   GrapplerItem item;
1619   item.fetch = {"outputs"};
1620   item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
1621   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1622 
1623   GraphDef output;
1624   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1625 
1626   item.graph.Swap(&output);
1627   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1628 
1629   NodeMap node_map(&output);
1630   const NodeDef* inputs_nchw_node_def =
1631       node_map.GetNode(inputs_nchw.node()->name());
1632   ASSERT_NE(inputs_nchw_node_def, nullptr);
1633   ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
1634   EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
1635             scaled_inputs.node()->name());
1636 }
1637 
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)1638 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
1639   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1640   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1641                                    ops::Placeholder::Shape({8, 28, 28, 28, 3}));
1642   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1643   Output scaled_inputs =
1644       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1645   Output weights = ops::Const(s.WithOpName("weights"),
1646                               Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
1647   Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
1648                             {1, 1, 1, 1, 1}, "VALID");
1649   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1650 
1651   GrapplerItem item;
1652   item.fetch = {"outputs"};
1653   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1654 
1655   GraphDef output;
1656   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1657 
1658   item.graph.Swap(&output);
1659   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1660 
1661   NodeMap node_map(&output);
1662   // `conv` is now a folded convolution on `inputs` and scaled weights.
1663   const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1664   ASSERT_NE(folded_conv, nullptr);
1665   ASSERT_EQ(folded_conv->input_size(), 2);
1666   CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
1667   const NodeDef* folded_conv_input_1 =
1668       node_map.GetNode(NodeName(folded_conv->input(1)));
1669   ASSERT_NE(folded_conv_input_1, nullptr);
1670   CHECK_EQ(folded_conv_input_1->op(), "Mul");
1671 }
1672 
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)1673 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
1674   // This unit test exercises two optimizations, folding mul into conv, and
1675   // reordering cast and transpose.
1676   //
1677   //   Conv2D(Transpose(Mul(Cast(I), S)), W)
1678   //     =>
1679   //   Conv2D(Transpose(Cast(I)), W*S)
1680   //     =>
1681   //   Conv2D(Cast(Transpose(I)), W*S)
1682   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1683 
1684   Output inputs =
1685       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1686   Output cast = ops::Cast(s, inputs, DT_FLOAT);
1687   Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
1688   Output transpose =
1689       ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
1690   Output weights = ops::Const(s.WithOpName("weights"),
1691                               Input::Initializer(127.0f, {5, 5, 3, 16}));
1692   Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
1693                             ops::Conv2D::DataFormat("NCHW"));
1694   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1695 
1696   GrapplerItem item;
1697   item.fetch = {"outputs"};
1698   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1699 
1700   GraphDef output;
1701   ArithmeticOptimizer optimizer;  // all optimization stages are on
1702   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1703   NodeMap node_map(&output);
1704 
1705   // Expected names for reordered cast and transpose.
1706   const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
1707   const string optimized_cast_name = absl::StrCat(p, "float_Cast");
1708   const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
1709 
1710   // Expected names for folded multiply and conv.
1711   const string optimized_weights =
1712       "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
1713 
1714   const NodeDef* inputs_node = node_map.GetNode("Placeholder");
1715   const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
1716   const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
1717 
1718   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1719   const NodeDef* conv_node = node_map.GetNode("Conv2D");
1720 
1721   ASSERT_NE(inputs_node, nullptr);
1722   ASSERT_NE(transpose_node, nullptr);
1723   ASSERT_NE(cast_node, nullptr);
1724   ASSERT_NE(weights_node, nullptr);
1725   ASSERT_NE(conv_node, nullptr);
1726 
1727   EXPECT_EQ(output.node_size(), 7);
1728   ASSERT_EQ(transpose_node->input_size(), 2);
1729   EXPECT_EQ(transpose_node->input(0), inputs_node->name());
1730   ASSERT_EQ(cast_node->input_size(), 1);
1731   EXPECT_EQ(cast_node->input(0), transpose_node->name());
1732   ASSERT_EQ(conv_node->input_size(), 2);
1733   EXPECT_EQ(conv_node->input(0), cast_node->name());
1734   EXPECT_EQ(conv_node->input(1), weights_node->name());
1735 }
1736 
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)1737 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
1738   // This unit test exercises optimization of folding mul into conv for
1739   // multiple nodes in the graph.
1740   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1741 
1742   GrapplerItem item;
1743   Output conv[2];
1744 
1745   for (int i = 0; i < 2; ++i) {
1746     Output inputs =
1747         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
1748     Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
1749     Output weights = ops::Const(s.WithOpName("weights"),
1750                                 Input::Initializer(127.0f, {5, 5, 3, 16}));
1751     conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
1752                           ops::Conv2D::DataFormat("NCHW"));
1753   }
1754   Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
1755 
1756   item.fetch = {"outputs"};
1757   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1758 
1759   GraphDef output;
1760   ArithmeticOptimizer optimizer;
1761   EnableOnlyFoldMultipleIntoConv(&optimizer);
1762   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1763 
1764   NodeMap node_map(&output);
1765 
1766   using absl::StrCat;
1767   const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
1768   const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
1769   const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
1770 
1771   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1772   const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
1773   const NodeDef* conv_node = node_map.GetNode("Conv2D");
1774   const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
1775 
1776   ASSERT_NE(weights_node, nullptr);
1777   ASSERT_NE(weights_node_1, nullptr);
1778   ASSERT_NE(conv_node, nullptr);
1779   ASSERT_NE(conv_node_1, nullptr);
1780 
1781   ASSERT_EQ(conv_node->input_size(), 2);
1782   ASSERT_EQ(conv_node_1->input_size(), 2);
1783   EXPECT_EQ(conv_node->input(1), weights_node->name());
1784   EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
1785 }
1786 
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)1787 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
1788   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1789   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
1790                                    ops::Placeholder::Shape({2, 3}));
1791   Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
1792   Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
1793   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1794 
1795   GrapplerItem item;
1796   item.fetch = {"outputs"};
1797   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1798 
1799   auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
1800   item.feed = {{"inputs", x_t}};
1801   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1802   ASSERT_EQ(tensors_expected.size(), 1);
1803 
1804   GraphDef output;
1805   ArithmeticOptimizer optimizer;
1806   EnableOnlyRemoveRedundantBitcast(&optimizer);
1807 
1808   OptimizeAndPrune(&optimizer, &item, &output);
1809   NodeMap node_map(&output);
1810 
1811   // Bitcasts combined into a single op and inputs redirected to updated Bitcast
1812   EXPECT_EQ(output.node_size(), 3);
1813   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
1814   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
1815 
1816   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1817   ASSERT_EQ(tensors.size(), 1);
1818   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1819 }
1820 
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)1821 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
1822   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1823   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1824                                    ops::Placeholder::Shape({2, 3}));
1825   Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
1826   Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
1827   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1828 
1829   GrapplerItem item;
1830   item.fetch = {"outputs"};
1831   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1832 
1833   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1834   item.feed = {{"inputs", x_t}};
1835   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1836   ASSERT_EQ(tensors_expected.size(), 1);
1837 
1838   GraphDef output;
1839   ArithmeticOptimizer optimizer;
1840   EnableOnlyRemoveRedundantBitcast(&optimizer);
1841 
1842   OptimizeAndPrune(&optimizer, &item, &output);
1843   NodeMap node_map(&output);
1844 
1845   // Bitcasts removed and inputs redirected to outputs
1846   EXPECT_EQ(output.node_size(), 2);
1847   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
1848   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1849 
1850   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1851   ASSERT_EQ(tensors.size(), 1);
1852   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1853 }
1854 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)1855 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
1856   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1857   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1858                                    ops::Placeholder::Shape({2, 3}));
1859   Output cast = ops::Cast(s, inputs, DT_INT8);
1860   Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
1861 
1862   GrapplerItem item;
1863   item.fetch = {"outputs"};
1864   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1865 
1866   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1867   item.feed = {{"inputs", x_t}};
1868   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1869   ASSERT_EQ(tensors_expected.size(), 1);
1870 
1871   GraphDef output;
1872   ArithmeticOptimizer optimizer;
1873   EnableOnlyRemoveRedundantCast(&optimizer);
1874 
1875   OptimizeAndPrune(&optimizer, &item, &output);
1876   NodeMap node_map(&output);
1877 
1878   // Cast removed and inputs redirected to outputs
1879   EXPECT_EQ(output.node_size(), 2);
1880   EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
1881   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1882 
1883   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1884   ASSERT_EQ(tensors.size(), 1);
1885   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1886 }
1887 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)1888 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
1889   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1890   tensorflow::Scope sx = s.NewSubScope("x");
1891   tensorflow::Scope sy = s.NewSubScope("y");
1892 
1893   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1894   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1895   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1896   auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
1897   auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
1898 
1899   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
1900 
1901   GrapplerItem item;
1902   item.fetch = {"outputs"};
1903   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1904 
1905   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1906   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1907   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1908   std::vector<std::pair<string, Tensor>> feed = {
1909       {"a", a_t}, {"b", b_t}, {"c", c_t}};
1910   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1911   ASSERT_EQ(tensors_expected.size(), 1);
1912 
1913   GraphDef output;
1914   ArithmeticOptimizer optimizer;
1915   EnableOnlyAddToAddNCombining(&optimizer);
1916 
1917   OptimizeAndPrune(&optimizer, &item, &output);
1918 
1919   // We expect the following rewrite(s) to occur:
1920   //
1921   //     +
1922   //    / \
1923   //   a   +         -->    AddN(a, b, c)
1924   //      / \
1925   //     b   c
1926   EXPECT_EQ(output.node_size(), 5);
1927 
1928   NodeMap node_map(&output);
1929 
1930   // check add tree was replaced with AddN
1931   const NodeDef* collapsed_add =
1932       node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1933   ASSERT_NE(collapsed_add, nullptr);
1934 
1935   EXPECT_EQ(collapsed_add->op(), "AddN");
1936   ASSERT_EQ(collapsed_add->input_size(), 3);
1937   EXPECT_EQ(collapsed_add->input(0), "a");
1938   EXPECT_EQ(collapsed_add->input(1), "b");
1939   EXPECT_EQ(collapsed_add->input(2), "c");
1940 
1941   // check output was re-wired to new node
1942   const NodeDef* updated_outputs = node_map.GetNode("outputs");
1943   ASSERT_NE(updated_outputs, nullptr);
1944   ASSERT_EQ(updated_outputs->input_size(), 1);
1945   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
1946 
1947   auto tensors = EvaluateNodes(output, item.fetch, feed);
1948   ASSERT_EQ(tensors.size(), 1);
1949   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1950 }
1951 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)1952 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
1953   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1954 
1955   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1956   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1957   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1958   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
1959   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
1960 
1961   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
1962   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
1963   auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
1964   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
1965   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
1966 
1967   auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
1968   auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
1969 
1970   GrapplerItem item;
1971   item.fetch = {"outputs"};
1972   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1973 
1974   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1975   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1976   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1977   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1978   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1979   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1980   std::vector<std::pair<string, Tensor>> feed = {
1981       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
1982   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1983   ASSERT_EQ(tensors_expected.size(), 1);
1984 
1985   GraphDef output;
1986   ArithmeticOptimizer optimizer;
1987   EnableOnlyAddToAddNCombining(&optimizer);
1988 
1989   OptimizeAndPrune(&optimizer, &item, &output);
1990 
1991   // We expect the following rewrite(s) to occur:
1992   //
1993   //         *
1994   //      /     \
1995   //     +       +                        *
1996   //    / \     / \                    /     \
1997   //   +   c   x   + -->    AddN(a, b, c)  AddN(x, y, z))
1998   //  / \         / \
1999   // a   b       y   z
2000   EXPECT_EQ(output.node_size(), 10);
2001 
2002   NodeMap node_map(&output);
2003 
2004   // check left Add subtree replaced with AddN
2005   const NodeDef* collapsed_left =
2006       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2007   ASSERT_NE(collapsed_left, nullptr);
2008 
2009   EXPECT_EQ(collapsed_left->op(), "AddN");
2010   ASSERT_EQ(collapsed_left->input_size(), 3);
2011   EXPECT_EQ(collapsed_left->input(0), "a");
2012   EXPECT_EQ(collapsed_left->input(1), "b");
2013   EXPECT_EQ(collapsed_left->input(2), "c");
2014 
2015   // check right Add subtree replaced with AddN
2016   const NodeDef* collapsed_right =
2017       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2018   ASSERT_NE(collapsed_right, nullptr);
2019 
2020   EXPECT_EQ(collapsed_right->op(), "AddN");
2021   ASSERT_EQ(collapsed_right->input_size(), 3);
2022   EXPECT_EQ(collapsed_right->input(0), "x");
2023   EXPECT_EQ(collapsed_right->input(1), "y");
2024   EXPECT_EQ(collapsed_right->input(2), "z");
2025 
2026   // check that Mul inputs re-wired to new Nodes
2027   const NodeDef* updated_mul = node_map.GetNode("Mul");
2028   ASSERT_NE(updated_mul, nullptr);
2029 
2030   EXPECT_EQ(updated_mul->op(), "Mul");
2031   ASSERT_EQ(updated_mul->input_size(), 2);
2032   EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2033   EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2034 
2035   auto tensors = EvaluateNodes(output, item.fetch, feed);
2036   ASSERT_EQ(tensors.size(), 1);
2037   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2038 }
2039 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2040 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2041   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2042 
2043   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2044   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2045   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2046   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2047   auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2048   auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2049   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2050 
2051   GrapplerItem item;
2052   item.fetch = {"outputs"};
2053   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2054 
2055   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2056   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2057   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2058   std::vector<std::pair<string, Tensor>> feed = {
2059       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2060   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2061   ASSERT_EQ(tensors_expected.size(), 1);
2062 
2063   GraphDef output;
2064   ArithmeticOptimizer optimizer;
2065   EnableOnlyAddToAddNCombining(&optimizer);
2066 
2067   OptimizeAndPrune(&optimizer, &item, &output);
2068 
2069   // We expect the following rewrite(s) to occur:
2070   //
2071   //     +
2072   //    / \
2073   //   +   +     -->    AddN(a, b, b, c)
2074   //  / \ / \                   ^
2075   // a   b   c                  b added twice!
2076   EXPECT_EQ(output.node_size(), 5);
2077 
2078   NodeMap node_map(&output);
2079 
2080   // check Add tree replaced with AddN
2081   const NodeDef* collapsed_add =
2082       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2083   ASSERT_NE(collapsed_add, nullptr);
2084 
2085   EXPECT_EQ(collapsed_add->op(), "AddN");
2086   ASSERT_EQ(collapsed_add->input_size(), 4);
2087   EXPECT_EQ(collapsed_add->input(0), "a");
2088   EXPECT_EQ(collapsed_add->input(1), "b");
2089   EXPECT_EQ(collapsed_add->input(2), "b");
2090   EXPECT_EQ(collapsed_add->input(3), "c");
2091 
2092   auto tensors = EvaluateNodes(output, item.fetch, feed);
2093   ASSERT_EQ(tensors.size(), 1);
2094   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2095 }
2096 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2097 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2098   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2099 
2100   // unknown input shape propagated symbolically through the graph
2101   auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2102 
2103   // [a, b, c] have symbolically equal shapes
2104   auto a = ops::Sqrt(s.WithOpName("a"), input);
2105   auto b = ops::Square(s.WithOpName("b"), input);
2106   auto c = ops::Round(s.WithOpName("c"), input);
2107 
2108   // [add_ab, add_abc] shape must be inferred from inputs
2109   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2110   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2111 
2112   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2113 
2114   GrapplerItem item;
2115   item.fetch = {"outputs"};
2116   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2117 
2118   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2119   std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2120   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2121   ASSERT_EQ(tensors_expected.size(), 1);
2122 
2123   GraphDef output;
2124   ArithmeticOptimizer optimizer;
2125   EnableOnlyAddToAddNCombining(&optimizer);
2126 
2127   OptimizeAndPrune(&optimizer, &item, &output);
2128 
2129   // We expect the following rewrite(s) to occur:
2130   //
2131   //     +
2132   //    / \
2133   //   +   c      -->    AddN(a, b, c)
2134   //  / \
2135   // a   b
2136   EXPECT_EQ(output.node_size(), 6);
2137 
2138   NodeMap node_map(&output);
2139 
2140   // check add tree was replaced with AddN
2141   const NodeDef* collapsed_add =
2142       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2143   ASSERT_NE(collapsed_add, nullptr);
2144   EXPECT_EQ(collapsed_add->op(), "AddN");
2145   ASSERT_EQ(collapsed_add->input_size(), 3);
2146   EXPECT_EQ(collapsed_add->input(0), "a");
2147   EXPECT_EQ(collapsed_add->input(1), "b");
2148   EXPECT_EQ(collapsed_add->input(2), "c");
2149 
2150   // check output was re-wired to new node
2151   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2152   ASSERT_NE(updated_outputs, nullptr);
2153   ASSERT_EQ(updated_outputs->input_size(), 1);
2154   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2155 
2156   auto tensors = EvaluateNodes(output, item.fetch, feed);
2157   ASSERT_EQ(tensors.size(), 1);
2158   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2159 }
2160 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2161 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2162   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2163 
2164   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2165   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2166   auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2167   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2168   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2169 
2170   auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2171   auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2172   auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2173   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2174   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2175 
2176   auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2177   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2178 
2179   GrapplerItem item;
2180   item.fetch = {"outputs"};
2181   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2182 
2183   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2184   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2185   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2186   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2187   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2188   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2189   std::vector<std::pair<string, Tensor>> feed = {
2190       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2191   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2192   ASSERT_EQ(tensors_expected.size(), 1);
2193 
2194   GraphDef output;
2195   ArithmeticOptimizer optimizer;
2196   EnableOnlyAddToAddNCombining(&optimizer);
2197 
2198   OptimizeAndPrune(&optimizer, &item, &output);
2199 
2200   // We expect the following rewrite(s) to occur:
2201   //  1) [a, x], [b, y], [c, z] - aggregate same shapes first
2202   //  2) Build an aggregation tree minimizing cost of broadcast
2203   //
2204   //         +                              +
2205   //      /     \                       /       \
2206   //     +       +                     +       AddN(c, z)
2207   //    / \     / \                 /     \
2208   //   +   c   x   + -->    AddN(a, x)  AddN(b, y)
2209   //  / \         / \
2210   // a   b       y   z
2211   EXPECT_EQ(output.node_size(), 12);
2212   NodeMap node_map(&output);
2213 
2214   // expected names of outer and inner nodes
2215   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2216   string outer_0_add_name =
2217       "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2218   string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2219   string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2220   string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2221 
2222   // Add [a, x] first
2223   const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2224   ASSERT_NE(add_ax_node, nullptr);
2225   EXPECT_EQ(add_ax_node->op(), "AddN");
2226   ASSERT_EQ(add_ax_node->input_size(), 2);
2227   EXPECT_EQ(add_ax_node->input(0), "a");
2228   EXPECT_EQ(add_ax_node->input(1), "x");
2229 
2230   // Then add [b, y]
2231   const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2232   ASSERT_NE(add_by_node, nullptr);
2233   EXPECT_EQ(add_by_node->op(), "AddN");
2234   ASSERT_EQ(2, add_by_node->input_size());
2235   EXPECT_EQ(add_by_node->input(0), "b");
2236   EXPECT_EQ(add_by_node->input(1), "y");
2237 
2238   // Then add [c, z]
2239   const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2240   ASSERT_NE(add_cz_node, nullptr);
2241   EXPECT_EQ(add_cz_node->op(), "AddN");
2242   ASSERT_EQ(add_cz_node->input_size(), 2);
2243   EXPECT_EQ(add_cz_node->input(0), "c");
2244   EXPECT_EQ(add_cz_node->input(1), "z");
2245 
2246   // Then add results together starting from smaller shapes [a, x] + [b, y]
2247   const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2248   ASSERT_NE(outer_0_node, nullptr);
2249   EXPECT_EQ(outer_0_node->op(), "AddV2");
2250   ASSERT_EQ(outer_0_node->input_size(), 2);
2251   EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2252   EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2253 
2254   // And finally top level Add node
2255   const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2256   ASSERT_NE(outer_node, nullptr);
2257   EXPECT_EQ(outer_node->op(), "AddV2");
2258   ASSERT_EQ(outer_node->input_size(), 2);
2259   EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2260   EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2261 
2262   // And outputs reading new top level Add node
2263   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2264   ASSERT_NE(updated_outputs, nullptr);
2265   ASSERT_EQ(updated_outputs->input_size(), 1);
2266   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2267 
2268   auto tensors = EvaluateNodes(output, item.fetch, feed);
2269   ASSERT_EQ(tensors.size(), 1);
2270   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2271 }
2272 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2273 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2274   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2275 
2276   // We have a small input with one unknown dimension
2277   auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2278 
2279   // And second input which is larger, but has the same unknown dimension
2280   // device spec prevents this node from rewriting
2281   auto d = "/device:CPU:0";
2282   auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2283   auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2284 
2285   // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2286   auto a = ops::Sqrt(s.WithOpName("a"), small);
2287   auto b = ops::Square(s.WithOpName("b"), large);
2288   auto c = ops::Round(s.WithOpName("c"), small);
2289 
2290   // [add_ab, add_abc] shape must be inferred from inputs
2291   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2292   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2293 
2294   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2295 
2296   GrapplerItem item;
2297   item.fetch = {"outputs"};
2298   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2299 
2300   auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2301   auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2302   std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2303   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2304   ASSERT_EQ(tensors_expected.size(), 1);
2305 
2306   GraphDef output;
2307   ArithmeticOptimizer optimizer;
2308   EnableOnlyAddToAddNCombining(&optimizer);
2309   OptimizeAndPrune(&optimizer, &item, &output);
2310 
2311   // We expect the following rewrite(s) to occur: it's much cheaper to add small
2312   // tensors, and do the broadcast just once
2313   //
2314   //     +                  +
2315   //    / \                / \
2316   //   +   c      -->     +   b
2317   //  / \                / \
2318   // a   b              a   c
2319   EXPECT_EQ(output.node_size(), 9);
2320   NodeMap node_map(&output);
2321 
2322   // expected names of outer and inner nodes
2323   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2324   string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2325 
2326   // outer Add node
2327   const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2328   ASSERT_NE(outer_add, nullptr);
2329   EXPECT_EQ(outer_add->op(), "AddV2");
2330   ASSERT_EQ(outer_add->input_size(), 2);
2331   EXPECT_EQ(outer_add->input(0), inner_add_name);
2332   EXPECT_EQ(outer_add->input(1), "b");
2333 
2334   // inner AddN node
2335   const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2336   ASSERT_NE(inner_add, nullptr);
2337   ASSERT_EQ(inner_add->input_size(), 2);
2338   EXPECT_EQ(inner_add->input(0), "a");
2339   EXPECT_EQ(inner_add->input(1), "c");
2340 
2341   // check output was re-wired to new node
2342   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2343   ASSERT_NE(updated_outputs, nullptr);
2344   ASSERT_EQ(updated_outputs->input_size(), 1);
2345   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2346 
2347   auto tensors = EvaluateNodes(output, item.fetch, feed);
2348   ASSERT_EQ(tensors.size(), 1);
2349   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2350 }
2351 
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2352 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2353   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2354   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2355   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2356   Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2357   Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2358   Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2359   Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2360   Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2361   Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2362   Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2363   Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2364   Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2365   Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2366   Output neg_x_with_dep = ops::Neg(
2367       s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2368   Output add_negx_with_dep_y =
2369       ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2370   auto add_all =
2371       ops::AddN(s.WithOpName("add_all"),
2372                 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2373                  sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2374 
2375   GrapplerItem item;
2376   item.fetch = {"add_all"};
2377   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2378 
2379   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2380   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2381   std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2382   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2383   ASSERT_EQ(tensors_expected.size(), 1);
2384 
2385   GraphDef output;
2386   ArithmeticOptimizer optimizer;
2387   EnableOnlyRemoveNegation(&optimizer);
2388   OptimizeTwice(&optimizer, &item, &output);
2389 
2390   EXPECT_EQ(output.node_size(), item.graph.node_size());
2391   int found = 0;
2392   for (int i = 0; i < output.node_size(); ++i) {
2393     const NodeDef& node = output.node(i);
2394     if (node.name() == "Add_negx_y") {
2395       ++found;
2396       EXPECT_EQ(node.op(), "Sub");
2397       ASSERT_EQ(node.input_size(), 2);
2398       EXPECT_EQ(node.input(0), "y");
2399       EXPECT_EQ(node.input(1), "x");
2400     } else if (node.name() == "Add_x_negy") {
2401       ++found;
2402       EXPECT_EQ(node.op(), "Sub");
2403       ASSERT_EQ(node.input_size(), 2);
2404       EXPECT_EQ(node.input(0), "x");
2405       EXPECT_EQ(node.input(1), "y");
2406     } else if (node.name() == "Add_negx_negy") {
2407       ++found;
2408       EXPECT_EQ(node.op(), "Sub");
2409       ASSERT_EQ(node.input_size(), 2);
2410       EXPECT_EQ(node.input(0), "Neg_x");
2411       EXPECT_EQ(node.input(1), "y");
2412     } else if (node.name() == "Sub_x_negy") {
2413       ++found;
2414       EXPECT_EQ(node.op(), "AddV2");
2415       ASSERT_EQ(node.input_size(), 2);
2416       EXPECT_EQ(node.input(0), "x");
2417       EXPECT_EQ(node.input(1), "y");
2418     } else if (node.name() == "Sub_negx_negy") {
2419       ++found;
2420       EXPECT_EQ(node.op(), "Sub");
2421       ASSERT_EQ(node.input_size(), 2);
2422       EXPECT_EQ(node.input(0), "y");
2423       EXPECT_EQ(node.input(1), "x");
2424     } else if (node.name() == "Add_negx_with_dep_y") {
2425       ++found;
2426       EXPECT_EQ(node.op(), "Sub");
2427       ASSERT_EQ(node.input_size(), 3);
2428       EXPECT_EQ(node.input(0), "y");
2429       EXPECT_EQ(node.input(1), "x");
2430       EXPECT_EQ(node.input(2), "^Add_x_y");
2431     }
2432   }
2433   EXPECT_EQ(found, 6);
2434 
2435   auto tensors = EvaluateNodes(output, item.fetch, feed);
2436   ASSERT_EQ(tensors.size(), 1);
2437   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2438 }
2439 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2440 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2441   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2442   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2443   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2444   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2445   Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2446 
2447   GrapplerItem item;
2448   item.fetch = {"output"};
2449   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2450   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2451   ASSERT_EQ(tensors_expected.size(), 1);
2452 
2453   GraphDef output;
2454   ArithmeticOptimizer optimizer;
2455   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2456   OptimizeAndPrune(&optimizer, &item, &output);
2457   auto tensors = EvaluateNodes(output, item.fetch);
2458   ASSERT_EQ(tensors.size(), 1);
2459 
2460   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2461   EXPECT_EQ(output.node_size(), item.graph.node_size());
2462   for (int i = 0; i < output.node_size(); ++i) {
2463     const NodeDef& node = output.node(i);
2464     if (node.name() == "output") {
2465       EXPECT_EQ(node.op(), "Mul");
2466       ASSERT_EQ(node.input_size(), 2);
2467       EXPECT_EQ(node.input(0), "x");
2468       EXPECT_EQ(node.input(1), "sqrt_y");
2469     } else if (node.name() == "sqrt_y") {
2470       EXPECT_EQ(node.op(), "Rsqrt");
2471       ASSERT_EQ(node.input_size(), 1);
2472       EXPECT_EQ(node.input(0), "y");
2473     }
2474   }
2475 }
2476 
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2477 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2478   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2479   Output floats = ops::Const(s.WithOpName("floats"),
2480                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2481   Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2482   Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2483   Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2484   Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2485 
2486   GrapplerItem item;
2487   item.fetch = {"grad", "output0"};
2488   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2489   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2490   ASSERT_EQ(tensors_expected.size(), 2);
2491 
2492   GraphDef output;
2493   ArithmeticOptimizer optimizer;
2494   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2495   OptimizeAndPrune(&optimizer, &item, &output);
2496   auto tensors = EvaluateNodes(output, item.fetch);
2497   ASSERT_EQ(tensors.size(), 2);
2498 
2499   for (int i = 0; i < tensors.size(); i++) {
2500     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2501     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2502   }
2503   EXPECT_EQ(output.node_size(), item.graph.node_size());
2504   for (int i = 0; i < output.node_size(); ++i) {
2505     const NodeDef& node = output.node(i);
2506     if (node.name() == "grad") {
2507       EXPECT_EQ(node.op(), "Div");
2508       ASSERT_EQ(node.input_size(), 2);
2509       EXPECT_EQ(node.input(0), "mul1");
2510       EXPECT_EQ(node.input(1), "output0");
2511     } else if (node.name() == "output0") {
2512       EXPECT_EQ(node.op(), "Sqrt");
2513       ASSERT_EQ(node.input_size(), 1);
2514       EXPECT_EQ(node.input(0), "floats");
2515     }
2516   }
2517 }
2518 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2519 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2520   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2521   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2522   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2523   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2524   Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2525 
2526   GrapplerItem item;
2527   item.fetch = {"output"};
2528   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2529   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2530   ASSERT_EQ(tensors_expected.size(), 1);
2531 
2532   GraphDef output;
2533   ArithmeticOptimizer optimizer;
2534   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2535   OptimizeAndPrune(&optimizer, &item, &output);
2536   auto tensors = EvaluateNodes(output, item.fetch);
2537   ASSERT_EQ(tensors.size(), 1);
2538 
2539   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2540   EXPECT_EQ(output.node_size(), item.graph.node_size());
2541   for (int i = 0; i < output.node_size(); ++i) {
2542     const NodeDef& node = output.node(i);
2543     if (node.name() == "output") {
2544       EXPECT_EQ(node.op(), "FloorDiv");
2545       ASSERT_EQ(node.input_size(), 2);
2546       EXPECT_EQ(node.input(0), "x");
2547       EXPECT_EQ(node.input(1), "sqrt_y");
2548     } else if (node.name() == "sqrt_y") {
2549       EXPECT_EQ(node.op(), "Sqrt");
2550       ASSERT_EQ(node.input_size(), 1);
2551       EXPECT_EQ(node.input(0), "y");
2552     }
2553   }
2554 }
2555 
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)2556 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
2557   for (bool is_complex : {false, true}) {
2558     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2559     Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2560     Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2561     Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
2562     Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
2563     Output sub_x_y =
2564         is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
2565                    : ops::Sub(s.WithOpName("sub_x_y"), x, y);
2566     Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2567 
2568     GrapplerItem item;
2569     item.fetch = {"output"};
2570     TF_CHECK_OK(s.ToGraphDef(&item.graph));
2571     const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2572     ASSERT_EQ(tensors_expected.size(), 1);
2573 
2574     GraphDef output;
2575     ArithmeticOptimizer optimizer;
2576     EnableOnlyFuseSquaredDiff(&optimizer);
2577     OptimizeAndPrune(&optimizer, &item, &output);
2578     const auto tensors = EvaluateNodes(output, item.fetch);
2579     ASSERT_EQ(tensors.size(), 1);
2580 
2581     if (is_complex) {
2582       test::ExpectTensorNear<std::complex<float>>(tensors[0],
2583                                                   tensors_expected[0], 1e-6);
2584       EXPECT_EQ(output.node_size(), item.graph.node_size());
2585     } else {
2586       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2587       // The two unused Complex nodes should get pruned.
2588       EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
2589     }
2590     for (int i = 0; i < output.node_size(); ++i) {
2591       const NodeDef& node = output.node(i);
2592       if (node.name() == "output") {
2593         EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
2594         ASSERT_EQ(node.input_size(), 1);
2595         EXPECT_EQ(node.input(0), "sub_x_y");
2596       } else if (node.name() == "sub_x_y") {
2597         EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
2598         ASSERT_EQ(node.input_size(), 2);
2599         EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
2600         EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
2601       }
2602     }
2603   }
2604 }
2605 
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)2606 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
2607   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2608   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2609   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2610   Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2611   Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2612 
2613   GrapplerItem item;
2614   item.fetch = {"output", "sub_x_y"};
2615   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2616   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2617   ASSERT_EQ(tensors_expected.size(), 2);
2618 
2619   GraphDef output;
2620   ArithmeticOptimizer optimizer;
2621   EnableOnlyFuseSquaredDiff(&optimizer);
2622   OptimizeAndPrune(&optimizer, &item, &output);
2623   const auto tensors = EvaluateNodes(output, item.fetch);
2624   ASSERT_EQ(tensors.size(), 2);
2625 
2626   for (int i = 0; i < tensors.size(); i++) {
2627     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2628     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2629   }
2630   EXPECT_EQ(output.node_size(), item.graph.node_size());
2631   for (int i = 0; i < output.node_size(); ++i) {
2632     const NodeDef& node = output.node(i);
2633     if (node.name() == "output") {
2634       EXPECT_EQ(node.op(), "Square");
2635       ASSERT_EQ(node.input_size(), 1);
2636       EXPECT_EQ(node.input(0), "sub_x_y");
2637     } else if (node.name() == "sub_x_y") {
2638       EXPECT_EQ(node.op(), "Sub");
2639       ASSERT_EQ(node.input_size(), 2);
2640       EXPECT_EQ(node.input(0), "x");
2641       EXPECT_EQ(node.input(1), "y");
2642     }
2643   }
2644 }
2645 
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)2646 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
2647   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2648   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2649   Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
2650   Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
2651 
2652   GrapplerItem item;
2653   item.fetch = {"output"};
2654   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2655   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2656   ASSERT_EQ(tensors_expected.size(), 1);
2657 
2658   GraphDef output;
2659   ArithmeticOptimizer optimizer;
2660   EnableOnlyLogSoftmax(&optimizer);
2661   OptimizeAndPrune(&optimizer, &item, &output);
2662   const auto tensors = EvaluateNodes(output, item.fetch);
2663   ASSERT_EQ(tensors.size(), 1);
2664 
2665   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2666   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
2667   for (int i = 0; i < output.node_size(); ++i) {
2668     const NodeDef& node = output.node(i);
2669     if (node.name() == "output") {
2670       EXPECT_EQ(node.op(), "LogSoftmax");
2671       ASSERT_EQ(node.input_size(), 1);
2672       EXPECT_EQ(node.input(0), "x");
2673     }
2674   }
2675 }
2676 
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)2677 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
2678   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2679   Output floats = ops::Const(s.WithOpName("floats"),
2680                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2681   Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
2682   Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
2683 
2684   GrapplerItem item;
2685   item.fetch = {"softmax", "final_output"};
2686   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2687   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2688   ASSERT_EQ(tensors_expected.size(), 2);
2689 
2690   GraphDef output;
2691   ArithmeticOptimizer optimizer;
2692   EnableOnlyLogSoftmax(&optimizer);
2693   OptimizeTwice(&optimizer, &item, &output);
2694   const auto tensors = EvaluateNodes(output, item.fetch);
2695   ASSERT_EQ(tensors.size(), 2);
2696 
2697   // Should be a NoOp since we are not allowed to change the output of fetch
2698   // nodes.
2699   VerifyGraphsMatch(item.graph, output, __LINE__);
2700 
2701   for (int i = 0; i < tensors.size(); i++) {
2702     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2703     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2704   }
2705 }
2706 
TEST_F(ArithmeticOptimizerTest,ConvertPow)2707 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
2708   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2709   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2710   auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
2711   auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
2712   auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
2713   auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
2714   auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
2715   auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
2716   auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
2717   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2718   auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
2719   auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
2720   auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
2721   Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
2722   Output out3 =
2723       ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
2724   Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
2725   Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
2726   Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
2727   Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
2728   Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
2729   Output out = ops::Pow(s.WithOpName("out"), x, y);
2730   Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
2731   Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
2732 
2733   GrapplerItem item;
2734   item.fetch = {"out2",   "out3",  "out1", "out.5",      "out0",
2735                 "out_.5", "out_1", "out",  "out_bcast1", "out_bcast2"};
2736   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2737   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2738   ASSERT_EQ(tensors_expected.size(), 10);
2739 
2740   GraphDef got;
2741   ArithmeticOptimizer optimizer;
2742   EnableOnlyConvertPow(&optimizer);
2743   OptimizeAndPrune(&optimizer, &item, &got);
2744   auto tensors = EvaluateNodes(got, item.fetch);
2745   ASSERT_EQ(tensors.size(), 10);
2746 
2747   for (int i = 0; i < tensors.size(); ++i) {
2748     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2749     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2750   }
2751 
2752   GraphDef want;
2753   AddNode("x", "Const", {}, {}, &want);
2754   AddNode("y", "Const", {}, {}, &want);
2755   AddNode("z", "Const", {}, {}, &want);
2756   AddNode("ones", "Const", {}, {}, &want);
2757   AddNode("zeros", "Const", {}, {}, &want);
2758   AddNode("out2", "Square", {"x"}, {}, &want);
2759   AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
2760           &want)
2761       ->set_device("/device:CPU:0");
2762   AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
2763           {}, &want)
2764       ->set_device("/device:CPU:0");
2765   AddNode("out1", "Identity", {"x"}, {}, &want);
2766   AddNode("out.5", "Sqrt", {"x"}, {}, &want);
2767   AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
2768   AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
2769   AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
2770   AddNode("out", "Pow", {"x", "y"}, {}, &want);
2771   AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
2772   AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
2773 
2774   CompareGraphs(want, got);
2775 }
2776 
TEST_F(ArithmeticOptimizerTest,Log1p)2777 TEST_F(ArithmeticOptimizerTest, Log1p) {
2778   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2779 
2780   auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
2781   auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
2782   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2783   auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
2784   auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
2785   Output out1 = ops::Log(s.WithOpName("out1"), a12);
2786   Output out2 = ops::Log(s.WithOpName("out2"), a23);
2787 
2788   GrapplerItem item;
2789   item.fetch = {"out1", "out2"};
2790   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2791   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2792   ASSERT_EQ(tensors_expected.size(), 2);
2793 
2794   GraphDef got;
2795   ArithmeticOptimizer optimizer;
2796   EnableOnlyLog1p(&optimizer);
2797   OptimizeAndPrune(&optimizer, &item, &got);
2798   auto tensors = EvaluateNodes(got, item.fetch);
2799   ASSERT_EQ(tensors.size(), 2);
2800 
2801   for (int i = 0; i < 2; ++i) {
2802     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2803     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2804   }
2805 
2806   GraphDef want;
2807   AddNode("x2", "Const", {}, {}, &want);
2808   AddNode("x3", "Const", {}, {}, &want);
2809   AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
2810   AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
2811   AddNode("out2", "Log", {"a23"}, {}, &want);
2812 
2813   CompareGraphs(want, got);
2814 }
2815 
TEST_F(ArithmeticOptimizerTest,Expm1)2816 TEST_F(ArithmeticOptimizerTest, Expm1) {
2817   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2818 
2819   auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
2820   auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
2821   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2822   auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
2823   Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
2824   Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
2825 
2826   GrapplerItem item;
2827   item.fetch = {"out1", "out2"};
2828   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2829   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2830   ASSERT_EQ(tensors_expected.size(), 2);
2831 
2832   GraphDef got;
2833   ArithmeticOptimizer optimizer;
2834   EnableOnlyExpm1(&optimizer);
2835   OptimizeAndPrune(&optimizer, &item, &got);
2836   auto tensors = EvaluateNodes(got, item.fetch);
2837   ASSERT_EQ(tensors.size(), 2);
2838 
2839   for (int i = 0; i < 2; ++i) {
2840     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2841     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2842   }
2843 
2844   GraphDef want;
2845   AddNode("x1", "Const", {}, {}, &want);
2846   AddNode("x3", "Const", {}, {}, &want);
2847   AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
2848   AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
2849   AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
2850 
2851   CompareGraphs(want, got);
2852 }
2853 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)2854 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
2855   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2856 
2857   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2858   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2859   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2860 
2861   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2862   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2863 
2864   auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
2865 
2866   GrapplerItem item;
2867   item.fetch = {"outputs"};
2868   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2869 
2870   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2871   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2872   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2873   std::vector<std::pair<string, Tensor>> feed = {
2874       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2875   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2876   ASSERT_EQ(tensors_expected.size(), 1);
2877 
2878   GraphDef output;
2879   ArithmeticOptimizer optimizer;
2880   EnableOnlyMinimizeBroadcasts(&optimizer);
2881 
2882   OptimizeAndPrune(&optimizer, &item, &output);
2883 
2884   // We expect the following rewrite(s) to occur:
2885   //
2886   //     *                  *
2887   //    / \                / \
2888   //   *   c      -->     *   b
2889   //  / \                / \
2890   // a   b              a   c
2891   NodeMap node_map(&output);
2892 
2893   const NodeDef* mul1_node = node_map.GetNode("mul1");
2894   ASSERT_NE(mul1_node, nullptr);
2895   ASSERT_EQ(mul1_node->input_size(), 2);
2896   EXPECT_EQ(mul1_node->input(0), "a");
2897   EXPECT_EQ(mul1_node->input(1), "c");
2898 
2899   const NodeDef* mul2_node = node_map.GetNode("mul2");
2900   ASSERT_NE(mul2_node, nullptr);
2901   ASSERT_EQ(mul2_node->input_size(), 2);
2902   EXPECT_EQ(mul2_node->input(0), "mul1");
2903   EXPECT_EQ(mul2_node->input(1), "b");
2904 
2905   auto tensors = EvaluateNodes(output, item.fetch, feed);
2906   ASSERT_EQ(tensors.size(), 1);
2907   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2908 }
2909 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)2910 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
2911   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2912 
2913   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
2914   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
2915   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
2916   auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
2917   auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
2918 
2919   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2920   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2921   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
2922   auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
2923 
2924   auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
2925 
2926   GrapplerItem item;
2927   item.fetch = {"outputs"};
2928   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2929 
2930   auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2931   auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
2932   auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2933   auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2934   auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2935   std::vector<std::pair<string, Tensor>> feed = {
2936       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
2937   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2938   ASSERT_EQ(tensors_expected.size(), 1);
2939 
2940   GraphDef output;
2941   ArithmeticOptimizer optimizer;
2942   EnableOnlyMinimizeBroadcasts(&optimizer);
2943 
2944   OptimizeAndPrune(&optimizer, &item, &output);
2945 
2946   // We expect the following rewrite(s) to occur: Graph is "flattened" and
2947   // largest shape pushed to the top.
2948   //
2949   //          *
2950   //        /   \
2951   //       *     e                *
2952   //      /  \                  /   \
2953   //     *    d               *      b
2954   //    / \                 /  \
2955   //   *   c      -->     *      *
2956   //  / \                / \    / \
2957   // a   b              a   c  d   e
2958   NodeMap node_map(&output);
2959 
2960   const NodeDef* mul1_node = node_map.GetNode("mul1");
2961   ASSERT_NE(mul1_node, nullptr);
2962   ASSERT_EQ(mul1_node->input_size(), 2);
2963   EXPECT_EQ(mul1_node->input(0), "a");
2964   EXPECT_EQ(mul1_node->input(1), "c");
2965 
2966   const NodeDef* mul2_node = node_map.GetNode("mul2");
2967   ASSERT_NE(mul2_node, nullptr);
2968   ASSERT_EQ(mul2_node->input_size(), 2);
2969   EXPECT_EQ(mul2_node->input(0), "d");
2970   EXPECT_EQ(mul2_node->input(1), "e");
2971 
2972   const NodeDef* mul3_node = node_map.GetNode("mul3");
2973   ASSERT_NE(mul3_node, nullptr);
2974   ASSERT_EQ(mul3_node->input_size(), 2);
2975   EXPECT_EQ(mul3_node->input(0), "mul1");
2976   EXPECT_EQ(mul3_node->input(1), "mul2");
2977 
2978   const NodeDef* mul4_node = node_map.GetNode("mul4");
2979   ASSERT_NE(mul4_node, nullptr);
2980   ASSERT_EQ(mul4_node->input_size(), 2);
2981   EXPECT_EQ(mul4_node->input(0), "mul3");
2982   EXPECT_EQ(mul4_node->input(1), "b");
2983 
2984   auto tensors = EvaluateNodes(output, item.fetch, feed);
2985   ASSERT_EQ(tensors.size(), 1);
2986   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2987 }
2988 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)2989 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
2990   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2991 
2992   // [a, b, c] - scalars, [d] - matrix
2993   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2994   auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
2995   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2996   auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
2997 
2998   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2999   auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3000   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3001 
3002   auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3003 
3004   GrapplerItem item;
3005   item.fetch = {"outputs"};
3006   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3007 
3008   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3009   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3010   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3011   auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3012   std::vector<std::pair<string, Tensor>> feed = {
3013       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3014   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3015   ASSERT_EQ(tensors_expected.size(), 1);
3016 
3017   GraphDef output;
3018   ArithmeticOptimizer optimizer;
3019   EnableOnlyMinimizeBroadcasts(&optimizer);
3020 
3021   OptimizeAndPrune(&optimizer, &item, &output);
3022 
3023   // We expect the following rewrite(s) to occur:
3024   //
3025   //                              *
3026   //                            /  \
3027   //       *                   *    D
3028   //     /   \                / \
3029   //    *     *      ->      *   c
3030   //   / \   / \            / \
3031   //  a   b c   D          a   b
3032   NodeMap node_map(&output);
3033 
3034   const NodeDef* mul1_node = node_map.GetNode("mul2");
3035   ASSERT_NE(mul1_node, nullptr);
3036   ASSERT_EQ(mul1_node->input_size(), 2);
3037   EXPECT_EQ(mul1_node->input(0), "a");
3038   EXPECT_EQ(mul1_node->input(1), "b");
3039 
3040   const NodeDef* mul2_node = node_map.GetNode("mul1");
3041   ASSERT_NE(mul2_node, nullptr);
3042   ASSERT_EQ(mul2_node->input_size(), 2);
3043   EXPECT_EQ(mul2_node->input(0), "mul2");
3044   EXPECT_EQ(mul2_node->input(1), "c");
3045 
3046   const NodeDef* mul3_node = node_map.GetNode("mul3");
3047   ASSERT_NE(mul3_node, nullptr);
3048   ASSERT_EQ(mul3_node->input_size(), 2);
3049   EXPECT_EQ(mul3_node->input(0), "D");
3050   EXPECT_EQ(mul3_node->input(1), "mul1");
3051 
3052   auto tensors = EvaluateNodes(output, item.fetch, feed);
3053   ASSERT_EQ(tensors.size(), 1);
3054   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3055 }
3056 
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3057 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3058   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3059   Output weights1 = ops::Const(s.WithOpName("weights1"),
3060                                Input::Initializer(1.0f, {5, 5, 3, 4}));
3061   Output weights2 = ops::Const(s.WithOpName("weights2"),
3062                                Input::Initializer(2.0f, {5, 5, 3, 4}));
3063   Output biases =
3064       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3065   Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3066   Output input = ops::Const(s.WithOpName("input"),
3067                             Input::Initializer(1.0f, {1, 28, 28, 3}));
3068   Output branch1 =
3069       ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3070   branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3071   branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3072   Output branch2 =
3073       ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3074   branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3075   branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3076   Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3077   Output output = ops::Identity(s.WithOpName("output"), concat);
3078 
3079   GrapplerItem item;
3080   item.fetch = {"output"};
3081   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3082 
3083   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3084 
3085   GraphDef new_graph;
3086   ArithmeticOptimizer optimizer;
3087   OptimizeAndPrune(&optimizer, &item, &new_graph);
3088 
3089   // Verify that the two Relus are not hoisted.
3090   EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3091 
3092   auto tensors = EvaluateNodes(new_graph, item.fetch);
3093   for (int i = 0; i < item.fetch.size(); ++i) {
3094     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3095   }
3096 }
3097 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3098 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3099   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3100   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3101   Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3102   Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3103   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3104   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3105   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3106   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3107   // Test case with chains of length 1.
3108   // Rewrites
3109   //       Concat({Exp(a), Exp(b), Exp(c)})
3110   // into
3111   //       Exp(Concat({a, b, c})).
3112   Output sin_a =
3113       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3114   Output exp_a =
3115       ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3116   Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3117   Output exp_c =
3118       ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3119   Output concat =
3120       ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3121   Output id = ops::Identity(s.WithOpName("id"), concat);
3122 
3123   // Test case with chains of length 2.
3124   // Rewrites
3125   //       Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3126   // into
3127   //       Cos(Exp(Concat({a, b, c}))).
3128   Output exp_a2 =
3129       ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3130   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3131   Output exp_c2 =
3132       ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3133   Output cos_exp_a2 = ops::Cos(
3134       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3135   Output cos_exp_b2 = ops::Cos(
3136       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3137   Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3138   Output concat2 = ops::Concat(s.WithOpName("concat2"),
3139                                {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3140   Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3141   GrapplerItem item;
3142   item.fetch = {"id", "id2"};
3143   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3144 
3145   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3146 
3147   GraphDef output;
3148   ArithmeticOptimizer optimizer;
3149   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3150   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3151 
3152   int found = 0;
3153   for (const NodeDef& node : output.node()) {
3154     if (node.name() == "concat") {
3155       ASSERT_EQ(node.input_size(), 4);
3156       EXPECT_EQ(node.input(0), "sin_a");
3157       EXPECT_EQ(node.input(1), "b");
3158       EXPECT_EQ(node.input(2), "c");
3159       EXPECT_EQ(node.input(3), "axis");
3160       found++;
3161     }
3162     if (node.name() == "exp_a") {
3163       ASSERT_EQ(node.input_size(), 1);
3164       EXPECT_EQ(node.input(0), "concat");
3165       found++;
3166     }
3167     if (node.name() == "id") {
3168       ASSERT_EQ(node.input_size(), 1);
3169       EXPECT_EQ(node.input(0), "exp_a");
3170       found++;
3171     }
3172 
3173     if (node.name() == "concat2") {
3174       ASSERT_EQ(node.input_size(), 4);
3175       EXPECT_EQ(node.input(0), "sin_a");
3176       EXPECT_EQ(node.input(1), "b");
3177       EXPECT_EQ(node.input(2), "c");
3178       EXPECT_EQ(node.input(3), "axis");
3179       found++;
3180     }
3181     if (node.name() == "exp_a2") {
3182       ASSERT_EQ(node.input_size(), 1);
3183       EXPECT_EQ(node.input(0), "concat2");
3184       found++;
3185     }
3186     if (node.name() == "cos_exp_a2") {
3187       ASSERT_EQ(node.input_size(), 1);
3188       EXPECT_EQ(node.input(0), "exp_a2");
3189       found++;
3190     }
3191     if (node.name() == "id2") {
3192       ASSERT_EQ(node.input_size(), 1);
3193       EXPECT_EQ(node.input(0), "cos_exp_a2");
3194       found++;
3195     }
3196   }
3197   EXPECT_EQ(found, 7);
3198 
3199   auto tensors = EvaluateNodes(output, item.fetch);
3200   ASSERT_EQ(tensors.size(), tensors_expected.size());
3201   EXPECT_EQ(tensors.size(), item.fetch.size());
3202   for (int i = 0; i < item.fetch.size(); ++i) {
3203     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3204   }
3205 }
3206 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3207 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3208   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3209   Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3210   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3211   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3212   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3213   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3214   // Test case with chains of length 1.
3215   // Rewrites
3216   //          [Sin(y) for y in Split(x)]
3217   // into
3218   //          [y for y in Split(Sin(x))].
3219   ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3220   Output sin_a =
3221       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3222   Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3223   Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3224   Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3225   Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3226 
3227   // Test case with SplitV and chains of length 2.
3228   // Rewrites
3229   //          [Cos(Exp(y)) for y in Split(x)]
3230   // into
3231   //          [y for y in Split(Cos(Exp(x)))].
3232   Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3233   ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3234   Output exp_a2 = ops::Exp(
3235       s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3236   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3237   Output cos_exp_a2 = ops::Cos(
3238       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3239   Output cos_exp_b2 = ops::Cos(
3240       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3241   Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3242   Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3243 
3244   GrapplerItem item;
3245   item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3246   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3247 
3248   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3249 
3250   GraphDef output;
3251   ArithmeticOptimizer optimizer;
3252   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3253   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3254 
3255   int found = 0;
3256   for (const NodeDef& node : output.node()) {
3257     // The following 6 nodes should be pruned.
3258     EXPECT_NE(node.name(), "sin_a");
3259     EXPECT_NE(node.name(), "sin_b");
3260     EXPECT_NE(node.name(), "exp_a2");
3261     EXPECT_NE(node.name(), "exp_b2");
3262     EXPECT_NE(node.name(), "cos_exp_a2");
3263     EXPECT_NE(node.name(), "cos_exp_b2");
3264 
3265     if (node.name() == "split1") {
3266       ASSERT_EQ(node.input_size(), 2);
3267       EXPECT_EQ(node.input(0), "axis");
3268       EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3269       found++;
3270     }
3271     if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3272       EXPECT_EQ(node.op(), "Sin");
3273       ASSERT_EQ(node.input_size(), 1);
3274       EXPECT_EQ(node.input(0), "x");
3275       found++;
3276     }
3277     if (node.name() == "id_a") {
3278       ASSERT_EQ(node.input_size(), 1);
3279       EXPECT_EQ(node.input(0), "split1");
3280       found++;
3281     }
3282     if (node.name() == "exp_b") {
3283       ASSERT_EQ(node.input_size(), 1);
3284       EXPECT_EQ(node.input(0), "split1:1");
3285       found++;
3286     }
3287     if (node.name() == "id_b") {
3288       ASSERT_EQ(node.input_size(), 1);
3289       EXPECT_EQ(node.input(0), "exp_b");
3290       found++;
3291     }
3292     if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3293       EXPECT_EQ(node.op(), "Exp");
3294       ASSERT_EQ(node.input_size(), 1);
3295       EXPECT_EQ(node.input(0), "x");
3296       found++;
3297     }
3298     if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3299       EXPECT_EQ(node.op(), "Cos");
3300       ASSERT_EQ(node.input_size(), 1);
3301       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3302       found++;
3303     }
3304     if (node.name() == "split2") {
3305       ASSERT_EQ(node.input_size(), 3);
3306       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3307       EXPECT_EQ(node.input(1), "size_splits2");
3308       EXPECT_EQ(node.input(2), "axis");
3309       found++;
3310     }
3311     if (node.name() == "id_a2") {
3312       ASSERT_EQ(node.input_size(), 1);
3313       EXPECT_EQ(node.input(0), "split2");
3314       found++;
3315     }
3316     if (node.name() == "id_b2") {
3317       ASSERT_EQ(node.input_size(), 1);
3318       EXPECT_EQ(node.input(0), "split2:1");
3319       found++;
3320     }
3321   }
3322   EXPECT_EQ(found, 10);
3323 
3324   auto tensors = EvaluateNodes(output, item.fetch);
3325   ASSERT_EQ(tensors.size(), tensors_expected.size());
3326   EXPECT_EQ(tensors.size(), item.fetch.size());
3327   for (int i = 0; i < item.fetch.size(); ++i) {
3328     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3329   }
3330 }
3331 
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3332 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3333   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3334   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3335   Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3336   Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3337   Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3338   Output id1 = ops::Identity(s.WithOpName("id1"), a);
3339   Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3340   Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3341   GrapplerItem item;
3342   item.fetch = {"out1", "out2"};
3343   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3344 
3345   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3346 
3347   GraphDef output;
3348   ArithmeticOptimizer optimizer;
3349   EnableOnlyRemoveIdempotent(&optimizer);
3350   OptimizeTwice(&optimizer, &item, &output);
3351 
3352   EXPECT_EQ(7, output.node_size());
3353   int found = 0;
3354   for (const NodeDef& node : output.node()) {
3355     if (node.name() == "out1") {
3356       ASSERT_EQ(node.input_size(), 1);
3357       EXPECT_EQ(node.input(0), "sn1");
3358       found++;
3359     } else if (node.name() == "out2") {
3360       ASSERT_EQ(node.input_size(), 1);
3361       EXPECT_EQ(node.input(0), "id1");
3362       found++;
3363     } else if (node.name() == "sn1") {
3364       ASSERT_EQ(node.input_size(), 1);
3365       EXPECT_EQ(node.input(0), "a");
3366       found++;
3367     }
3368   }
3369   EXPECT_EQ(found, 3);
3370 
3371   auto tensors = EvaluateNodes(output, item.fetch);
3372   ASSERT_EQ(tensors.size(), tensors_expected.size());
3373   EXPECT_EQ(tensors.size(), item.fetch.size());
3374   for (int i = 0; i < item.fetch.size(); ++i) {
3375     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3376   }
3377 }
3378 
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3379 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3380   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3381   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3382   Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3383   Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3384   Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3385   Output lt = ops::Less(s.WithOpName("lt"), a, b);
3386   Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3387   Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3388   Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3389   // not_eq is reserved
3390   Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3391   Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3392   Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3393   Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3394   Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3395   Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3396   Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3397   Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3398   Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3399   Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3400   Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3401   Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3402 
3403   GrapplerItem item;
3404   item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3405                 "id_not_le", "id_not_gt",  "id_not_ge"};
3406   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3407 
3408   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3409 
3410   GraphDef output;
3411   ArithmeticOptimizer optimizer;
3412   EnableOnlyRemoveLogicalNot(&optimizer);
3413   OptimizeTwice(&optimizer, &item, &output);
3414 
3415   int found = 0;
3416   for (const NodeDef& node : output.node()) {
3417     if (node.name() == "id_not_eq") {
3418       ASSERT_EQ(node.input_size(), 1);
3419       EXPECT_EQ(node.input(0), "eq");
3420       ++found;
3421     }
3422     if (node.name() == "id_not_neq") {
3423       ASSERT_EQ(node.input_size(), 1);
3424       EXPECT_EQ(node.input(0), "neq");
3425       ++found;
3426     }
3427     if (node.name() == "id_not_lt") {
3428       ASSERT_EQ(node.input_size(), 1);
3429       EXPECT_EQ(node.input(0), "lt");
3430       ++found;
3431     }
3432     if (node.name() == "id_not_le") {
3433       ASSERT_EQ(node.input_size(), 1);
3434       EXPECT_EQ(node.input(0), "le");
3435       ++found;
3436     }
3437     if (node.name() == "id_not_gt") {
3438       ASSERT_EQ(node.input_size(), 1);
3439       EXPECT_EQ(node.input(0), "gt");
3440       ++found;
3441     }
3442     if (node.name() == "id_not_ge") {
3443       ASSERT_EQ(node.input_size(), 1);
3444       EXPECT_EQ(node.input(0), "ge");
3445       ++found;
3446     }
3447 
3448     if (node.name() == "eq") {
3449       EXPECT_EQ(node.op(), "NotEqual");
3450       ++found;
3451     }
3452     if (node.name() == "neq") {
3453       EXPECT_EQ(node.op(), "Equal");
3454       ++found;
3455     }
3456     if (node.name() == "lt") {
3457       EXPECT_EQ(node.op(), "GreaterEqual");
3458       ++found;
3459     }
3460     if (node.name() == "le") {
3461       EXPECT_EQ(node.op(), "Greater");
3462       ++found;
3463     }
3464     if (node.name() == "gt") {
3465       EXPECT_EQ(node.op(), "LessEqual");
3466       ++found;
3467     }
3468     if (node.name() == "ge") {
3469       EXPECT_EQ(node.op(), "Less");
3470       ++found;
3471     }
3472   }
3473   EXPECT_EQ(found, 12);
3474 
3475   auto tensors = EvaluateNodes(output, item.fetch);
3476   ASSERT_EQ(tensors.size(), tensors_expected.size());
3477   EXPECT_EQ(tensors.size(), item.fetch.size());
3478   for (int i = 0; i < item.fetch.size(); ++i) {
3479     test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3480   }
3481 }
3482 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3483 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3484   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3485   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3486   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3487   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3488   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3489 
3490   GrapplerItem item;
3491   item.fetch = {"final_out"};
3492   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3493   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3494   ASSERT_EQ(tensors_expected.size(), 1);
3495 
3496   GraphDef output;
3497   ArithmeticOptimizer optimizer;
3498   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3499   OptimizeAndPrune(&optimizer, &item, &output);
3500   auto tensors = EvaluateNodes(output, item.fetch);
3501   ASSERT_EQ(tensors.size(), 1);
3502 
3503   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3504   EXPECT_EQ(output.node_size(), item.graph.node_size());
3505   // Check if the inputs are switched
3506   int required_node_count = 0;
3507   for (int i = 0; i < output.node_size(); ++i) {
3508     const NodeDef& node = output.node(i);
3509     if (node.name() == "sqrt") {
3510       EXPECT_EQ(node.op(), "Sqrt");
3511       ASSERT_EQ(node.input_size(), 1);
3512       EXPECT_EQ(node.input(0), "reduce_max");
3513       ++required_node_count;
3514     } else if (node.name() == "reduce_max") {
3515       EXPECT_EQ(node.op(), "Max");
3516       ASSERT_EQ(node.input_size(), 2);
3517       EXPECT_EQ(node.input(0), "x");
3518       ++required_node_count;
3519     }
3520   }
3521   EXPECT_EQ(required_node_count, 2);
3522 }
3523 
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3524 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3525   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3526   const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3527   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3528   Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3529   Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3530 
3531   GrapplerItem item;
3532   item.fetch = {"final_out"};
3533   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3534   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3535   ASSERT_EQ(tensors_expected.size(), 1);
3536 
3537   GraphDef output;
3538   ArithmeticOptimizer optimizer;
3539   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3540   OptimizeAndPrune(&optimizer, &item, &output);
3541   const auto tensors = EvaluateNodes(output, item.fetch);
3542   ASSERT_EQ(tensors.size(), 1);
3543 
3544   test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
3545   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3546   // Check if the inputs are switched
3547   int required_node_count = 0;
3548   for (int i = 0; i < output.node_size(); ++i) {
3549     const NodeDef& node = output.node(i);
3550     if (node.name() == "final_out") {
3551       EXPECT_EQ(node.op(), "Identity");
3552       ASSERT_EQ(node.input_size(), 1);
3553       EXPECT_EQ(node.input(0), "arg_max");
3554       ++required_node_count;
3555     } else if (node.name() == "arg_max") {
3556       EXPECT_EQ(node.op(), "ArgMax");
3557       ASSERT_EQ(node.input_size(), 2);
3558       EXPECT_EQ(node.input(0), "x");
3559       ++required_node_count;
3560     }
3561   }
3562   EXPECT_EQ(required_node_count, 2);
3563 }
3564 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)3565 TEST_F(ArithmeticOptimizerTest,
3566        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
3567   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3568   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3569   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3570   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3571   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3572 
3573   GrapplerItem item;
3574   item.fetch = {"sqrt", "final_out"};
3575   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3576   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3577   EXPECT_EQ(tensors_expected.size(), 2);
3578 
3579   GraphDef output;
3580   ArithmeticOptimizer optimizer;
3581   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3582   OptimizeTwice(&optimizer, &item, &output);
3583 
3584   // Should be a NoOp since we are not allowed to change the output of fetch
3585   // nodes.
3586   VerifyGraphsMatch(item.graph, output, __LINE__);
3587 }
3588 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)3589 TEST_F(ArithmeticOptimizerTest,
3590        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3591   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3592   auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
3593   Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
3594   Output y = ops::Neg(s.WithOpName("y"), reshape);
3595   Output z = ops::Max(s.WithOpName("z"), y, {0});
3596 
3597   GrapplerItem item;
3598   item.fetch = {"z"};
3599   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3600   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3601   ASSERT_EQ(tensors_expected.size(), 1);
3602 
3603   GraphDef output;
3604   ArithmeticOptimizer optimizer;
3605   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3606   OptimizeTwice(&optimizer, &item, &output);
3607 
3608   // Should be a NoOp since we are not allowed to change the output of fetch
3609   // nodes.
3610   VerifyGraphsMatch(item.graph, output, __LINE__);
3611 
3612   auto tensors = EvaluateNodes(output, item.fetch);
3613   ASSERT_EQ(tensors.size(), 1);
3614   test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
3615   test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
3616 }
3617 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)3618 TEST_F(ArithmeticOptimizerTest,
3619        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
3620   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3621   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3622   Output neg = ops::Neg(s.WithOpName("neg"), x);
3623   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
3624   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3625 
3626   GrapplerItem item;
3627   item.fetch = {"final_out"};
3628   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3629   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3630   ASSERT_EQ(tensors_expected.size(), 1);
3631 
3632   GraphDef output;
3633   ArithmeticOptimizer optimizer;
3634   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3635   OptimizeAndPrune(&optimizer, &item, &output);
3636   auto tensors = EvaluateNodes(output, item.fetch);
3637   ASSERT_EQ(tensors.size(), 1);
3638 
3639   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3640   EXPECT_EQ(output.node_size(), item.graph.node_size());
3641   // Check if the inputs are switched
3642   int required_node_count = 0;
3643   for (int i = 0; i < output.node_size(); ++i) {
3644     const NodeDef& node = output.node(i);
3645     if (node.name() == "neg") {
3646       EXPECT_EQ(node.op(), "Neg");
3647       ASSERT_EQ(node.input_size(), 1);
3648       EXPECT_EQ(node.input(0), "reduce_max");
3649       ++required_node_count;
3650     } else if (node.name() == "reduce_max") {
3651       EXPECT_EQ(node.op(), "Min");
3652       ASSERT_EQ(node.input_size(), 2);
3653       EXPECT_EQ(node.input(0), "x");
3654       ++required_node_count;
3655     }
3656   }
3657   EXPECT_EQ(2, required_node_count);
3658 }
3659 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)3660 TEST_F(ArithmeticOptimizerTest,
3661        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
3662   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3663   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3664   Output neg = ops::Neg(s.WithOpName("neg"), x);
3665   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
3666                                  {1, 2, 2, 1}, "VALID");
3667 
3668   GrapplerItem item;
3669   item.fetch = {"max_pool"};
3670   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3671   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3672   ASSERT_EQ(tensors_expected.size(), 1);
3673 
3674   GraphDef output;
3675   ArithmeticOptimizer optimizer;
3676   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3677   OptimizeTwice(&optimizer, &item, &output);
3678 
3679   // Should be a NoOp
3680   VerifyGraphsMatch(item.graph, output, __LINE__);
3681 
3682   auto tensors = EvaluateNodes(output, item.fetch);
3683   ASSERT_EQ(tensors.size(), 1);
3684   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3685 }
3686 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)3687 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
3688   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3689   Output weights = ops::Const(s.WithOpName("weights"),
3690                               Input::Initializer(1.0f, {5, 5, 3, 4}));
3691   Output biases =
3692       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3693   Output input = ops::Const(s.WithOpName("input"),
3694                             Input::Initializer(1.0f, {1, 28, 28, 3}));
3695   Output output =
3696       ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
3697   output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
3698   output = ops::Relu(s.WithOpName("relu"), output);
3699   output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
3700                         {1, 2, 2, 1}, "VALID");
3701   output = ops::Identity(s.WithOpName("output"), output);
3702 
3703   GrapplerItem item;
3704   item.fetch = {"output"};
3705   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3706   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3707   ASSERT_EQ(tensors_expected.size(), 1);
3708 
3709   GraphDef new_graph;
3710   ArithmeticOptimizer optimizer;
3711   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3712   OptimizeTwice(&optimizer, &item, &new_graph);
3713 
3714   // Should be a NoOp
3715   VerifyGraphsMatch(item.graph, new_graph, __LINE__);
3716 
3717   auto tensors = EvaluateNodes(new_graph, item.fetch);
3718   ASSERT_EQ(tensors.size(), 1);
3719   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3720 }
3721 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)3722 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
3723   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3724   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3725   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3726   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
3727                                  {1, 2, 2, 1}, "VALID");
3728   Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
3729 
3730   GrapplerItem item;
3731   item.fetch = {"final_out"};
3732   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3733   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3734   ASSERT_EQ(tensors_expected.size(), 1);
3735 
3736   GraphDef output;
3737   ArithmeticOptimizer optimizer;
3738   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3739   OptimizeAndPrune(&optimizer, &item, &output);
3740   auto tensors = EvaluateNodes(output, item.fetch);
3741   ASSERT_EQ(tensors.size(), 1);
3742 
3743   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3744   EXPECT_EQ(output.node_size(), item.graph.node_size());
3745   // Check if the inputs are switched
3746   int required_node_count = 0;
3747   for (int i = 0; i < output.node_size(); ++i) {
3748     const NodeDef& node = output.node(i);
3749     if (node.name() == "sqrt") {
3750       EXPECT_EQ(node.op(), "Sqrt");
3751       ASSERT_EQ(node.input_size(), 1);
3752       EXPECT_EQ(node.input(0), "max_pool");
3753       ++required_node_count;
3754     } else if (node.name() == "max_pool") {
3755       EXPECT_EQ(node.op(), "MaxPool");
3756       ASSERT_EQ(node.input_size(), 1);
3757       EXPECT_EQ(node.input(0), "x");
3758       ++required_node_count;
3759     }
3760   }
3761   EXPECT_EQ(required_node_count, 2);
3762 }
3763 
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)3764 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
3765   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3766 
3767   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3768   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3769   Output log = ops::Log(s.WithOpName("log"), sqrt);
3770   Output relu = ops::Relu(s.WithOpName("relu"), log);
3771   Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
3772 
3773   GrapplerItem item;
3774   item.fetch = {"final_out"};
3775   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3776 
3777   // Place all nodes on CPU.
3778   for (int i = 0; i < item.graph.node_size(); ++i) {
3779     item.graph.mutable_node(i)->set_device("/device:CPU:0");
3780   }
3781 
3782   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3783   ASSERT_EQ(tensors_expected.size(), 1);
3784 
3785   GraphDef output;
3786   ArithmeticOptimizer optimizer;
3787   EnableOnlyUnaryOpsComposition(&optimizer);
3788   OptimizeAndPrune(&optimizer, &item, &output);
3789 
3790   EXPECT_EQ(output.node_size(), 3);
3791 
3792   // Check that Sqrt/Log/Relu were replaced with a single op.
3793   int required_node_count = 0;
3794   for (int i = 0; i < output.node_size(); ++i) {
3795     const NodeDef& node = output.node(i);
3796     if (node.name() == "final_out") {
3797       EXPECT_EQ(node.op(), "Identity");
3798       ASSERT_EQ(node.input_size(), 1);
3799       EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
3800       ++required_node_count;
3801     } else if (node.name() == "relu/unary_ops_composition") {
3802       EXPECT_EQ(node.op(), "_UnaryOpsComposition");
3803       ASSERT_EQ(node.input_size(), 1);
3804       EXPECT_EQ(node.input(0), "x");
3805 
3806       auto op_names = node.attr().at("op_names").list().s();
3807       ASSERT_EQ(op_names.size(), 3);
3808       EXPECT_EQ(op_names[0], "Sqrt");
3809       EXPECT_EQ(op_names[1], "Log");
3810       EXPECT_EQ(op_names[2], "Relu");
3811       ++required_node_count;
3812     }
3813   }
3814   EXPECT_EQ(required_node_count, 2);
3815 
3816   auto tensors = EvaluateNodes(output, item.fetch);
3817   ASSERT_EQ(tensors.size(), 1);
3818   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3819 }
3820 
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)3821 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
3822   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3823   auto a_in =
3824       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
3825   auto b_in =
3826       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
3827   auto c_in =
3828       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
3829   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
3830                                        PartialTensorShape({-1, -1}));
3831   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
3832                                        PartialTensorShape({-1, -1}));
3833   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
3834                                        PartialTensorShape({-1, -1}));
3835   // stacked = tf.stack((a, b, c), axis=1).
3836   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
3837   auto stacked =
3838       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
3839                  ops::Stack::Axis(1));
3840   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
3841   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
3842   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
3843   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
3844   auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
3845   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
3846   auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
3847   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
3848   auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
3849   auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
3850   auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
3851 
3852   // stacked[:, 0]
3853   using SS = ops::StridedSlice;
3854   auto pa_slice = ops::Identity(
3855       s.WithOpName("pa_slice_out"),
3856       SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
3857          SS::BeginMask(0b0101)  // 5
3858              .EllipsisMask(0)
3859              .EndMask(0b0101)  // 5
3860              .NewAxisMask(0)
3861              .ShrinkAxisMask(0b0010)));  // 2
3862 
3863   // stacked[:, 1]
3864   auto pb_slice = ops::Identity(
3865       s.WithOpName("pb_slice_out"),
3866       SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
3867          SS::BeginMask(0b0101)  // 5
3868              .EllipsisMask(0)
3869              .EndMask(0b0101)  // 5
3870              .NewAxisMask(0)
3871              .ShrinkAxisMask(0b0010)));  // 2
3872 
3873   // stacked[:, 2]
3874   auto pc_slice = ops::Identity(
3875       s.WithOpName("pc_slice_out"),
3876       SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
3877          SS::BeginMask(0b0101)  // 5
3878              .EllipsisMask(0)
3879              .EndMask(0b0101)  // 5
3880              .NewAxisMask(0)
3881              .ShrinkAxisMask(0b0010)));  // 2
3882 
3883   // stacked[:, 0:1, :]
3884   auto pa_slice_01 = ops::Identity(
3885       s.WithOpName("pa_slice_01_out"),
3886       SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
3887          SS::BeginMask(0b0101)  // 5
3888              .EllipsisMask(0)
3889              .EndMask(0b0101)  // 5
3890              .NewAxisMask(0)
3891              .ShrinkAxisMask(0)));
3892 
3893   // stacked[:, :1, :]
3894   auto pa_slice_to1 = ops::Identity(
3895       s.WithOpName("pa_slice_to1_out"),
3896       SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
3897          SS::BeginMask(0b0111)  // 7
3898              .EllipsisMask(0)
3899              .EndMask(0b0101)  // 5
3900              .NewAxisMask(0)
3901              .ShrinkAxisMask(0)));
3902 
3903   // stacked[:, 1:2, :]
3904   auto pb_slice_12 = ops::Identity(
3905       s.WithOpName("pb_slice_12_out"),
3906       SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
3907          SS::BeginMask(0b0101)  // 5
3908              .EllipsisMask(0)
3909              .EndMask(0b0101)  // 5
3910              .NewAxisMask(0)
3911              .ShrinkAxisMask(0)));
3912 
3913   // stacked[:, 2:, :].
3914   auto pc_slice_2to = ops::Identity(
3915       s.WithOpName("pc_slice_2to_out"),
3916       SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
3917          SS::BeginMask(0b0101)  // 5
3918              .EllipsisMask(0)
3919              .EndMask(0b0111)  // 7
3920              .NewAxisMask(0)
3921              .ShrinkAxisMask(0)));
3922 
3923   GrapplerItem item;
3924   item.fetch = {"a",
3925                 "b",
3926                 "c",
3927                 "pa_slice_out",
3928                 "pb_slice_out",
3929                 "pc_slice_out",
3930                 "expanded_a",
3931                 "expanded_b",
3932                 "expanded_c",
3933                 "pa_slice_01_out",
3934                 "pa_slice_to1_out",
3935                 "pb_slice_12_out",
3936                 "pc_slice_2to_out"};
3937   enum FetchItem {
3938     fA,
3939     fB,
3940     fC,
3941     fASliceOut,
3942     fBSliceOut,
3943     fCSliceOut,
3944     fExpandedA,
3945     fExpandedB,
3946     fExpandedC,
3947     fASlice01Out,
3948     fASliceTo1Out,
3949     fBSlice12Out,
3950     fCSlice2ToOut,
3951   };
3952   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3953   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3954 
3955   // stacked[:, 0, :] == a.
3956   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
3957                                  tensors_expected[fA]);
3958   // stacked[:, 1, :] == b.
3959   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
3960                                  tensors_expected[fB]);
3961   // stacked[:, 2, :] == c.
3962   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
3963                                  tensors_expected[fC]);
3964 
3965   // stacked[:, 0:1, :] == expand_dims(a, 1).
3966   test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
3967                                  tensors_expected[fExpandedA]);
3968 
3969   // stacked[:, :1, :] == expand_dims(a, 1).
3970   test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
3971                                  tensors_expected[fExpandedA]);
3972 
3973   // stacked[:, 1:2, :] == expand_dims(b, 1).
3974   test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
3975                                  tensors_expected[fExpandedB]);
3976   // stacked[:, 2:, :] == expand_dims(c, 1).
3977   test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
3978                                  tensors_expected[fExpandedC]);
3979 
3980   GraphDef output;
3981   ArithmeticOptimizer optimizer;
3982   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
3983   OptimizeAndPrune(&optimizer, &item, &output);
3984 
3985   for (const auto& node : output.node()) {
3986     if (node.name() == "pa_slice_out") {
3987       ASSERT_EQ(node.input_size(), 1);
3988       EXPECT_EQ(node.input(0), "a");
3989     } else if (node.name() == "pb_slice_out") {
3990       ASSERT_EQ(node.input_size(), 1);
3991       EXPECT_EQ(node.input(0), "b");
3992     } else if (node.name() == "pc_slice_out") {
3993       ASSERT_EQ(node.input_size(), 1);
3994       EXPECT_EQ(node.input(0), "c");
3995     } else if (str_util::EndsWith(node.name(), "_out")) {
3996       ASSERT_EQ(node.input_size(), 1);
3997       EXPECT_EQ(
3998           absl::StrCat(node.input(0), "_out"),
3999           absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4000                        node.name()));
4001     }
4002   }
4003 
4004   auto tensors = EvaluateNodes(output, item.fetch);
4005 
4006   // stacked[:, 0, :] == a.
4007   test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4008 
4009   // stacked[:, 1, :] == b.
4010   test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4011   // stacked[:, 2, :] == c.
4012   test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4013 
4014   // stacked[:, 0:1, :] == expand_dims(a, 1).
4015   test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4016                                  tensors_expected[fExpandedA]);
4017 
4018   // stacked[:, :1, :] == expand_dims(a, 1).
4019   test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4020                                  tensors_expected[fExpandedA]);
4021 
4022   // stacked[:, 1:2, :] == expand_dims(b, 1).
4023   test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4024                                  tensors_expected[fExpandedB]);
4025   // stacked[:, 2:, :] == expand_dims(c, 1).
4026   test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4027                                  tensors_expected[fExpandedC]);
4028 }
4029 
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4030 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4031   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4032   auto a_in =
4033       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4034   auto b_in =
4035       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4036   auto c_in =
4037       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4038   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4039                                        PartialTensorShape({-1, -1}));
4040   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4041                                        PartialTensorShape({-1, -1}));
4042   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4043                                        PartialTensorShape({-1, -1}));
4044   // stacked = tf.stack((a, b, c), axis=1).
4045   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4046   auto stacked =
4047       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4048                  ops::Stack::Axis(1));
4049   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4050   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4051   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4052   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4053   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4054   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4055   auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4056 
4057   // stacked[:, 0:1, :]
4058   auto pa_slice = ops::Identity(
4059       s.WithOpName("pa_slice_out"),
4060       ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4061 
4062   // stacked[:, 1:2, :]
4063   auto pb_slice = ops::Identity(
4064       s.WithOpName("pb_slice_out"),
4065       ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4066 
4067   // stacked[:, 2:3, :]
4068   auto pc_slice = ops::Identity(
4069       s.WithOpName("pc_slice_out"),
4070       ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4071 
4072   GrapplerItem item;
4073   item.fetch = {"a",
4074                 "b",
4075                 "c",
4076                 "pa_slice_out",
4077                 "pb_slice_out",
4078                 "pc_slice_out",
4079                 "expanded_a",
4080                 "expanded_b",
4081                 "expanded_c"};
4082   enum FetchItem {
4083     fA,
4084     fB,
4085     fC,
4086     fASliceOut,
4087     fBSliceOut,
4088     fCSliceOut,
4089     fExpandedA,
4090     fExpandedB,
4091     fExpandedC,
4092   };
4093   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4094   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4095 
4096   // stacked[:, 0:1, :] == a.
4097   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4098                                  tensors_expected[fExpandedA]);
4099   // stacked[:, 1:2, :] == b.
4100   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4101                                  tensors_expected[fExpandedB]);
4102   // stacked[:, 2:3, :] == c.
4103   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4104                                  tensors_expected[fExpandedC]);
4105 
4106   GraphDef output;
4107   ArithmeticOptimizer optimizer;
4108   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4109   OptimizeAndPrune(&optimizer, &item, &output);
4110 
4111   const string kExpandDimsNamePrefix(
4112       "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4113 
4114   for (const auto& node : output.node()) {
4115     if (node.name() == "pa_slice_out") {
4116       ASSERT_EQ(node.input_size(), 1);
4117       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4118     } else if (node.name() == "pb_slice_out") {
4119       ASSERT_EQ(node.input_size(), 1);
4120       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4121     } else if (node.name() == "pc_slice_out") {
4122       ASSERT_EQ(node.input_size(), 1);
4123       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4124     } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4125       EXPECT_EQ(node.op(), "ExpandDims");
4126       // The input is "a", "b", or "c", as appropriate.
4127       EXPECT_EQ(node.input(0),
4128                 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4129     }
4130   }
4131 
4132   auto tensors = EvaluateNodes(output, item.fetch);
4133 
4134   // stacked[:, 0:1, :] == a.
4135   test::ExpectTensorEqual<float>(tensors[fASliceOut],
4136                                  tensors_expected[fExpandedA]);
4137 
4138   // stacked[:, 1:2, :] == b.
4139   test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4140                                  tensors_expected[fExpandedB]);
4141   // stacked[:, 2:3, :] == c.
4142   test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4143                                  tensors_expected[fExpandedC]);
4144 }
4145 
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4146 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4147   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4148   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4149   Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4150   Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4151   Output id = ops::Identity(s.WithOpName("id"), add);
4152 
4153   GrapplerItem item;
4154   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4155   item.fetch = {"id"};
4156   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4157   ASSERT_EQ(tensors_expected.size(), 1);
4158 
4159   GraphDef output;
4160   ArithmeticOptimizer optimizer;
4161   EnableOnlySimplifyAggregation(&optimizer);
4162   OptimizeAndPrune(&optimizer, &item, &output);
4163 
4164   // Extra node created for multiplier.
4165   EXPECT_EQ(output.node_size(), 5);
4166 
4167   auto tensors = EvaluateNodes(output, item.fetch);
4168   ASSERT_EQ(tensors.size(), 1);
4169   test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4170 }
4171 
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4172 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4173   for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4174     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4175     Output embeddings = ops::Const(s.WithOpName("embeddings"),
4176                                    {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4177     Output segment_ids =
4178         ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4179     Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4180     auto unique = ops::Unique(s.WithOpName("unique"), indices,
4181                               /*attrs=*/{unique_idx_type});
4182     Output ids = unique.y;
4183     Output idx = unique.idx;
4184     Output gathered_rows =
4185         ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4186     Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4187                                           idx, segment_ids);
4188     Output id = ops::Identity(s.WithOpName("id"), result);
4189 
4190     GrapplerItem item;
4191     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4192     item.fetch = {"id"};
4193     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4194     ASSERT_EQ(tensors_expected.size(), 1);
4195 
4196     GraphDef output;
4197     ArithmeticOptimizer optimizer;
4198     EnableOnlySimplifyEmbeddingLookup(&optimizer);
4199     OptimizeAndPrune(&optimizer, &item, &output);
4200 
4201     for (const auto& node : output.node()) {
4202       if (node.name() == "result") {
4203         EXPECT_EQ(node.input(0), "embeddings");
4204         EXPECT_EQ(node.input(1), "indices");
4205       }
4206       EXPECT_NE(node.op(), "Unique");
4207       EXPECT_NE(node.op(), "Gather");
4208     }
4209 
4210     auto tensors = EvaluateNodes(output, item.fetch);
4211     ASSERT_EQ(tensors.size(), 1);
4212     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4213   }
4214 }
4215 
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4216 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4217   for (DataType indices_type : {DT_INT32, DT_INT64}) {
4218     for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4219       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4220       Output embeddings = ops::Const(s.WithOpName("embeddings"),
4221                                      {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4222       Output indices =
4223           ops::Cast(s.WithOpName("cast_indices"),
4224                     ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4225                     indices_type);
4226       Output segment_ids = ops::Cast(
4227           s.WithOpName("cast_segment_ids"),
4228           ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4229           segment_ids_type);
4230       Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4231                                             indices, segment_ids);
4232       Output id = ops::Identity(s.WithOpName("id"), result);
4233 
4234       GrapplerItem item;
4235       TF_CHECK_OK(s.ToGraphDef(&item.graph));
4236       item.fetch = {"id"};
4237       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4238       ASSERT_EQ(tensors_expected.size(), 1);
4239 
4240       GraphDef output;
4241       ArithmeticOptimizer optimizer;
4242       EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4243       OptimizeAndPrune(&optimizer, &item, &output);
4244 
4245       for (const auto& node : output.node()) {
4246         if (node.name() == "result") {
4247           EXPECT_EQ(node.input(1), "indices");
4248           EXPECT_EQ(node.input(2), "segment_ids");
4249         }
4250         EXPECT_NE(node.op(), "Cast");
4251       }
4252 
4253       auto tensors = EvaluateNodes(output, item.fetch);
4254       ASSERT_EQ(tensors.size(), 1);
4255       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4256     }
4257   }
4258 }
4259 
4260 }  // namespace grappler
4261 }  // namespace tensorflow
4262