1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/window_util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status_test_util.h"
43 
44 namespace xla {
45 namespace {
46 
47 using ::testing::ElementsAre;
48 namespace m = match;
49 
50 class AlgebraicSimplifierTest : public HloTestBase {
51  protected:
52   AlgebraicSimplifierOptions default_options_;
53 };
54 
55 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)56 TEST_F(AlgebraicSimplifierTest, AddZero) {
57   auto m = CreateNewVerifiedModule();
58   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
59   HloComputation::Builder builder(TestName());
60   HloInstruction* param0 = builder.AddInstruction(
61       HloInstruction::CreateParameter(0, r0f32, "param0"));
62   HloInstruction* zero = builder.AddInstruction(
63       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
64   builder.AddInstruction(
65       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
66 
67   auto computation = m->AddEntryComputation(builder.Build());
68   HloInstruction* root = computation->root_instruction();
69   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
70   AlgebraicSimplifier simplifier(default_options_);
71   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
72   root = computation->root_instruction();
73   EXPECT_EQ(root, param0);
74 }
75 
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)76 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
77   const char* kModuleStr = R"(
78     HloModule m
79     test {
80       p0 = s32[8] parameter(0)
81       p1 = s32[8] parameter(1)
82       p2 = s32[8] parameter(2)
83       x = s32[8] multiply(p0, p2)
84       y = s32[8] multiply(p1, p2)
85       ROOT sum = s32[8] add(x, y)
86     }
87   )";
88   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
89   AlgebraicSimplifier simplifier(default_options_);
90   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
91   EXPECT_THAT(
92       m->entry_computation()->root_instruction(),
93       GmockMatch(m::MultiplyAnyOrder(
94           m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
95 }
96 
97 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)98 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
99   const char* kModuleStr = R"(
100     HloModule m
101     test {
102       p0 = f32[] parameter(0)
103       p1 = f32[] parameter(1)
104       c = f32[] constant(0.125)
105       x = f32[] multiply(p0, c)
106       y = f32[] multiply(p1, c)
107       ROOT sum = f32[] add(x, y)
108     }
109   )";
110   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
111   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
112   EXPECT_THAT(m->entry_computation()->root_instruction(),
113               GmockMatch(m::MultiplyAnyOrder(
114                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
115                   m::ConstantScalar(0.125))));
116 }
117 
118 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)119 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
120   const char* kModuleStr = R"(
121     HloModule m
122     test {
123       p0 = f32[4] parameter(0)
124       p1 = f32[4] parameter(1)
125       c = f32[] constant(0.125)
126       b = f32[4] broadcast(c), dimensions={}
127       x = f32[4] multiply(p0, b)
128       y = f32[4] multiply(p1, b)
129       ROOT sum = f32[4] add(x, y)
130     }
131   )";
132   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
133   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
134   EXPECT_THAT(m->entry_computation()->root_instruction(),
135               GmockMatch(m::MultiplyAnyOrder(
136                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
137                   m::Broadcast(m::ConstantScalar(0.125)))));
138 }
139 
140 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
141 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)142 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
143   const char* kModuleStr = R"(
144     HloModule m
145     test {
146       p0 = f32[] parameter(0)
147       p1 = f32[] parameter(1)
148       c = f32[] constant(0.3)
149       x = f32[] multiply(p0, c)
150       y = f32[] multiply(p1, c)
151       ROOT sum = f32[] add(x, y)
152     }
153   )";
154   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
155   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
156 }
157 
158 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
159 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)160 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
161   const char* kModuleStr = R"(
162     HloModule m
163     test {
164       p0 = c64[8] parameter(0)
165       p1 = c64[8] parameter(1)
166       p2 = c64[8] parameter(2)
167       x = c64[8] multiply(p0, p2)
168       y = c64[8] multiply(p1, p2)
169       ROOT sum = c64[8] add(x, y)
170     }
171   )";
172   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
173   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
174 }
175 
176 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)177 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
178   const char* kModuleStr = R"(
179     HloModule m
180     test {
181       p0 = bf16[4] parameter(0)
182       p1 = bf16[4] parameter(1)
183       c = bf16[] constant(0.125)
184       b = bf16[4] broadcast(c), dimensions={}
185       x = bf16[4] multiply(p0, b)
186       y = bf16[4] multiply(p1, b)
187       ROOT sum = bf16[4] add(x, y)
188     }
189   )";
190   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
191   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
192   EXPECT_THAT(m->entry_computation()->root_instruction(),
193               GmockMatch(m::MultiplyAnyOrder(
194                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
195                   m::Broadcast(m::ConstantScalar(0.125)))));
196 }
197 
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)198 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
199   const char* kModuleStr = R"(
200     HloModule m
201     test {
202       p = u32[4] parameter(0)
203       c = u32[] constant(8)
204       b = u32[4] broadcast(c), dimensions={}
205       ROOT d = u32[4] divide(p, b)
206     }
207   )";
208   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
209   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
210   EXPECT_THAT(m->entry_computation()->root_instruction(),
211               GmockMatch(m::ShiftRightLogical(
212                   m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
213 }
214 
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)215 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
216   const char* kModuleStr = R"(
217     HloModule m
218     test {
219       p = s32[4] parameter(0)
220       c = s32[] constant(8)
221       b = s32[4] broadcast(c), dimensions={}
222       ROOT d = s32[4] divide(p, b)
223     }
224   )";
225   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
226   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
227   auto match_dividend_is_negative =
228       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
229   auto match_abs = m::Select(match_dividend_is_negative,
230                              m::Negate(m::Parameter(0)), m::Parameter(0));
231   auto match_shift =
232       m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
233   EXPECT_THAT(m->entry_computation()->root_instruction(),
234               GmockMatch(m::Select(match_dividend_is_negative,
235                                    m::Negate(match_shift), match_shift)));
236 }
237 
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)238 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
239   const char* kModuleStr = R"(
240     HloModule m
241     test {
242       p = u32[4] parameter(0)
243       c = u32[] constant(8)
244       b = u32[4] broadcast(c), dimensions={}
245       ROOT r = u32[4] remainder(p, b)
246     }
247   )";
248   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
249   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
250   EXPECT_THAT(m->entry_computation()->root_instruction(),
251               GmockMatch(m::AndAnyOrder(m::Parameter(0),
252                                         m::Broadcast(m::ConstantScalar(7)))));
253 }
254 
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)255 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
256   const char* kModuleStr = R"(
257     HloModule m
258     test {
259       p = s32[4] parameter(0)
260       c = s32[] constant(8)
261       b = s32[4] broadcast(c), dimensions={}
262       ROOT r = s32[4] remainder(p, b)
263     }
264   )";
265   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
266   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
267   auto match_dividend_is_negative =
268       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
269   auto match_abs = m::Select(match_dividend_is_negative,
270                              m::Negate(m::Parameter(0)), m::Parameter(0));
271   auto match_and =
272       m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
273   EXPECT_THAT(m->entry_computation()->root_instruction(),
274               GmockMatch(m::Select(match_dividend_is_negative,
275                                    m::Negate(match_and), match_and)));
276 }
277 
278 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)279 TEST_F(AlgebraicSimplifierTest, MulZero) {
280   auto m = CreateNewVerifiedModule();
281   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
282   HloComputation::Builder builder(TestName());
283   HloInstruction* param0 = builder.AddInstruction(
284       HloInstruction::CreateParameter(0, r0s32, "param0"));
285   HloInstruction* zero = builder.AddInstruction(
286       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
287   builder.AddInstruction(
288       HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
289 
290   auto computation = m->AddEntryComputation(builder.Build());
291   HloInstruction* root = computation->root_instruction();
292   EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
293   AlgebraicSimplifier simplifier(default_options_);
294   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
295   EXPECT_EQ(computation->root_instruction(), zero);
296 }
297 
298 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)299 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
300   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
301   HloComputation::Builder builder(TestName());
302   HloInstruction* param0 = builder.AddInstruction(
303       HloInstruction::CreateParameter(0, r0s32, "param0"));
304   HloInstruction* param1 = builder.AddInstruction(
305       HloInstruction::CreateParameter(1, r0s32, "param1"));
306   HloInstruction* one = builder.AddInstruction(
307       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
308   builder.AddInstruction(HloInstruction::CreateTernary(
309       r0s32, HloOpcode::kSelect, one, param0, param1));
310 
311   auto module = CreateNewVerifiedModule();
312   auto computation = module->AddEntryComputation(builder.Build());
313   HloInstruction* root = computation->root_instruction();
314   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
315   AlgebraicSimplifier simplifier(default_options_);
316   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
317   EXPECT_EQ(computation->root_instruction(), param0);
318 }
319 
320 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)321 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
322   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
323   HloComputation::Builder builder(TestName());
324   HloInstruction* param0 = builder.AddInstruction(
325       HloInstruction::CreateParameter(0, r0s32, "param0"));
326   HloInstruction* param1 = builder.AddInstruction(
327       HloInstruction::CreateParameter(1, r0s32, "param1"));
328   HloInstruction* zero = builder.AddInstruction(
329       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
330   builder.AddInstruction(HloInstruction::CreateTernary(
331       r0s32, HloOpcode::kSelect, zero, param0, param1));
332 
333   auto module = CreateNewVerifiedModule();
334   auto computation = module->AddEntryComputation(builder.Build());
335   HloInstruction* root = computation->root_instruction();
336   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
337   AlgebraicSimplifier simplifier(default_options_);
338   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
339   EXPECT_EQ(computation->root_instruction(), param1);
340 }
341 
342 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)343 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
344   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
345   HloComputation::Builder builder(TestName());
346   HloInstruction* param0 = builder.AddInstruction(
347       HloInstruction::CreateParameter(0, r0s32, "param0"));
348   HloInstruction* param1 = builder.AddInstruction(
349       HloInstruction::CreateParameter(1, r0s32, "param1"));
350   builder.AddInstruction(HloInstruction::CreateTernary(
351       r0s32, HloOpcode::kSelect, param0, param1, param1));
352 
353   auto module = CreateNewVerifiedModule();
354   auto computation = module->AddEntryComputation(builder.Build());
355   HloInstruction* root = computation->root_instruction();
356   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
357   AlgebraicSimplifier simplifier(default_options_);
358   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
359   EXPECT_EQ(computation->root_instruction(), param1);
360 }
361 
362 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)363 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
364   auto m = CreateNewVerifiedModule();
365   HloComputation::Builder builder(TestName());
366   // Create add computation.
367   HloInstruction* zero = builder.AddInstruction(
368       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
369   HloComputation* add_computation = nullptr;
370   {
371     HloComputation::Builder builder(TestName() + ".add");
372     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
373     HloInstruction* p0 = builder.AddInstruction(
374         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
375     HloInstruction* p1 = builder.AddInstruction(
376         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
377     builder.AddInstruction(
378         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
379     add_computation = m->AddEmbeddedComputation(builder.Build());
380   }
381   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
382   HloInstruction* param = builder.AddInstruction(
383       HloInstruction::CreateParameter(0, r4f32, "param"));
384   std::vector<int64> dims0({0});
385   Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
386   HloInstruction* reduce0 = builder.AddInstruction(
387       HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
388   std::vector<int64> dims1({1, 2});
389   Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
390   builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
391                                                       dims1, add_computation));
392   m->AddEntryComputation(builder.Build());
393   AlgebraicSimplifier simplifier(default_options_);
394   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
395   HloInstruction* root = m->entry_computation()->root_instruction();
396   EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
397   EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
398 }
399 
400 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)401 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
402   auto m = CreateNewVerifiedModule();
403   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
404   HloComputation::Builder builder(TestName());
405   HloInstruction* param0 = builder.AddInstruction(
406       HloInstruction::CreateParameter(0, r0f32, "param0"));
407   HloInstruction* constant = builder.AddInstruction(
408       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
409   builder.AddInstruction(
410       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
411 
412   auto computation = m->AddEntryComputation(builder.Build());
413   HloInstruction* root = computation->root_instruction();
414   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
415   AlgebraicSimplifier simplifier(default_options_);
416   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
417   root = computation->root_instruction();
418   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
419 }
420 
421 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)422 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
423   auto m = CreateNewVerifiedModule();
424   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
425   HloComputation::Builder builder(TestName());
426   HloInstruction* param0 = builder.AddInstruction(
427       HloInstruction::CreateParameter(0, r0f32, "param0"));
428   HloInstruction* constant1 = builder.AddInstruction(
429       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
430   HloInstruction* constant2 = builder.AddInstruction(
431       HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
432 
433   HloInstruction* add1 = builder.AddInstruction(
434       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
435   builder.AddInstruction(
436       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
437 
438   auto computation = m->AddEntryComputation(builder.Build());
439   HloInstruction* root = computation->root_instruction();
440   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
441   AlgebraicSimplifier simplifier(default_options_);
442   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
443   root = computation->root_instruction();
444   EXPECT_THAT(root, GmockMatch(m::Add(
445                         m::Op().Is(param0),
446                         m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
447 }
448 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)449 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
450   auto m = CreateNewVerifiedModule();
451   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
452   HloComputation::Builder builder(TestName());
453   HloInstruction* param0 = builder.AddInstruction(
454       HloInstruction::CreateParameter(0, r2f32, "param0"));
455   HloInstruction* zero = builder.AddInstruction(
456       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
457   HloInstruction* bcast = builder.AddInstruction(
458       HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
459   builder.AddInstruction(
460       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
461 
462   auto computation = m->AddEntryComputation(builder.Build());
463   HloInstruction* root = computation->root_instruction();
464   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
465   AlgebraicSimplifier simplifier(default_options_);
466   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
467   root = computation->root_instruction();
468   EXPECT_EQ(root, param0);
469 }
470 
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)471 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
472   auto m = CreateNewVerifiedModule();
473   HloComputation::Builder builder(TestName());
474   // Create add computation.
475   HloComputation* add_computation = nullptr;
476   {
477     HloComputation::Builder builder(TestName() + ".add");
478     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
479     HloInstruction* p0 = builder.AddInstruction(
480         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
481     HloInstruction* p1 = builder.AddInstruction(
482         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
483     builder.AddInstruction(
484         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
485     add_computation = m->AddEmbeddedComputation(builder.Build());
486   }
487   Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
488   HloInstruction* param0 = builder.AddInstruction(
489       HloInstruction::CreateParameter(0, r2f32, "param0"));
490   HloInstruction* zero = builder.AddInstruction(
491       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
492   builder.AddInstruction(HloInstruction::CreateMap(
493       r2f32,
494       {param0, builder.AddInstruction(
495                    HloInstruction::CreateBroadcast(r2f32, zero, {}))},
496       add_computation));
497 
498   auto computation = m->AddEntryComputation(builder.Build());
499   HloInstruction* root = computation->root_instruction();
500   EXPECT_EQ(root->opcode(), HloOpcode::kMap);
501   AlgebraicSimplifier simplifier(default_options_);
502   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
503   root = computation->root_instruction();
504   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
505                                       m::Broadcast(m::Op().Is(zero)))));
506 }
507 
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)508 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
509   const char* kModuleStr = R"(
510     HloModule m
511     fusion {
512       x = f32[] parameter(0)
513       c = f32[] constant(42)
514       m = f32[] multiply(x, x)
515       ROOT a = f32[] add(m, c)
516     }
517 
518     map {
519       x = f32[] parameter(0)
520       ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
521     }
522 
523     ENTRY test {
524       p = f32[2,2] parameter(0)
525       ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
526     }
527   )";
528   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
529   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
530 }
531 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)532 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
533   auto m = CreateNewVerifiedModule();
534   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
535   HloComputation::Builder builder(TestName());
536   HloInstruction* param0 = builder.AddInstruction(
537       HloInstruction::CreateParameter(0, r2f32, "param0"));
538   HloInstruction* zero = builder.AddInstruction(
539       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
540   HloInstruction* bcast =
541       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
542   builder.AddInstruction(
543       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
544 
545   auto computation = m->AddEntryComputation(builder.Build());
546   HloInstruction* root = computation->root_instruction();
547   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
548   AlgebraicSimplifier simplifier(default_options_);
549   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
550   root = computation->root_instruction();
551   EXPECT_EQ(root, param0);
552 }
553 
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)554 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
555   auto m = CreateNewVerifiedModule();
556   HloComputation::Builder builder(TestName());
557   builder.AddInstruction(HloInstruction::CreateConstant(
558       LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
559 
560   auto computation = m->AddEntryComputation(builder.Build());
561   HloInstruction* root = computation->root_instruction();
562   EXPECT_THAT(root, GmockMatch(m::Constant()));
563   AlgebraicSimplifier simplifier(default_options_);
564   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
565   root = computation->root_instruction();
566   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
567   EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
568 }
569 
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)570 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
571   auto m = CreateNewVerifiedModule();
572   HloComputation::Builder builder(TestName());
573   builder.AddInstruction(HloInstruction::CreateConstant(
574       LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
575 
576   auto computation = m->AddEntryComputation(builder.Build());
577   HloInstruction* root = computation->root_instruction();
578   EXPECT_THAT(root, GmockMatch(m::Constant()));
579   AlgebraicSimplifier simplifier(default_options_);
580   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
581   root = computation->root_instruction();
582   EXPECT_THAT(root, GmockMatch(m::Constant()));
583 }
584 
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)585 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
586   auto m = CreateNewVerifiedModule();
587   HloComputation::Builder builder(TestName());
588   builder.AddInstruction(HloInstruction::CreateConstant(
589       LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
590 
591   auto computation = m->AddEntryComputation(builder.Build());
592   HloInstruction* root = computation->root_instruction();
593   EXPECT_THAT(root, GmockMatch(m::Constant()));
594   AlgebraicSimplifier simplifier(default_options_);
595   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
596   root = computation->root_instruction();
597   EXPECT_THAT(root, GmockMatch(m::Iota()));
598 }
599 
600 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)601 TEST_F(AlgebraicSimplifierTest, SubZero) {
602   auto m = CreateNewVerifiedModule();
603   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
604   HloComputation::Builder builder(TestName());
605   HloInstruction* param0 = builder.AddInstruction(
606       HloInstruction::CreateParameter(0, r0f32, "param0"));
607   HloInstruction* zero = builder.AddInstruction(
608       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
609   builder.AddInstruction(
610       HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
611 
612   auto computation = m->AddEntryComputation(builder.Build());
613   HloInstruction* root = computation->root_instruction();
614   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
615   AlgebraicSimplifier simplifier(default_options_);
616   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
617   root = computation->root_instruction();
618   EXPECT_EQ(root, param0);
619 }
620 
621 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)622 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
623   auto m = CreateNewVerifiedModule();
624   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
625   HloComputation::Builder builder(TestName());
626   HloInstruction* param0 = builder.AddInstruction(
627       HloInstruction::CreateParameter(0, r0f32, "param0"));
628   HloInstruction* constant = builder.AddInstruction(
629       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
630   builder.AddInstruction(HloInstruction::CreateBinary(
631       r0f32, HloOpcode::kSubtract, param0, constant));
632 
633   auto computation = m->AddEntryComputation(builder.Build());
634   HloInstruction* root = computation->root_instruction();
635   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
636   AlgebraicSimplifier simplifier(default_options_);
637   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
638   root = computation->root_instruction();
639   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
640                                       m::Negate(m::Op().Is(constant)))));
641 }
642 
643 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)644 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
645   auto m = CreateNewVerifiedModule();
646   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
647   HloComputation::Builder builder(TestName());
648   HloInstruction* param0 = builder.AddInstruction(
649       HloInstruction::CreateParameter(0, r0f32, "param0"));
650   HloInstruction* param1 = builder.AddInstruction(
651       HloInstruction::CreateParameter(1, r0f32, "param1"));
652   HloInstruction* param2 = builder.AddInstruction(
653       HloInstruction::CreateParameter(2, r0f32, "param2"));
654   HloInstruction* div = builder.AddInstruction(
655       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
656   builder.AddInstruction(
657       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
658 
659   auto computation = m->AddEntryComputation(builder.Build());
660 
661   EXPECT_THAT(computation->root_instruction(),
662               GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
663                                    m::Parameter(2))));
664 
665   AlgebraicSimplifier simplifier(default_options_);
666   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
667 
668   EXPECT_THAT(
669       computation->root_instruction(),
670       GmockMatch(m::Divide(m::Parameter(0),
671                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
672 }
673 
674 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)675 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
676   auto m = CreateNewVerifiedModule();
677   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
678   HloComputation::Builder builder(TestName());
679   HloInstruction* param0 = builder.AddInstruction(
680       HloInstruction::CreateParameter(0, r0f32, "param0"));
681   HloInstruction* param1 = builder.AddInstruction(
682       HloInstruction::CreateParameter(1, r0f32, "param1"));
683   HloInstruction* param2 = builder.AddInstruction(
684       HloInstruction::CreateParameter(2, r0f32, "param2"));
685   HloInstruction* div = builder.AddInstruction(
686       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
687   builder.AddInstruction(
688       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
689 
690   auto computation = m->AddEntryComputation(builder.Build());
691 
692   EXPECT_THAT(
693       computation->root_instruction(),
694       GmockMatch(m::Divide(m::Parameter(0),
695                            m::Divide(m::Parameter(1), m::Parameter(2)))));
696 
697   AlgebraicSimplifier simplifier(default_options_);
698   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
699 
700   EXPECT_THAT(
701       computation->root_instruction(),
702       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
703                            m::Parameter(1))));
704 }
705 
706 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)707 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
708   auto m = CreateNewVerifiedModule();
709   Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
710   HloComputation::Builder builder(TestName());
711   HloInstruction* param0 = builder.AddInstruction(
712       HloInstruction::CreateParameter(0, r2f32, "param0"));
713   HloInstruction* param1 = builder.AddInstruction(
714       HloInstruction::CreateParameter(1, r2f32, "param1"));
715   HloInstruction* param2 = builder.AddInstruction(
716       HloInstruction::CreateParameter(2, r2f32, "param2"));
717   HloInstruction* param3 = builder.AddInstruction(
718       HloInstruction::CreateParameter(3, r2f32, "param3"));
719   HloInstruction* div0 = builder.AddInstruction(
720       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
721   HloInstruction* div1 = builder.AddInstruction(
722       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
723   builder.AddInstruction(
724       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
725 
726   auto computation = m->AddEntryComputation(builder.Build());
727 
728   EXPECT_THAT(
729       computation->root_instruction(),
730       GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
731                            m::Divide(m::Parameter(2), m::Parameter(3)))));
732 
733   AlgebraicSimplifier simplifier(default_options_);
734   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
735 
736   EXPECT_THAT(
737       computation->root_instruction(),
738       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
739                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
740 }
741 
742 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)743 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
744   auto m = CreateNewVerifiedModule();
745   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
746   HloComputation::Builder builder(TestName());
747   HloInstruction* param0 = builder.AddInstruction(
748       HloInstruction::CreateParameter(0, r0f32, "param0"));
749   HloInstruction* param1 = builder.AddInstruction(
750       HloInstruction::CreateParameter(1, r0f32, "param1"));
751   HloInstruction* exp = builder.AddInstruction(
752       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
753   builder.AddInstruction(
754       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
755 
756   auto computation = m->AddEntryComputation(builder.Build());
757 
758   EXPECT_THAT(computation->root_instruction(),
759               GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
760 
761   AlgebraicSimplifier simplifier(default_options_);
762   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
763 
764   EXPECT_THAT(computation->root_instruction(),
765               GmockMatch(m::Multiply(m::Parameter(0),
766                                      m::Exp(m::Negate(m::Parameter(1))))));
767 }
768 
769 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)770 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
771   auto m = CreateNewVerifiedModule();
772   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
773   HloComputation::Builder builder(TestName());
774   HloInstruction* param0 = builder.AddInstruction(
775       HloInstruction::CreateParameter(0, r0f32, "param0"));
776   HloInstruction* param1 = builder.AddInstruction(
777       HloInstruction::CreateParameter(1, r0f32, "param1"));
778   HloInstruction* param2 = builder.AddInstruction(
779       HloInstruction::CreateParameter(2, r0f32, "param2"));
780   HloInstruction* power = builder.AddInstruction(
781       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
782   builder.AddInstruction(
783       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
784 
785   auto computation = m->AddEntryComputation(builder.Build());
786 
787   EXPECT_THAT(
788       computation->root_instruction(),
789       GmockMatch(m::Divide(m::Parameter(0),
790                            m::Power(m::Parameter(1), m::Parameter(2)))));
791 
792   AlgebraicSimplifier simplifier(default_options_);
793   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
794 
795   EXPECT_THAT(computation->root_instruction(),
796               GmockMatch(m::Multiply(
797                   m::Parameter(0),
798                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
799 }
800 
801 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
802 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)803 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
804   auto m = CreateNewVerifiedModule();
805   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
806   HloComputation::Builder builder(TestName());
807   HloInstruction* param0 = builder.AddInstruction(
808       HloInstruction::CreateParameter(0, r1f32, "param0"));
809   HloInstruction* param1 = builder.AddInstruction(
810       HloInstruction::CreateParameter(1, r1f32, "param1"));
811   HloInstruction* param2 = builder.AddInstruction(
812       HloInstruction::CreateParameter(2, r1f32, "param2"));
813   HloInstruction* power = builder.AddInstruction(
814       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
815   builder.AddInstruction(
816       HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
817 
818   auto computation = m->AddEntryComputation(builder.Build());
819 
820   EXPECT_THAT(
821       computation->root_instruction(),
822       GmockMatch(m::Divide(m::Parameter(0),
823                            m::Power(m::Parameter(1), m::Parameter(2)))));
824 
825   AlgebraicSimplifier simplifier(default_options_);
826   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
827 
828   ASSERT_THAT(computation->root_instruction(),
829               GmockMatch(m::Multiply(
830                   m::Parameter(0),
831                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
832 }
833 
834 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)835 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
836   auto m = CreateNewVerifiedModule();
837   Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
838   HloComputation::Builder builder(TestName());
839   HloInstruction* param0 = builder.AddInstruction(
840       HloInstruction::CreateParameter(0, r1f32, "param0"));
841   HloInstruction* constant =
842       builder.AddInstruction(HloInstruction::CreateConstant(
843           LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
844   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
845                                                       param0, constant));
846 
847   auto computation = m->AddEntryComputation(builder.Build());
848 
849   AlgebraicSimplifier simplifier(default_options_);
850   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
851 
852   EXPECT_THAT(computation->root_instruction(),
853               GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
854 }
855 
856 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)857 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
858   auto m = CreateNewVerifiedModule();
859   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
860   HloComputation::Builder builder(TestName());
861   HloInstruction* base = builder.AddInstruction(
862       HloInstruction::CreateParameter(0, r1f32, "param0"));
863   HloInstruction* exp1 = builder.AddInstruction(
864       HloInstruction::CreateParameter(1, r1f32, "param1"));
865   HloInstruction* exp2 = builder.AddInstruction(
866       HloInstruction::CreateParameter(2, r1f32, "param2"));
867   HloInstruction* inner_power = builder.AddInstruction(
868       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
869   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
870                                                       inner_power, exp2));
871 
872   auto computation = m->AddEntryComputation(builder.Build());
873   AlgebraicSimplifier simplifier(default_options_);
874   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
875   EXPECT_THAT(
876       computation->root_instruction(),
877       GmockMatch(m::Power(m::Op().Is(base),
878                           m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2)))));
879 }
880 
881 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
882 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)883 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
884   auto m = CreateNewVerifiedModule();
885   Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
886   HloComputation::Builder builder(TestName());
887   HloInstruction* base = builder.AddInstruction(
888       HloInstruction::CreateParameter(0, r1c64, "param0"));
889   HloInstruction* exp1 = builder.AddInstruction(
890       HloInstruction::CreateParameter(1, r1c64, "param1"));
891   HloInstruction* exp2 = builder.AddInstruction(
892       HloInstruction::CreateParameter(2, r1c64, "param2"));
893   HloInstruction* inner_power = builder.AddInstruction(
894       HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
895   builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
896                                                       inner_power, exp2));
897 
898   m->AddEntryComputation(builder.Build());
899   AlgebraicSimplifier simplifier(default_options_);
900   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
901 }
902 
903 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)904 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
905   auto m = CreateNewVerifiedModule();
906   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
907   HloComputation::Builder builder(TestName());
908   HloInstruction* param0 = builder.AddInstruction(
909       HloInstruction::CreateParameter(0, r0f32, "param0"));
910   HloInstruction* one = builder.AddInstruction(
911       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
912   HloInstruction* div = builder.AddInstruction(
913       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
914 
915   auto computation = m->AddEntryComputation(builder.Build());
916   HloInstruction* root = computation->root_instruction();
917   EXPECT_EQ(root, div);
918   AlgebraicSimplifier simplifier(default_options_);
919   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
920   root = computation->root_instruction();
921   EXPECT_EQ(root, param0);
922 }
923 
924 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)925 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
926   auto m = CreateNewVerifiedModule();
927   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
928   HloComputation::Builder builder(TestName());
929   HloInstruction* param0 = builder.AddInstruction(
930       HloInstruction::CreateParameter(0, r2f32, "param0"));
931   HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
932       LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
933   HloInstruction* div = builder.AddInstruction(
934       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
935 
936   auto computation = m->AddEntryComputation(builder.Build());
937   HloInstruction* root = computation->root_instruction();
938   EXPECT_EQ(root, div);
939   AlgebraicSimplifier simplifier(default_options_);
940   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
941   root = computation->root_instruction();
942   EXPECT_EQ(root, param0);
943 }
944 
945 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)946 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
947   auto m = CreateNewVerifiedModule();
948   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
949   Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
950   HloComputation::Builder builder(TestName());
951   HloInstruction* param0 = builder.AddInstruction(
952       HloInstruction::CreateParameter(0, r2c64, "param0"));
953   HloInstruction* real = builder.AddInstruction(
954       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
955   HloInstruction* imag = builder.AddInstruction(
956       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
957   HloInstruction* cplx = builder.AddInstruction(
958       HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
959 
960   auto computation = m->AddEntryComputation(builder.Build());
961   HloInstruction* root = computation->root_instruction();
962   EXPECT_EQ(root, cplx);
963   AlgebraicSimplifier simplifier(default_options_);
964   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
965   root = computation->root_instruction();
966   EXPECT_EQ(root, param0);
967 }
968 
969 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)970 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
971   auto m = CreateNewVerifiedModule();
972   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
973   HloComputation::Builder builder(TestName());
974   HloInstruction* param0 = builder.AddInstruction(
975       HloInstruction::CreateParameter(0, r2f32, "param0"));
976   HloInstruction* param1 = builder.AddInstruction(
977       HloInstruction::CreateParameter(1, r2f32, "param1"));
978   HloInstruction* cplx = builder.AddInstruction(
979       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
980                                    HloOpcode::kComplex, param0, param1));
981   HloInstruction* real = builder.AddInstruction(
982       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
983 
984   auto computation = m->AddEntryComputation(builder.Build());
985   HloInstruction* root = computation->root_instruction();
986   EXPECT_EQ(root, real);
987   AlgebraicSimplifier simplifier(default_options_);
988   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
989   root = computation->root_instruction();
990   EXPECT_EQ(root, param0);
991 }
992 
993 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)994 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
995   auto m = CreateNewVerifiedModule();
996   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
997   HloComputation::Builder builder(TestName());
998   HloInstruction* param0 = builder.AddInstruction(
999       HloInstruction::CreateParameter(0, r2f32, "param0"));
1000   HloInstruction* param1 = builder.AddInstruction(
1001       HloInstruction::CreateParameter(1, r2f32, "param1"));
1002   HloInstruction* cplx = builder.AddInstruction(
1003       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1004                                    HloOpcode::kComplex, param0, param1));
1005   HloInstruction* imag = builder.AddInstruction(
1006       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1007 
1008   auto computation = m->AddEntryComputation(builder.Build());
1009   HloInstruction* root = computation->root_instruction();
1010   EXPECT_EQ(root, imag);
1011   AlgebraicSimplifier simplifier(default_options_);
1012   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1013   root = computation->root_instruction();
1014   EXPECT_EQ(root, param1);
1015 }
1016 
1017 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1018 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1019   auto m = CreateNewVerifiedModule();
1020   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1021   HloComputation::Builder builder(TestName());
1022   HloInstruction* param0 = builder.AddInstruction(
1023       HloInstruction::CreateParameter(0, r0f32, "param0"));
1024   HloInstruction* param1 = builder.AddInstruction(
1025       HloInstruction::CreateParameter(1, r0f32, "param1"));
1026   HloInstruction* param2 = builder.AddInstruction(
1027       HloInstruction::CreateParameter(2, r0f32, "param2"));
1028   HloInstruction* tuple =
1029       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1030   HloInstruction* get = builder.AddInstruction(
1031       HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1032   HloInstruction* add = builder.AddInstruction(
1033       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1034 
1035   auto computation = m->AddEntryComputation(builder.Build());
1036   HloInstruction* root = computation->root_instruction();
1037   EXPECT_EQ(root, add);
1038   AlgebraicSimplifier simplifier(default_options_);
1039   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1040   root = computation->root_instruction();
1041   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1042 }
1043 
1044 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1045 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1046   auto m = CreateNewVerifiedModule();
1047   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1048   HloComputation::Builder builder(TestName());
1049   HloInstruction* param0 = builder.AddInstruction(
1050       HloInstruction::CreateParameter(0, r0f32, "param0"));
1051   HloInstruction* param1 = builder.AddInstruction(
1052       HloInstruction::CreateParameter(1, r0f32, "param1"));
1053   HloInstruction* exp0 = builder.AddInstruction(
1054       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1055   HloInstruction* exp1 = builder.AddInstruction(
1056       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1057   builder.AddInstruction(
1058       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1059 
1060   auto computation = m->AddEntryComputation(builder.Build());
1061 
1062   EXPECT_THAT(
1063       computation->root_instruction(),
1064       GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1065 
1066   AlgebraicSimplifier simplifier(default_options_);
1067   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1068 
1069   EXPECT_THAT(
1070       computation->root_instruction(),
1071       GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1072 }
1073 
1074 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1075 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1076   auto m = CreateNewVerifiedModule();
1077   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1078   HloComputation::Builder builder(TestName());
1079   HloInstruction* param0 = builder.AddInstruction(
1080       HloInstruction::CreateParameter(0, r0f32, "param0"));
1081   HloInstruction* param1 = builder.AddInstruction(
1082       HloInstruction::CreateParameter(1, r0f32, "param1"));
1083   HloInstruction* exp0 = builder.AddInstruction(
1084       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1085   HloInstruction* exp1 = builder.AddInstruction(
1086       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1087   builder.AddInstruction(
1088       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1089 
1090   auto computation = m->AddEntryComputation(builder.Build());
1091 
1092   EXPECT_THAT(computation->root_instruction(),
1093               GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1094                                      m::Exp(m::Parameter(1)))));
1095 
1096   AlgebraicSimplifier simplifier(default_options_);
1097   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1098 
1099   EXPECT_THAT(computation->root_instruction(),
1100               GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1101 }
1102 
1103 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1104 TEST_F(AlgebraicSimplifierTest, PowExp) {
1105   auto m = CreateNewVerifiedModule();
1106   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1107   HloComputation::Builder builder(TestName());
1108   HloInstruction* param0 = builder.AddInstruction(
1109       HloInstruction::CreateParameter(0, r0f32, "param0"));
1110   HloInstruction* param1 = builder.AddInstruction(
1111       HloInstruction::CreateParameter(1, r0f32, "param1"));
1112   HloInstruction* exp0 = builder.AddInstruction(
1113       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1114   builder.AddInstruction(
1115       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1116 
1117   auto computation = m->AddEntryComputation(builder.Build());
1118 
1119   EXPECT_THAT(computation->root_instruction(),
1120               GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1121 
1122   AlgebraicSimplifier simplifier(default_options_);
1123   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1124 
1125   EXPECT_THAT(
1126       computation->root_instruction(),
1127       GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1128 }
1129 
1130 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1131 TEST_F(AlgebraicSimplifierTest, LnPow) {
1132   auto m = CreateNewVerifiedModule();
1133   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1134   HloComputation::Builder builder(TestName());
1135   HloInstruction* param0 = builder.AddInstruction(
1136       HloInstruction::CreateParameter(0, r0f32, "param0"));
1137   HloInstruction* param1 = builder.AddInstruction(
1138       HloInstruction::CreateParameter(1, r0f32, "param1"));
1139   HloInstruction* pow = builder.AddInstruction(
1140       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1141   builder.AddInstruction(
1142       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1143 
1144   auto computation = m->AddEntryComputation(builder.Build());
1145 
1146   EXPECT_THAT(computation->root_instruction(),
1147               GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1148 
1149   AlgebraicSimplifier simplifier(default_options_);
1150   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1151 
1152   EXPECT_THAT(
1153       computation->root_instruction(),
1154       GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1))));
1155 }
1156 
1157 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1158 TEST_F(AlgebraicSimplifierTest, LnExp) {
1159   auto m = CreateNewVerifiedModule();
1160   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1161   HloComputation::Builder builder(TestName());
1162   HloInstruction* param0 = builder.AddInstruction(
1163       HloInstruction::CreateParameter(0, r0f32, "param0"));
1164   HloInstruction* exp0 = builder.AddInstruction(
1165       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1166   builder.AddInstruction(
1167       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1168 
1169   auto computation = m->AddEntryComputation(builder.Build());
1170 
1171   EXPECT_THAT(computation->root_instruction(),
1172               GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1173 
1174   AlgebraicSimplifier simplifier(default_options_);
1175   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1176 
1177   EXPECT_EQ(computation->root_instruction(), param0);
1178 }
1179 
1180 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1181 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1182   auto m = CreateNewVerifiedModule();
1183   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1184   HloComputation::Builder builder(TestName());
1185   HloInstruction* param0 = builder.AddInstruction(
1186       HloInstruction::CreateParameter(0, r0f32, "param0"));
1187   HloInstruction* param1 = builder.AddInstruction(
1188       HloInstruction::CreateParameter(1, r0f32, "param1"));
1189   HloInstruction* exp0 = builder.AddInstruction(
1190       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1191   HloInstruction* exp1 = builder.AddInstruction(
1192       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1193   HloInstruction* div = builder.AddInstruction(
1194       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1195   builder.AddInstruction(
1196       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1197 
1198   auto computation = m->AddEntryComputation(builder.Build());
1199 
1200   EXPECT_THAT(computation->root_instruction(),
1201               GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1202                                           m::Exp(m::Parameter(1))))));
1203 
1204   AlgebraicSimplifier simplifier(default_options_);
1205   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1206 
1207   EXPECT_THAT(computation->root_instruction(),
1208               GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1209 }
1210 
1211 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1212 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1213 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1214   auto m = CreateNewVerifiedModule();
1215   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1216   HloComputation::Builder builder(TestName());
1217   HloInstruction* param0 = builder.AddInstruction(
1218       HloInstruction::CreateParameter(0, r0f32, "param0"));
1219   HloInstruction* zero = builder.AddInstruction(
1220       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1221   builder.AddInstruction(
1222       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1223 
1224   auto computation = m->AddEntryComputation(builder.Build());
1225 
1226   EXPECT_THAT(computation->root_instruction(),
1227               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1228 
1229   AlgebraicSimplifier simplifier(default_options_);
1230   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1231 
1232   HloInstruction* root = computation->root_instruction();
1233   EXPECT_THAT(root, GmockMatch(m::Constant()));
1234   EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1235 }
1236 
1237 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1238 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1239   auto m = CreateNewVerifiedModule();
1240   Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1241   HloComputation::Builder builder(TestName());
1242   HloInstruction* param0 = builder.AddInstruction(
1243       HloInstruction::CreateParameter(0, r1f32, "param0"));
1244   HloInstruction* zero = builder.AddInstruction(
1245       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1246   builder.AddInstruction(
1247       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1248 
1249   auto computation = m->AddEntryComputation(builder.Build());
1250 
1251   EXPECT_THAT(computation->root_instruction(),
1252               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1253 
1254   AlgebraicSimplifier simplifier(default_options_);
1255   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1256 
1257   HloInstruction* root = computation->root_instruction();
1258   EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1259   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1260       << ShapeUtil::HumanString(root->shape());
1261   EXPECT_EQ(root->dimensions().size(), 0);
1262   EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1263   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1264 }
1265 
1266 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1267 TEST_F(AlgebraicSimplifierTest, Pow1) {
1268   auto m = CreateNewVerifiedModule();
1269   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1270   HloComputation::Builder builder(TestName());
1271   HloInstruction* param0 = builder.AddInstruction(
1272       HloInstruction::CreateParameter(0, r0f32, "param0"));
1273   HloInstruction* one = builder.AddInstruction(
1274       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1275   builder.AddInstruction(
1276       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1277 
1278   auto computation = m->AddEntryComputation(builder.Build());
1279 
1280   EXPECT_THAT(computation->root_instruction(),
1281               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1282 
1283   AlgebraicSimplifier simplifier(default_options_);
1284   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1285 
1286   EXPECT_EQ(computation->root_instruction(), param0);
1287 }
1288 
1289 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1290 TEST_F(AlgebraicSimplifierTest, Pow2) {
1291   auto m = CreateNewVerifiedModule();
1292   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1293   HloComputation::Builder builder(TestName());
1294   HloInstruction* param0 = builder.AddInstruction(
1295       HloInstruction::CreateParameter(0, r0f32, "param0"));
1296   HloInstruction* two = builder.AddInstruction(
1297       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1298   builder.AddInstruction(
1299       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1300 
1301   auto computation = m->AddEntryComputation(builder.Build());
1302 
1303   EXPECT_THAT(computation->root_instruction(),
1304               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1305 
1306   AlgebraicSimplifier simplifier(default_options_);
1307   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1308 
1309   EXPECT_THAT(computation->root_instruction(),
1310               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1311 }
1312 
1313 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1314 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1315   auto m = CreateNewVerifiedModule();
1316   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1317   HloComputation::Builder builder(TestName());
1318   HloInstruction* param0 = builder.AddInstruction(
1319       HloInstruction::CreateParameter(0, r0f32, "param0"));
1320   HloInstruction* negative_one = builder.AddInstruction(
1321       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1322   builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1323                                                       param0, negative_one));
1324 
1325   auto computation = m->AddEntryComputation(builder.Build());
1326 
1327   EXPECT_THAT(computation->root_instruction(),
1328               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1329 
1330   AlgebraicSimplifier simplifier(default_options_);
1331   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1332 
1333   HloInstruction* root = computation->root_instruction();
1334   EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0))));
1335   EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
1336   EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
1337             1);
1338 }
1339 
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1340 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1341   auto m = CreateNewVerifiedModule();
1342   auto builder = HloComputation::Builder(TestName());
1343   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1344       0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1345 
1346   HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1347       1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1348 
1349   ConvolutionDimensionNumbers dnums;
1350   dnums.set_input_batch_dimension(0);
1351   dnums.add_input_spatial_dimensions(1);
1352   dnums.set_input_feature_dimension(2);
1353 
1354   dnums.set_output_batch_dimension(0);
1355   dnums.add_output_spatial_dimensions(1);
1356   dnums.set_output_feature_dimension(2);
1357 
1358   dnums.add_kernel_spatial_dimensions(0);
1359   dnums.set_kernel_input_feature_dimension(1);
1360   dnums.set_kernel_output_feature_dimension(2);
1361   Window window;
1362   WindowDimension* dim = window.add_dimensions();
1363   dim->set_size(3);
1364   dim->set_padding_low(0);
1365   dim->set_padding_high(0);
1366   dim->set_stride(1);
1367   dim->set_window_dilation(1);
1368   dim->set_base_dilation(1);
1369   dim->set_window_reversal(false);
1370   // Create add computation.
1371   builder.AddInstruction(HloInstruction::CreateConvolve(
1372       ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1373       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1374   m->AddEntryComputation(builder.Build());
1375   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1376   EXPECT_THAT(m->entry_computation()->root_instruction(),
1377               GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1378   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1379   EXPECT_THAT(m->entry_computation()->root_instruction(),
1380               GmockMatch(m::Broadcast(m::Constant())));
1381 }
1382 
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1383 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1384   auto m = CreateNewVerifiedModule();
1385   auto builder = HloComputation::Builder(TestName());
1386   HloInstruction* param =
1387       builder.AddInstruction(HloInstruction::CreateParameter(
1388           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1389   Window window;
1390   for (int64 i = 0; i < 4; ++i) {
1391     WindowDimension* dim = window.add_dimensions();
1392     // Makes 1x2x3x1 window.
1393     dim->set_size((i % 3) + 1);
1394     dim->set_stride(1);
1395     dim->set_padding_low(0);
1396     dim->set_padding_high(0);
1397     dim->set_window_dilation(1);
1398     dim->set_base_dilation(1);
1399   }
1400   // Create add computation.
1401   HloComputation* add_computation = nullptr;
1402   {
1403     HloComputation::Builder builder(TestName() + ".add");
1404     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1405     HloInstruction* p0 = builder.AddInstruction(
1406         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1407     HloInstruction* p1 = builder.AddInstruction(
1408         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1409     builder.AddInstruction(
1410         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1411     add_computation = m->AddEmbeddedComputation(builder.Build());
1412   }
1413   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1414       ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1415       builder.AddInstruction(
1416           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1417       window, add_computation));
1418   m->AddEntryComputation(builder.Build());
1419   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1420   EXPECT_THAT(m->entry_computation()->root_instruction(),
1421               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1422   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1423   EXPECT_THAT(
1424       m->entry_computation()->root_instruction(),
1425       GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1426 }
1427 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1428 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1429   auto m = CreateNewVerifiedModule();
1430   auto builder = HloComputation::Builder(TestName());
1431   HloInstruction* param =
1432       builder.AddInstruction(HloInstruction::CreateParameter(
1433           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1434   Window window;
1435   for (int64 i = 0; i < 2; ++i) {
1436     WindowDimension* dim = window.add_dimensions();
1437     dim->set_size(1);
1438     dim->set_padding_low(1);
1439     dim->set_padding_high(1);
1440     dim->set_window_dilation(1);
1441     dim->set_base_dilation(1);
1442   }
1443   // Create add computation.
1444   HloComputation* add_computation = nullptr;
1445   {
1446     HloComputation::Builder builder(TestName() + ".add");
1447     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1448     HloInstruction* p0 = builder.AddInstruction(
1449         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1450     HloInstruction* p1 = builder.AddInstruction(
1451         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1452     builder.AddInstruction(
1453         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1454     add_computation = m->AddEmbeddedComputation(builder.Build());
1455   }
1456   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1457       ShapeUtil::MakeShape(F32, {5, 2}), param,
1458       builder.AddInstruction(
1459           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1460       window, add_computation));
1461   m->AddEntryComputation(builder.Build());
1462   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1463   EXPECT_THAT(m->entry_computation()->root_instruction(),
1464               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1465   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1466   EXPECT_THAT(m->entry_computation()->root_instruction(),
1467               GmockMatch(m::Broadcast(m::Constant())));
1468 }
1469 
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)1470 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
1471   auto m = CreateNewVerifiedModule();
1472   auto builder = HloComputation::Builder(TestName());
1473   HloInstruction* param =
1474       builder.AddInstruction(HloInstruction::CreateParameter(
1475           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1476   PaddingConfig padding;
1477   for (int i = 0; i < 2; ++i) {
1478     PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
1479     dimension->set_edge_padding_low(1);
1480     dimension->set_edge_padding_high(1);
1481     dimension->set_interior_padding(0);
1482   }
1483   builder.AddInstruction(HloInstruction::CreatePad(
1484       ShapeUtil::MakeShape(F32, {5, 2}), param,
1485       builder.AddInstruction(
1486           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
1487       padding));
1488   m->AddEntryComputation(builder.Build());
1489   EXPECT_THAT(m->entry_computation()->root_instruction(),
1490               GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
1491   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1492   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1493   EXPECT_THAT(m->entry_computation()->root_instruction(),
1494               GmockMatch(m::Broadcast(m::Constant())));
1495 }
1496 
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)1497 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
1498   auto m = CreateNewVerifiedModule();
1499   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1500 
1501   auto builder = HloComputation::Builder(TestName());
1502   auto op = builder.AddInstruction(HloInstruction::CreateParameter(
1503       0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
1504   auto reshape1 = builder.AddInstruction(
1505       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
1506   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1507       ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
1508   builder.AddInstruction(HloInstruction::CreateReshape(
1509       ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
1510 
1511   auto computation = builder.Build();
1512   m->AddEntryComputation(std::move(computation));
1513 
1514   EXPECT_THAT(m->entry_computation()->root_instruction(),
1515               GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
1516 
1517   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1518   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1519 
1520   EXPECT_THAT(m->entry_computation()->root_instruction(), op);
1521 }
1522 
1523 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)1524 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
1525   auto m = CreateNewVerifiedModule();
1526   HloComputation::Builder builder(TestName());
1527   HloInstruction* input = builder.AddInstruction(
1528       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
1529   builder.AddInstruction(
1530       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
1531 
1532   auto computation = m->AddEntryComputation(builder.Build());
1533 
1534   EXPECT_THAT(computation->root_instruction(),
1535               GmockMatch(m::Convert(m::Op().Is(input))));
1536 
1537   AlgebraicSimplifier simplifier(default_options_);
1538   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1539 
1540   EXPECT_THAT(computation->root_instruction(), input);
1541 }
1542 
1543 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)1544 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
1545   auto m = CreateNewVerifiedModule();
1546   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1547   HloComputation::Builder builder(TestName());
1548   HloInstruction* param0 = builder.AddInstruction(
1549       HloInstruction::CreateParameter(0, r0f32, "param0"));
1550   builder.AddInstruction(
1551       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1552 
1553   auto computation = m->AddEntryComputation(builder.Build());
1554 
1555   EXPECT_THAT(computation->root_instruction(),
1556               GmockMatch(m::Copy(m::Parameter(0))));
1557 
1558   AlgebraicSimplifier simplifier(default_options_);
1559   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1560 
1561   EXPECT_THAT(computation->root_instruction(), param0);
1562 }
1563 
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)1564 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
1565   auto m = CreateNewVerifiedModule();
1566   HloComputation::Builder builder(TestName());
1567   HloInstruction* param =
1568       builder.AddInstruction(HloInstruction::CreateParameter(
1569           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1570           "param"));
1571   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
1572       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1573       HloOpcode::kCopy, param));
1574   HloInstruction* reshape =
1575       builder.AddInstruction(HloInstruction::CreateReshape(
1576           ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
1577   builder.AddInstruction(HloInstruction::CreateUnary(
1578       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
1579       HloOpcode::kCopy, reshape));
1580   auto computation = m->AddEntryComputation(builder.Build());
1581   EXPECT_THAT(computation->root_instruction(),
1582               GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
1583 
1584   AlgebraicSimplifierOptions options;
1585   options.set_is_layout_sensitive(true);
1586   AlgebraicSimplifier simplifier(options);
1587   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1588   // Verify that the copy of reshape of copy is replaced.
1589   EXPECT_THAT(computation->root_instruction(),
1590               GmockMatch(m::Bitcast(m::Parameter(0))));
1591 }
1592 
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)1593 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
1594   auto m = CreateNewVerifiedModule();
1595   HloComputation::Builder builder(TestName());
1596   HloInstruction* param =
1597       builder.AddInstruction(HloInstruction::CreateParameter(
1598           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1599           "param"));
1600   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
1601       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1602       HloOpcode::kCopy, param));
1603   builder.AddInstruction(HloInstruction::CreateReshape(
1604       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
1605 
1606   auto computation = m->AddEntryComputation(builder.Build());
1607   EXPECT_THAT(computation->root_instruction(),
1608               GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
1609 
1610   AlgebraicSimplifierOptions options;
1611   options.set_is_layout_sensitive(true);
1612   AlgebraicSimplifier simplifier(options);
1613   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1614   // Verify that the copy of reshape of copy is replaced.
1615   EXPECT_THAT(computation->root_instruction(),
1616               GmockMatch(m::Bitcast(m::Parameter(0))));
1617 }
1618 
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)1619 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
1620   auto m = CreateNewVerifiedModule();
1621   HloComputation::Builder builder(TestName());
1622   HloInstruction* param =
1623       builder.AddInstruction(HloInstruction::CreateParameter(
1624           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1625           "param"));
1626   builder.AddInstruction(HloInstruction::CreateUnary(
1627       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
1628       HloOpcode::kCopy, param));
1629   auto computation = m->AddEntryComputation(builder.Build());
1630   EXPECT_THAT(computation->root_instruction(),
1631               GmockMatch(m::Copy(m::Parameter(0))));
1632 
1633   AlgebraicSimplifierOptions options(
1634       [](const Shape&, const Shape&) { return false; });
1635   options.set_is_layout_sensitive(true);
1636   AlgebraicSimplifier simplifier1(options);
1637   ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
1638   // Verify that the copy is not replaced.
1639   EXPECT_THAT(computation->root_instruction(),
1640               GmockMatch(m::Copy(m::Parameter(0))));
1641 
1642   AlgebraicSimplifierOptions options2;
1643   options2.set_is_layout_sensitive(true);
1644   AlgebraicSimplifier simplifier2(options2);
1645   EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
1646   // Verify that the copy is replaced.
1647   EXPECT_THAT(computation->root_instruction(),
1648               GmockMatch(m::Bitcast(m::Parameter(0))));
1649 }
1650 
1651 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)1652 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
1653   auto m = CreateNewVerifiedModule();
1654   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1655   HloComputation::Builder builder(TestName());
1656   HloInstruction* param0 = builder.AddInstruction(
1657       HloInstruction::CreateParameter(0, r1f32, "param0"));
1658   builder.AddInstruction(
1659       HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
1660 
1661   auto computation = m->AddEntryComputation(builder.Build());
1662 
1663   EXPECT_THAT(computation->root_instruction(),
1664               GmockMatch(m::Concatenate(m::Parameter(0))));
1665 
1666   AlgebraicSimplifier simplifier(default_options_);
1667   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1668 
1669   EXPECT_THAT(computation->root_instruction(), param0);
1670 }
1671 
1672 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)1673 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
1674   auto m = CreateNewVerifiedModule();
1675   const int kParamLength = 100;
1676   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1677   HloComputation::Builder builder(TestName());
1678   HloInstruction* param0 = builder.AddInstruction(
1679       HloInstruction::CreateParameter(0, r1f32, "param0"));
1680   HloInstruction* param1 = builder.AddInstruction(
1681       HloInstruction::CreateParameter(1, r1f32, "param1"));
1682   HloInstruction* empty_literal = builder.AddInstruction(
1683       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
1684   HloInstruction* empty_slice =
1685       builder.AddInstruction(HloInstruction::CreateSlice(
1686           ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
1687   Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
1688   builder.AddInstruction(HloInstruction::CreateConcatenate(
1689       result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
1690 
1691   auto computation = m->AddEntryComputation(builder.Build());
1692 
1693   EXPECT_THAT(computation->root_instruction(),
1694               GmockMatch(m::Concatenate(
1695                   m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
1696                   m::Op().Is(empty_slice), m::Parameter(1))));
1697 
1698   AlgebraicSimplifier simplifier(default_options_);
1699   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1700 
1701   EXPECT_THAT(computation->root_instruction(),
1702               GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
1703                                         m::Parameter(1))));
1704 }
1705 
1706 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)1707 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
1708   auto m = CreateNewVerifiedModule();
1709   const int kParamLength = 100;
1710   Shape r3f32 =
1711       ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
1712   HloComputation::Builder builder(TestName());
1713   HloInstruction* param0 = builder.AddInstruction(
1714       HloInstruction::CreateParameter(0, r3f32, "param0"));
1715   HloInstruction* param1 = builder.AddInstruction(
1716       HloInstruction::CreateParameter(1, r3f32, "param1"));
1717   HloInstruction* param2 = builder.AddInstruction(
1718       HloInstruction::CreateParameter(2, r3f32, "param2"));
1719   Shape concat_shape =
1720       ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
1721   HloInstruction* Concatenate =
1722       builder.AddInstruction(HloInstruction::CreateConcatenate(
1723           concat_shape, {param0, param1, param2}, 1));
1724   HloComputation* add_computation = nullptr;
1725   {
1726     HloComputation::Builder builder(TestName() + ".add");
1727     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1728     HloInstruction* p0 = builder.AddInstruction(
1729         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1730     HloInstruction* p1 = builder.AddInstruction(
1731         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1732     builder.AddInstruction(
1733         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1734     add_computation = m->AddEmbeddedComputation(builder.Build());
1735   }
1736   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
1737   Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
1738 
1739   HloInstruction* zero = builder.AddInstruction(
1740       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1741   builder.AddInstruction(HloInstruction::CreateReduce(
1742       reduce_shape, Concatenate, zero, {1, 2}, add_computation));
1743 
1744   auto computation = m->AddEntryComputation(builder.Build());
1745 
1746   AlgebraicSimplifier simplifier(default_options_);
1747   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1748 
1749   EXPECT_THAT(
1750       computation->root_instruction(),
1751       GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
1752                                m::Reduce(m::Parameter(1), m::Op().Is(zero))),
1753                         m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
1754 }
1755 
1756 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)1757 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
1758   auto m = CreateNewVerifiedModule();
1759   const int kParamLength = 100;
1760   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1761   HloComputation::Builder builder(TestName());
1762   HloInstruction* param0 = builder.AddInstruction(
1763       HloInstruction::CreateParameter(0, r1f32, "param0"));
1764   HloInstruction* empty_literal = builder.AddInstruction(
1765       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
1766   HloInstruction* empty_slice =
1767       builder.AddInstruction(HloInstruction::CreateSlice(
1768           ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
1769   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
1770   builder.AddInstruction(HloInstruction::CreateConcatenate(
1771       result_shape, {empty_literal, empty_slice}, 0));
1772 
1773   auto computation = m->AddEntryComputation(builder.Build());
1774 
1775   EXPECT_THAT(computation->root_instruction(),
1776               GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
1777                                         m::Op().Is(empty_slice))));
1778 
1779   AlgebraicSimplifier simplifier(default_options_);
1780   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1781 
1782   EXPECT_EQ(computation->root_instruction(), empty_literal);
1783 }
1784 
1785 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)1786 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
1787   auto m = CreateNewVerifiedModule();
1788   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1789   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1790   HloComputation::Builder builder(TestName());
1791   HloInstruction* param0 = builder.AddInstruction(
1792       HloInstruction::CreateParameter(0, r1f32, "param0"));
1793   HloInstruction* param1 = builder.AddInstruction(
1794       HloInstruction::CreateParameter(1, r0f32, "param1"));
1795   HloInstruction* broadcast = builder.AddInstruction(
1796       HloInstruction::CreateBroadcast(r1f32, param1, {}));
1797   builder.AddInstruction(HloInstruction::CreateConcatenate(
1798       ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
1799 
1800   auto computation = m->AddEntryComputation(builder.Build());
1801 
1802   AlgebraicSimplifier simplifier(default_options_);
1803   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1804   EXPECT_THAT(computation->root_instruction(),
1805               GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
1806 }
1807 
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)1808 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
1809   auto m = CreateNewVerifiedModule();
1810   Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
1811   Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80});
1812   HloComputation::Builder builder(TestName());
1813   HloInstruction* param0 = builder.AddInstruction(
1814       HloInstruction::CreateParameter(0, r2f32, "param0"));
1815   HloInstruction* param1 = builder.AddInstruction(
1816       HloInstruction::CreateParameter(1, r2f32, "param1"));
1817 
1818   HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
1819       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
1820       /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
1821 
1822   // Cannot merge 'slice0' and 'slice1' because of different start indices in
1823   // dimension 0.
1824   HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
1825       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
1826       /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
1827 
1828   // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
1829   HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
1830       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
1831       /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
1832 
1833   // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
1834   HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
1835       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
1836       /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
1837 
1838   // Can merge 'slice3' and 'slice4'.
1839   HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
1840       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
1841       /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
1842 
1843   // Can merge 'slice4' and 'slice5'.
1844   HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
1845       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
1846       /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
1847 
1848   // Cannot merge 'slice5' and 'slice6' because of overlap.
1849   HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
1850       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
1851       /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
1852 
1853   // Cannot merge 'slice6' and 'slice7' because of slicing from a different
1854   // parameter.
1855   HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
1856       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
1857       /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
1858 
1859   builder.AddInstruction(HloInstruction::CreateConcatenate(
1860       concat_shape,
1861       {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1));
1862   auto computation = m->AddEntryComputation(builder.Build());
1863 
1864   AlgebraicSimplifier simplifier(default_options_);
1865   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1866   auto s = m::Slice(m::Parameter(0));
1867   EXPECT_THAT(
1868       computation->root_instruction(),
1869       GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
1870   // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
1871   // shape should have dimensions {50, 30}.
1872   EXPECT_TRUE(
1873       ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
1874                        ShapeUtil::MakeShape(F32, {50, 30})));
1875   EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
1876 }
1877 
1878 // Test that a simplification which changes layouts is not performed if layout
1879 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)1880 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
1881   auto m = CreateNewVerifiedModule();
1882   HloComputation::Builder builder(TestName());
1883   HloInstruction* param0 =
1884       builder.AddInstruction(HloInstruction::CreateParameter(
1885           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1886   HloInstruction* copy = builder.AddInstruction(
1887       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1888 
1889   auto computation = m->AddEntryComputation(builder.Build());
1890 
1891   // Set to different layouts.
1892   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1893   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
1894 
1895   EXPECT_THAT(computation->root_instruction(),
1896               GmockMatch(m::Copy(m::Parameter(0))));
1897 
1898   AlgebraicSimplifierOptions options;
1899   options.set_is_layout_sensitive(true);
1900   AlgebraicSimplifier simplifier(options);
1901   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1902 
1903   // Copy has not been removed.
1904   EXPECT_THAT(computation->root_instruction(),
1905               GmockMatch(m::Copy(m::Parameter(0))));
1906 }
1907 
1908 // Test that a simplification which preserves layouts is performed if layout
1909 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)1910 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
1911   auto m = CreateNewVerifiedModule();
1912   HloComputation::Builder builder(TestName());
1913   HloInstruction* param0 =
1914       builder.AddInstruction(HloInstruction::CreateParameter(
1915           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1916   HloInstruction* copy = builder.AddInstruction(
1917       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1918 
1919   auto computation = m->AddEntryComputation(builder.Build());
1920 
1921   // Set to same layouts.
1922   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1923   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1924 
1925   EXPECT_THAT(computation->root_instruction(),
1926               GmockMatch(m::Copy(m::Parameter(0))));
1927 
1928   AlgebraicSimplifierOptions options;
1929   options.set_is_layout_sensitive(true);
1930   AlgebraicSimplifier simplifier(options);
1931   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1932 
1933   // Copy has been removed.
1934   EXPECT_THAT(computation->root_instruction(), param0);
1935 }
1936 
1937 // Test that a reshape which could be replaced with a bitcast is not if
1938 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)1939 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
1940   auto m = CreateNewVerifiedModule();
1941   HloComputation::Builder builder(TestName());
1942   HloInstruction* param0 =
1943       builder.AddInstruction(HloInstruction::CreateParameter(
1944           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1945   HloInstruction* reshape =
1946       builder.AddInstruction(HloInstruction::CreateReshape(
1947           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
1948 
1949   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1950   *reshape->mutable_shape()->mutable_layout() =
1951       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
1952 
1953   auto computation = m->AddEntryComputation(builder.Build());
1954 
1955   EXPECT_THAT(computation->root_instruction(),
1956               GmockMatch(m::Reshape(m::Parameter(0))));
1957 
1958   AlgebraicSimplifierOptions options(
1959       [](const Shape&, const Shape&) { return false; });
1960   options.set_is_layout_sensitive(true);
1961   AlgebraicSimplifier simplifier(options);
1962   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1963 
1964   // Reshape is not replaced with a bitcast.
1965   EXPECT_THAT(computation->root_instruction(),
1966               GmockMatch(m::Reshape(m::Parameter(0))));
1967 }
1968 
1969 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)1970 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
1971   auto m = CreateNewVerifiedModule();
1972   HloComputation::Builder builder(TestName());
1973   HloInstruction* zero = builder.AddInstruction(
1974       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
1975   HloInstruction* one = builder.AddInstruction(
1976       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1977   HloInstruction* rng0 = builder.AddInstruction(
1978       HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
1979                                 RandomDistribution::RNG_UNIFORM, {zero, one}));
1980 
1981   HloInstruction* transpose = builder.AddInstruction(
1982       HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
1983   Shape reshape_shape = builder
1984                             .AddInstruction(HloInstruction::CreateReshape(
1985                                 ShapeUtil::MakeShape(F32, {4}), transpose))
1986                             ->shape();
1987 
1988   auto computation = m->AddEntryComputation(builder.Build());
1989 
1990   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
1991   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1992 
1993   // Verify that reshape(transpose(rng)) is replace by a single rng of the
1994   // same shape as the reshape.
1995   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
1996   EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
1997                                reshape_shape));
1998 }
1999 
2000 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2001 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2002   auto m = CreateNewVerifiedModule();
2003   HloComputation::Builder builder(TestName());
2004   HloInstruction* param0 =
2005       builder.AddInstruction(HloInstruction::CreateParameter(
2006           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2007   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2008 
2009   // Reshape which can be transformed into a bitcast.
2010   HloInstruction* transformable_reshape =
2011       builder.AddInstruction(HloInstruction::CreateReshape(
2012           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2013   *transformable_reshape->mutable_shape()->mutable_layout() =
2014       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2015 
2016   // Reshape does not just add degenerate dimensions.
2017   HloInstruction* dimensions_wrong_reshape =
2018       builder.AddInstruction(HloInstruction::CreateReshape(
2019           ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2020   *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2021       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2022 
2023   // Reshape has wrong layout.
2024   HloInstruction* layout_wrong_reshape =
2025       builder.AddInstruction(HloInstruction::CreateReshape(
2026           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2027   *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2028       LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2029 
2030   // Collect all the reshapes into a tuple so they are not dead.
2031   builder.AddInstruction(HloInstruction::CreateTuple(
2032       {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2033 
2034   auto computation = m->AddEntryComputation(builder.Build());
2035 
2036   EXPECT_THAT(computation->root_instruction(),
2037               GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2038                                   m::Op().Is(dimensions_wrong_reshape),
2039                                   m::Op().Is(layout_wrong_reshape))));
2040 
2041   AlgebraicSimplifierOptions options;
2042   options.set_is_layout_sensitive(true);
2043   AlgebraicSimplifier simplifier(options);
2044   simplifier.Run(m.get()).ValueOrDie();
2045 
2046   // Verify that only the first reshape is replaced.
2047   EXPECT_THAT(
2048       computation->root_instruction(),
2049       GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
2050                           m::Op().Is(layout_wrong_reshape))));
2051 }
2052 
2053 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2054 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)2055 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
2056   auto m = CreateNewVerifiedModule();
2057   HloComputation::Builder builder(TestName());
2058 
2059   // This add (param0 + 0) can be simplified.
2060   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2061   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2062       shape, HloOpcode::kAdd,
2063       builder.AddInstruction(
2064           HloInstruction::CreateParameter(0, shape, "param0")),
2065       builder.AddInstruction(HloInstruction::CreateConstant(
2066           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2067 
2068   builder.AddInstruction(
2069       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
2070 
2071   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2072   m->AddEntryComputation(builder.Build());
2073   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2074 }
2075 
2076 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2077 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)2078 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
2079   auto m = CreateNewVerifiedModule();
2080   HloComputation::Builder builder(TestName());
2081 
2082   // This add (param0 + 0) can be simplified.
2083   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2084   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2085       shape, HloOpcode::kAdd,
2086       builder.AddInstruction(
2087           HloInstruction::CreateParameter(0, shape, "param0")),
2088       builder.AddInstruction(HloInstruction::CreateConstant(
2089           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2090 
2091   builder.AddInstruction(
2092       HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
2093                                       /*broadcast_dimensions=*/{0, 1}));
2094 
2095   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2096   m->AddEntryComputation(builder.Build());
2097   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2098 }
2099 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)2100 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
2101   auto m = CreateNewVerifiedModule();
2102   HloComputation::Builder builder(TestName());
2103   HloInstruction* param =
2104       builder.AddInstruction(HloInstruction::CreateParameter(
2105           0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
2106   *param->mutable_shape()->mutable_layout() =
2107       LayoutUtil::MakeLayout({1, 2, 0, 3});
2108 
2109   HloInstruction* transpose =
2110       builder.AddInstruction(HloInstruction::CreateTranspose(
2111           ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
2112   *transpose->mutable_shape()->mutable_layout() =
2113       LayoutUtil::MakeLayout({0, 1, 2, 3});
2114 
2115   auto computation = m->AddEntryComputation(builder.Build());
2116 
2117   EXPECT_THAT(computation->root_instruction(),
2118               GmockMatch(m::Transpose(m::Parameter(0))));
2119 
2120   AlgebraicSimplifierOptions options;
2121   options.set_is_layout_sensitive(true);
2122   AlgebraicSimplifier simplifier(options);
2123   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2124 
2125   // Verify that the reshape is replaced.
2126   EXPECT_THAT(computation->root_instruction(),
2127               GmockMatch(m::Bitcast(m::Parameter(0))));
2128 }
2129 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)2130 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
2131   auto m = CreateNewVerifiedModule();
2132   HloComputation::Builder builder(TestName());
2133   HloInstruction* param =
2134       builder.AddInstruction(HloInstruction::CreateParameter(
2135           0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
2136   *param->mutable_shape()->mutable_layout() =
2137       LayoutUtil::MakeLayout({1, 2, 3, 0});
2138 
2139   HloInstruction* transpose =
2140       builder.AddInstruction(HloInstruction::CreateTranspose(
2141           ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
2142   *transpose->mutable_shape()->mutable_layout() =
2143       LayoutUtil::MakeLayout({3, 1, 2, 0});
2144 
2145   auto computation = m->AddEntryComputation(builder.Build());
2146 
2147   EXPECT_THAT(computation->root_instruction(),
2148               GmockMatch(m::Transpose(m::Parameter(0))));
2149 
2150   AlgebraicSimplifierOptions options;
2151   options.set_is_layout_sensitive(true);
2152   AlgebraicSimplifier simplifier(options);
2153   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2154 
2155   // Verify that the reshape is replaced.
2156   EXPECT_THAT(computation->root_instruction(),
2157               GmockMatch(m::Bitcast(m::Parameter(0))));
2158 }
2159 
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)2160 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
2161   auto m = CreateNewVerifiedModule();
2162   HloComputation::Builder builder(TestName());
2163   HloInstruction* param0 =
2164       builder.AddInstruction(HloInstruction::CreateParameter(
2165           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2166 
2167   HloInstruction* reshape1 =
2168       builder.AddInstruction(HloInstruction::CreateReshape(
2169           ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
2170 
2171   builder.AddInstruction(HloInstruction::CreateReshape(
2172       ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
2173 
2174   auto computation = m->AddEntryComputation(builder.Build());
2175 
2176   EXPECT_THAT(computation->root_instruction(),
2177               GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
2178 
2179   AlgebraicSimplifier simplifier(default_options_);
2180   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2181 
2182   EXPECT_THAT(computation->root_instruction(),
2183               GmockMatch(m::Reshape(m::Parameter(0))));
2184 }
2185 
TEST_F(AlgebraicSimplifierTest,CopiesMerged)2186 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
2187   auto m = CreateNewVerifiedModule();
2188   HloComputation::Builder builder(TestName());
2189   HloInstruction* param0 =
2190       builder.AddInstruction(HloInstruction::CreateParameter(
2191           0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
2192           "param0"));
2193 
2194   HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2195       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
2196       HloOpcode::kCopy, param0));
2197 
2198   builder.AddInstruction(HloInstruction::CreateUnary(
2199       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
2200       HloOpcode::kCopy, copy1));
2201 
2202   auto computation = m->AddEntryComputation(builder.Build());
2203 
2204   EXPECT_THAT(computation->root_instruction(),
2205               GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
2206 
2207   AlgebraicSimplifierOptions options;
2208   options.set_is_layout_sensitive(true);
2209   AlgebraicSimplifier simplifier(options);
2210   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2211 
2212   EXPECT_THAT(computation->root_instruction(),
2213               GmockMatch(m::Copy(m::Parameter(0))));
2214 }
2215 
TEST_F(AlgebraicSimplifierTest,TransposesMerged)2216 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
2217   auto m = CreateNewVerifiedModule();
2218   HloComputation::Builder builder(TestName());
2219   HloInstruction* param0 =
2220       builder.AddInstruction(HloInstruction::CreateParameter(
2221           0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
2222 
2223   HloInstruction* transpose1 =
2224       builder.AddInstruction(HloInstruction::CreateTranspose(
2225           ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
2226 
2227   builder.AddInstruction(HloInstruction::CreateTranspose(
2228       ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
2229 
2230   auto computation = m->AddEntryComputation(builder.Build());
2231 
2232   EXPECT_THAT(computation->root_instruction(),
2233               GmockMatch(m::Transpose(m::Op().Is(transpose1))));
2234 
2235   AlgebraicSimplifier simplifier(default_options_);
2236   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2237 
2238   EXPECT_THAT(computation->root_instruction(),
2239               GmockMatch(m::Transpose(m::Parameter(0))));
2240   EXPECT_EQ(std::vector<int64>({2, 1, 0}),
2241             computation->root_instruction()->dimensions());
2242 }
2243 
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)2244 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
2245   const char* hlo_string = R"(
2246     HloModule module
2247 
2248     ENTRY test {
2249       param = f32[10] parameter(0)
2250       reshaped = f32[1,1,10] reshape(f32[10] param)
2251       transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
2252       ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
2253     }
2254   )";
2255   TF_ASSERT_OK_AND_ASSIGN(auto module,
2256                           ParseAndReturnVerifiedModule(hlo_string));
2257 
2258   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2259   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2260   auto root = module->entry_computation()->root_instruction();
2261   EXPECT_THAT(root, GmockMatch(m::Parameter()));
2262 }
2263 
2264 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)2265 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
2266   auto m = CreateNewVerifiedModule();
2267   HloComputation::Builder builder(TestName());
2268   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2269       0, ShapeUtil::MakeShape(F32, {5}), "param0"));
2270   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
2271       ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
2272   builder.AddInstruction(HloInstruction::CreateBroadcast(
2273       ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
2274 
2275   auto computation = m->AddEntryComputation(builder.Build());
2276 
2277   EXPECT_THAT(computation->root_instruction(),
2278               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2279 
2280   AlgebraicSimplifier simplifier(default_options_);
2281   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2282 
2283   EXPECT_THAT(computation->root_instruction(),
2284               GmockMatch(m::Broadcast(m::Parameter(0))));
2285 }
2286 
2287 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)2288 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
2289   auto m = CreateNewVerifiedModule();
2290   HloComputation::Builder builder(TestName());
2291   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2292       0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
2293   auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
2294       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
2295   builder.AddInstruction(HloInstruction::CreateReshape(
2296       ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
2297 
2298   auto computation = m->AddEntryComputation(builder.Build());
2299 
2300   EXPECT_THAT(computation->root_instruction(),
2301               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2302 
2303   AlgebraicSimplifier simplifier(default_options_);
2304   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2305 
2306   EXPECT_THAT(computation->root_instruction(),
2307               GmockMatch(m::Broadcast(m::Parameter(0))));
2308 }
2309 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)2310 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
2311   auto m = CreateNewVerifiedModule();
2312   HloComputation::Builder builder(TestName());
2313   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2314       0, ShapeUtil::MakeShape(F32, {1}), "param"));
2315   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2316       ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
2317   builder.AddInstruction(
2318       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
2319 
2320   auto computation = m->AddEntryComputation(builder.Build());
2321 
2322   EXPECT_THAT(computation->root_instruction(),
2323               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2324 
2325   AlgebraicSimplifier simplifier(default_options_);
2326   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2327 
2328   EXPECT_THAT(computation->root_instruction(),
2329               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2330 }
2331 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)2332 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
2333   auto m = CreateNewVerifiedModule();
2334   HloComputation::Builder builder(TestName());
2335   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2336       0, ShapeUtil::MakeShape(F32, {4}), "param"));
2337   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2338       ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
2339   builder.AddInstruction(HloInstruction::CreateReshape(
2340       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
2341 
2342   HloComputation* computation = m->AddEntryComputation(builder.Build());
2343 
2344   EXPECT_THAT(computation->root_instruction(),
2345               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2346 
2347   AlgebraicSimplifier simplifier(default_options_);
2348   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2349 
2350   EXPECT_THAT(computation->root_instruction(),
2351               GmockMatch(m::Broadcast(m::Parameter(0))));
2352   EXPECT_THAT(computation->root_instruction()->dimensions(),
2353               ::testing::ElementsAre(3));
2354 }
2355 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)2356 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
2357   auto m = CreateNewVerifiedModule();
2358   HloComputation::Builder builder(TestName());
2359   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2360       0, ShapeUtil::MakeShape(F32, {1}), "param"));
2361   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2362       ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
2363   builder.AddInstruction(HloInstruction::CreateReshape(
2364       ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
2365 
2366   HloComputation* computation = m->AddEntryComputation(builder.Build());
2367 
2368   EXPECT_THAT(computation->root_instruction(),
2369               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2370 
2371   AlgebraicSimplifier simplifier(default_options_);
2372   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2373 
2374   EXPECT_THAT(computation->root_instruction(),
2375               GmockMatch(m::Broadcast(m::Parameter(0))));
2376   const std::vector<int64> broadcast_dims =
2377       computation->root_instruction()->dimensions();
2378   EXPECT_EQ(1, broadcast_dims.size());
2379   EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3));
2380 }
2381 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)2382 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
2383   auto m = CreateNewVerifiedModule();
2384   HloComputation::Builder builder(TestName());
2385   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2386       0, ShapeUtil::MakeShape(F32, {4}), "param"));
2387   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2388       ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
2389   builder.AddInstruction(HloInstruction::CreateReshape(
2390       ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
2391 
2392   HloComputation* computation = m->AddEntryComputation(builder.Build());
2393 
2394   EXPECT_THAT(computation->root_instruction(),
2395               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2396 
2397   AlgebraicSimplifier simplifier(default_options_);
2398   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2399 
2400   EXPECT_THAT(computation->root_instruction(),
2401               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2402 }
2403 
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)2404 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
2405   auto m = CreateNewVerifiedModule();
2406   HloComputation::Builder builder(TestName());
2407   auto iota = builder.AddInstruction(HloInstruction::CreateIota(
2408       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
2409   Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
2410   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
2411 
2412   auto computation = m->AddEntryComputation(builder.Build());
2413 
2414   EXPECT_THAT(computation->root_instruction(),
2415               GmockMatch(m::Reshape(m::Iota())));
2416 
2417   AlgebraicSimplifier simplifier(default_options_);
2418   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2419 
2420   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2421   EXPECT_TRUE(
2422       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
2423 }
2424 
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)2425 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
2426   auto m = CreateNewVerifiedModule();
2427   HloComputation::Builder builder(TestName());
2428   auto iota = builder.AddInstruction(
2429       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
2430   auto result_shape = iota->shape();
2431 
2432   auto computation = m->AddEntryComputation(builder.Build());
2433 
2434   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2435 
2436   AlgebraicSimplifier simplifier(default_options_);
2437   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2438 
2439   auto root = computation->root_instruction();
2440   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
2441   EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
2442   EXPECT_TRUE(
2443       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
2444 }
2445 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)2446 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
2447   auto m = CreateNewVerifiedModule();
2448   HloComputation::Builder builder(TestName());
2449   auto iota = builder.AddInstruction(
2450       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
2451   builder.AddInstruction(
2452       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
2453 
2454   auto computation = m->AddEntryComputation(builder.Build());
2455 
2456   EXPECT_THAT(computation->root_instruction(),
2457               GmockMatch(m::Reshape(m::Iota())));
2458 
2459   AlgebraicSimplifier simplifier(default_options_);
2460   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2461 
2462   EXPECT_THAT(computation->root_instruction(),
2463               GmockMatch(m::Reshape(m::Iota())));
2464 }
2465 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)2466 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
2467   auto m = CreateNewVerifiedModule();
2468   HloComputation::Builder builder(TestName());
2469   auto iota = builder.AddInstruction(
2470       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
2471   builder.AddInstruction(HloInstruction::CreateReshape(
2472       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
2473 
2474   HloComputation* computation = m->AddEntryComputation(builder.Build());
2475 
2476   EXPECT_THAT(computation->root_instruction(),
2477               GmockMatch(m::Reshape(m::Iota())));
2478 
2479   AlgebraicSimplifier simplifier(default_options_);
2480   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2481 
2482   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2483   EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
2484                 ->iota_dimension(),
2485             3);
2486 }
2487 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)2488 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
2489   auto m = CreateNewVerifiedModule();
2490   HloComputation::Builder builder(TestName());
2491   auto iota = builder.AddInstruction(
2492       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
2493   builder.AddInstruction(HloInstruction::CreateReshape(
2494       ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
2495 
2496   HloComputation* computation = m->AddEntryComputation(builder.Build());
2497 
2498   EXPECT_THAT(computation->root_instruction(),
2499               GmockMatch(m::Reshape(m::Iota())));
2500 
2501   AlgebraicSimplifier simplifier(default_options_);
2502   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2503 
2504   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2505   const int64 iota_dim =
2506       Cast<HloIotaInstruction>(computation->root_instruction())
2507           ->iota_dimension();
2508   EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
2509 }
2510 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)2511 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
2512   auto m = CreateNewVerifiedModule();
2513   HloComputation::Builder builder(TestName());
2514   auto iota = builder.AddInstruction(
2515       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
2516   builder.AddInstruction(
2517       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
2518 
2519   HloComputation* computation = m->AddEntryComputation(builder.Build());
2520 
2521   EXPECT_THAT(computation->root_instruction(),
2522               GmockMatch(m::Reshape(m::Iota())));
2523 
2524   AlgebraicSimplifier simplifier(default_options_);
2525   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2526 
2527   EXPECT_THAT(computation->root_instruction(),
2528               GmockMatch(m::Reshape(m::Iota())));
2529 }
2530 
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)2531 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
2532   HloComputation::Builder builder(TestName());
2533   HloInstruction* param =
2534       builder.AddInstruction(HloInstruction::CreateParameter(
2535           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
2536   HloInstruction* zero = builder.AddInstruction(
2537       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2538   PaddingConfig no_padding;
2539   for (int i = 0; i < 2; ++i) {
2540     auto dimension = no_padding.add_dimensions();
2541     dimension->set_edge_padding_low(0);
2542     dimension->set_edge_padding_high(0);
2543     dimension->set_interior_padding(0);
2544   }
2545   builder.AddInstruction(HloInstruction::CreatePad(
2546       ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
2547 
2548   auto module = CreateNewVerifiedModule();
2549   HloComputation* computation = module->AddEntryComputation(builder.Build());
2550 
2551   EXPECT_THAT(computation->root_instruction(),
2552               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2553 
2554   AlgebraicSimplifier simplifier(default_options_);
2555   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2556 
2557   EXPECT_THAT(computation->root_instruction(), param);
2558 }
2559 
TEST_F(AlgebraicSimplifierTest,NegativePadding)2560 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
2561   // Verify that a pad instruction with negative padding is replaced with a
2562   // pad with non-negative padding followed by a slice.
2563   HloComputation::Builder builder(TestName());
2564   HloInstruction* param =
2565       builder.AddInstruction(HloInstruction::CreateParameter(
2566           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
2567   HloInstruction* zero = builder.AddInstruction(
2568       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2569   PaddingConfig padding;
2570   int64 low_padding[2] = {-1, -2};
2571   int64 high_padding[2] = {2, -3};
2572   for (int i = 0; i < 2; ++i) {
2573     auto dimension = padding.add_dimensions();
2574     dimension->set_edge_padding_low(low_padding[i]);
2575     dimension->set_edge_padding_high(high_padding[i]);
2576     dimension->set_interior_padding(0);
2577   }
2578   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2579       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
2580 
2581   auto module = CreateNewVerifiedModule();
2582   HloComputation* computation = module->AddEntryComputation(builder.Build());
2583 
2584   AlgebraicSimplifier simplifier(default_options_);
2585 
2586   auto has_negative_padding = [](const HloInstruction* pad) {
2587     for (auto& padding_dimension : pad->padding_config().dimensions()) {
2588       if (padding_dimension.edge_padding_low() < 0 ||
2589           padding_dimension.edge_padding_high() < 0) {
2590         return true;
2591       }
2592     }
2593     return false;
2594   };
2595 
2596   EXPECT_THAT(computation->root_instruction(),
2597               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2598   EXPECT_TRUE(has_negative_padding(pad));
2599 
2600   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2601 
2602   EXPECT_THAT(computation->root_instruction(),
2603               GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
2604   EXPECT_FALSE(
2605       has_negative_padding(computation->root_instruction()->operand(0)));
2606 }
2607 
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)2608 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
2609   // Verify that a pad instruction with interior padding on one-sized
2610   // dimensions, removes the interior padding.
2611   HloComputation::Builder builder(TestName());
2612   HloInstruction* param =
2613       builder.AddInstruction(HloInstruction::CreateParameter(
2614           0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
2615   HloInstruction* zero = builder.AddInstruction(
2616       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2617   PaddingConfig padding;
2618   for (int i = 0; i < 2; ++i) {
2619     auto dimension = padding.add_dimensions();
2620     dimension->set_edge_padding_low(3);
2621     dimension->set_edge_padding_high(3);
2622     dimension->set_interior_padding(i * 3);
2623   }
2624   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2625       ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
2626 
2627   auto module = CreateNewVerifiedModule();
2628   HloComputation* computation = module->AddEntryComputation(builder.Build());
2629 
2630   AlgebraicSimplifier simplifier(default_options_);
2631 
2632   ASSERT_THAT(computation->root_instruction(),
2633               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2634   ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
2635 
2636   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2637 
2638   EXPECT_THAT(computation->root_instruction(),
2639               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2640   EXPECT_FALSE(
2641       HasInteriorPadding(computation->root_instruction()->padding_config()));
2642 }
2643 
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)2644 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
2645   HloComputation::Builder builder(TestName());
2646   HloInstruction* param =
2647       builder.AddInstruction(HloInstruction::CreateParameter(
2648           0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
2649   builder.AddInstruction(
2650       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
2651 
2652   auto module = CreateNewVerifiedModule();
2653   HloComputation* computation = module->AddEntryComputation(builder.Build());
2654 
2655   EXPECT_THAT(computation->root_instruction(),
2656               GmockMatch(m::Reshape(m::Parameter(0))));
2657 
2658   AlgebraicSimplifier simplifier(default_options_);
2659   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2660 
2661   EXPECT_THAT(computation->root_instruction(), param);
2662 }
2663 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)2664 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
2665   HloComputation::Builder builder(TestName());
2666   const int64 dim0 = 2;
2667   const int64 dim1 = 3;
2668   HloInstruction* param =
2669       builder.AddInstruction(HloInstruction::CreateParameter(
2670           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
2671   builder.AddInstruction(HloInstruction::CreateSlice(
2672       ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
2673       /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
2674 
2675   auto module = CreateNewVerifiedModule();
2676   HloComputation* computation = module->AddEntryComputation(builder.Build());
2677 
2678   EXPECT_THAT(computation->root_instruction(),
2679               GmockMatch(m::Slice(m::Parameter(0))));
2680 
2681   AlgebraicSimplifier simplifier(default_options_);
2682   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2683 
2684   EXPECT_THAT(computation->root_instruction(), param);
2685 }
2686 
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)2687 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
2688   HloComputation::Builder builder(TestName());
2689   const int64 dim0 = 11;
2690   const int64 dim1 = 12;
2691   HloInstruction* param =
2692       builder.AddInstruction(HloInstruction::CreateParameter(
2693           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
2694   HloInstruction* original_slice =
2695       builder.AddInstruction(HloInstruction::CreateSlice(
2696           ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
2697           /*start_indices=*/{1, 2},
2698           /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
2699 
2700   builder.AddInstruction(HloInstruction::CreateSlice(
2701       ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
2702       /*start_indices=*/{2, 3},
2703       /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
2704   auto module = CreateNewVerifiedModule();
2705   HloComputation* computation = module->AddEntryComputation(builder.Build());
2706 
2707   EXPECT_THAT(computation->root_instruction(),
2708               GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
2709 
2710   AlgebraicSimplifier simplifier(default_options_);
2711   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2712 
2713   EXPECT_THAT(computation->root_instruction(),
2714               GmockMatch(m::Slice(m::Parameter(0))));
2715   EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
2716   EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
2717   EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
2718   EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
2719 }
2720 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)2721 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
2722   HloComputation::Builder builder(TestName());
2723   const int64 dim0 = 11;
2724   const int64 dim1 = 12;
2725   HloInstruction* param =
2726       builder.AddInstruction(HloInstruction::CreateParameter(
2727           0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
2728   HloInstruction* broadcast =
2729       builder.AddInstruction(HloInstruction::CreateBroadcast(
2730           ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
2731   builder.AddInstruction(HloInstruction::CreateSlice(
2732       ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
2733       /*start_indices=*/{0, 3},
2734       /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
2735   auto module = CreateNewVerifiedModule();
2736   HloComputation* computation = module->AddEntryComputation(builder.Build());
2737 
2738   EXPECT_THAT(computation->root_instruction(),
2739               GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
2740 
2741   AlgebraicSimplifier simplifier(default_options_);
2742   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2743 
2744   EXPECT_THAT(computation->root_instruction(),
2745               GmockMatch(m::Broadcast(m::Parameter(0))));
2746 }
2747 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)2748 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
2749   HloComputation::Builder builder(TestName());
2750   const int64 dim0 = 11;
2751   const int64 dim1 = 12;
2752   const int64 dim2 = 13;
2753   HloInstruction* param =
2754       builder.AddInstruction(HloInstruction::CreateParameter(
2755           0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
2756   HloInstruction* original_reshape =
2757       builder.AddInstruction(HloInstruction::CreateReshape(
2758           ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
2759 
2760   builder.AddInstruction(HloInstruction::CreateSlice(
2761       ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
2762       /*start_indices=*/{0, 0, 0},
2763       /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
2764   auto module = CreateNewVerifiedModule();
2765   HloComputation* computation = module->AddEntryComputation(builder.Build());
2766 
2767   EXPECT_THAT(computation->root_instruction(),
2768               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
2769 
2770   AlgebraicSimplifier simplifier(default_options_);
2771   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2772 
2773   EXPECT_THAT(computation->root_instruction(),
2774               GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
2775 }
2776 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)2777 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
2778   HloComputation::Builder builder(TestName());
2779   HloInstruction* param =
2780       builder.AddInstruction(HloInstruction::CreateParameter(
2781           0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
2782   HloInstruction* original_reshape =
2783       builder.AddInstruction(HloInstruction::CreateReshape(
2784           ShapeUtil::MakeShape(F32, {3600, 512}), param));
2785 
2786   builder.AddInstruction(HloInstruction::CreateSlice(
2787       ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
2788       /*start_indices=*/{0, 0},
2789       /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
2790   auto module = CreateNewVerifiedModule();
2791   HloComputation* computation = module->AddEntryComputation(builder.Build());
2792 
2793   EXPECT_THAT(computation->root_instruction(),
2794               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
2795 
2796   AlgebraicSimplifier simplifier(default_options_);
2797   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
2798 }
2799 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)2800 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
2801   auto builder = HloComputation::Builder(TestName());
2802   auto module = CreateNewVerifiedModule();
2803 
2804   Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
2805   auto keys = builder.AddInstruction(
2806       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2807   TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
2808                            module.get())
2809                    .status());
2810   HloComputation* computation = module->AddEntryComputation(builder.Build());
2811   AlgebraicSimplifier simplifier(default_options_);
2812   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2813   EXPECT_THAT(computation->root_instruction(), keys);
2814 }
2815 
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)2816 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
2817   auto builder = HloComputation::Builder(TestName());
2818   auto module = CreateNewVerifiedModule();
2819 
2820   Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
2821   Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
2822   auto keys = builder.AddInstruction(
2823       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2824   auto values0 = builder.AddInstruction(
2825       HloInstruction::CreateParameter(1, values_shape, "values0"));
2826   auto values1 = builder.AddInstruction(
2827       HloInstruction::CreateParameter(2, values_shape, "values1"));
2828   TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
2829                                {keys_shape, values_shape, values_shape}),
2830                            {keys, values0, values1}, 0, /*is_stable=*/false,
2831                            &builder, module.get())
2832                    .status());
2833   HloComputation* computation = module->AddEntryComputation(builder.Build());
2834   AlgebraicSimplifier simplifier(default_options_);
2835   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2836   EXPECT_THAT(computation->root_instruction(),
2837               GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
2838                                   m::Op().Is(values1))));
2839 }
2840 
2841 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)2842 TEST_F(AlgebraicSimplifierTest, AndTrue) {
2843   auto m = CreateNewVerifiedModule();
2844   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2845   HloComputation::Builder builder(TestName());
2846   HloInstruction* param0 = builder.AddInstruction(
2847       HloInstruction::CreateParameter(0, r0pred, "param0"));
2848   HloInstruction* const_true = builder.AddInstruction(
2849       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2850   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2851                                                       param0, const_true));
2852 
2853   auto computation = m->AddEntryComputation(builder.Build());
2854   HloInstruction* root = computation->root_instruction();
2855   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2856   AlgebraicSimplifier simplifier(default_options_);
2857   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2858   root = computation->root_instruction();
2859   EXPECT_EQ(root, param0);
2860 }
2861 
2862 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)2863 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
2864   auto m = CreateNewVerifiedModule();
2865   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2866   HloComputation::Builder builder(TestName());
2867   HloInstruction* param0 = builder.AddInstruction(
2868       HloInstruction::CreateParameter(0, r0pred, "param0"));
2869   HloInstruction* const_true = builder.AddInstruction(
2870       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2871   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2872                                                       const_true, param0));
2873 
2874   auto computation = m->AddEntryComputation(builder.Build());
2875   HloInstruction* root = computation->root_instruction();
2876   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2877   AlgebraicSimplifier simplifier(default_options_);
2878   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2879   root = computation->root_instruction();
2880   EXPECT_EQ(root, param0);
2881 }
2882 
2883 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)2884 TEST_F(AlgebraicSimplifierTest, AndFalse) {
2885   auto m = CreateNewVerifiedModule();
2886   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2887   HloComputation::Builder builder(TestName());
2888   HloInstruction* param0 = builder.AddInstruction(
2889       HloInstruction::CreateParameter(0, r0pred, "param0"));
2890   HloInstruction* const_false = builder.AddInstruction(
2891       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2892   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2893                                                       param0, const_false));
2894 
2895   auto computation = m->AddEntryComputation(builder.Build());
2896   HloInstruction* root = computation->root_instruction();
2897   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2898   AlgebraicSimplifier simplifier(default_options_);
2899   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2900   root = computation->root_instruction();
2901   EXPECT_EQ(root, const_false);
2902 }
2903 
2904 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)2905 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
2906   auto m = CreateNewVerifiedModule();
2907   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2908   HloComputation::Builder builder(TestName());
2909   HloInstruction* param0 = builder.AddInstruction(
2910       HloInstruction::CreateParameter(0, r0pred, "param0"));
2911   HloInstruction* const_false = builder.AddInstruction(
2912       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2913   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2914                                                       const_false, param0));
2915 
2916   auto computation = m->AddEntryComputation(builder.Build());
2917   HloInstruction* root = computation->root_instruction();
2918   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2919   AlgebraicSimplifier simplifier(default_options_);
2920   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2921   root = computation->root_instruction();
2922   EXPECT_EQ(root, const_false);
2923 }
2924 
2925 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)2926 TEST_F(AlgebraicSimplifierTest, OrTrue) {
2927   auto m = CreateNewVerifiedModule();
2928   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2929   HloComputation::Builder builder(TestName());
2930   HloInstruction* param0 = builder.AddInstruction(
2931       HloInstruction::CreateParameter(0, r0pred, "param0"));
2932   HloInstruction* const_true = builder.AddInstruction(
2933       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2934   builder.AddInstruction(
2935       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
2936 
2937   auto computation = m->AddEntryComputation(builder.Build());
2938   HloInstruction* root = computation->root_instruction();
2939   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2940   AlgebraicSimplifier simplifier(default_options_);
2941   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2942   root = computation->root_instruction();
2943   EXPECT_EQ(root, const_true);
2944 }
2945 
2946 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)2947 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
2948   auto m = CreateNewVerifiedModule();
2949   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2950   HloComputation::Builder builder(TestName());
2951   HloInstruction* param0 = builder.AddInstruction(
2952       HloInstruction::CreateParameter(0, r0pred, "param0"));
2953   HloInstruction* const_true = builder.AddInstruction(
2954       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2955   builder.AddInstruction(
2956       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
2957 
2958   auto computation = m->AddEntryComputation(builder.Build());
2959   HloInstruction* root = computation->root_instruction();
2960   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2961   AlgebraicSimplifier simplifier(default_options_);
2962   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2963   root = computation->root_instruction();
2964   EXPECT_EQ(root, const_true);
2965 }
2966 
2967 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)2968 TEST_F(AlgebraicSimplifierTest, OrFalse) {
2969   auto m = CreateNewVerifiedModule();
2970   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2971   HloComputation::Builder builder(TestName());
2972   HloInstruction* param0 = builder.AddInstruction(
2973       HloInstruction::CreateParameter(0, r0pred, "param0"));
2974   HloInstruction* const_false = builder.AddInstruction(
2975       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2976   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
2977                                                       param0, const_false));
2978 
2979   auto computation = m->AddEntryComputation(builder.Build());
2980   HloInstruction* root = computation->root_instruction();
2981   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2982   AlgebraicSimplifier simplifier(default_options_);
2983   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2984   root = computation->root_instruction();
2985   EXPECT_EQ(root, param0);
2986 }
2987 
2988 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)2989 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
2990   auto m = CreateNewVerifiedModule();
2991   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2992   HloComputation::Builder builder(TestName());
2993   HloInstruction* param0 = builder.AddInstruction(
2994       HloInstruction::CreateParameter(0, r0pred, "param0"));
2995   HloInstruction* const_false = builder.AddInstruction(
2996       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2997   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
2998                                                       const_false, param0));
2999 
3000   auto computation = m->AddEntryComputation(builder.Build());
3001   HloInstruction* root = computation->root_instruction();
3002   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3003   AlgebraicSimplifier simplifier(default_options_);
3004   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3005   root = computation->root_instruction();
3006   EXPECT_EQ(root, param0);
3007 }
3008 
3009 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
3010 // convolution's Window.
3011 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3012   ConvPaddingTestcase(absl::string_view padding,
3013                       absl::string_view orig_conv_window,
3014                       absl::string_view expected_conv_window)
3015       : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
3016                             /*pad_value=*/0) {}
3017 
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3018   ConvPaddingTestcase(absl::string_view padding,
3019                       absl::string_view orig_conv_window,
3020                       absl::string_view expected_conv_window, float pad_value)
3021       : padding(padding),
3022         orig_conv_window(orig_conv_window),
3023         expected_conv_window(expected_conv_window),
3024         pad_value(pad_value) {}
3025 
ToStringxla::__anonac242c730111::ConvPaddingTestcase3026   string ToString() const {
3027     return absl::StrFormat(
3028         "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
3029         "pad_value=%f",
3030         padding, orig_conv_window, expected_conv_window, pad_value);
3031   }
3032 
3033   string padding;
3034   string orig_conv_window;
3035   string expected_conv_window;
3036   float pad_value;
3037 };
3038 
3039 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
3040 // computation that does
3041 //
3042 //   conv(pad(param0, padding=padding), param1), window=orig_conv_window
3043 //
3044 // gets transformed by AlgebraicSimplifier to
3045 //
3046 //   conv(param0, param1), window=expected_conv_window
3047 //
3048 // or, if expected_conv_window is the empty string, checks that
3049 // AlgebraicSimplifier does *not* transform the original convolution.
3050 class ConvInputPaddingTest
3051     : public AlgebraicSimplifierTest,
3052       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3053 
3054 INSTANTIATE_TEST_SUITE_P(
3055     ConvInputPaddingTestCases, ConvInputPaddingTest,
3056     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3057         // Merge this edge padding into the conv.
3058         {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
3059         // Merge this edge padding with the conv's edge padding.
3060         {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
3061         // Merge this interior-padded kPad with the unpadded conv.  The 3x6
3062         // interior padding gets transformed to 4x7 conv lhs dilation.
3063         {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
3064         // kPad has dilation on one dim, conv has it on the other; merge them.
3065         {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
3066         // kPad has dilation and edge padding on one dim, conv has them on the
3067         // other; merge them.
3068         {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
3069          "pad=0_1x3_0 lhs_dilate=2x10"},
3070 
3071         // Don't transform if the pad value is nonzero.
3072         {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
3073 
3074         // We refuse to transform the following because on some dimension, one
3075         // of the kPad and conv has dilation and the other has some sort of
3076         // padding.
3077         {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
3078         {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
3079         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3080         {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
3081         {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
3082         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3083 
3084         // We can't merge feature or batch padding into the conv.
3085         {"1_0x0_0x0_0x0_0", "", ""},
3086         {"0_0x1_0x0_0x0_0", "", ""},
3087     }));
3088 
TEST_P(ConvInputPaddingTest,DoTest)3089 TEST_P(ConvInputPaddingTest, DoTest) {
3090   ConvPaddingTestcase testcase = GetParam();
3091 
3092   // It would be better to put the testcase's ToString into the test name, but
3093   // gUnit has constraints on what can go into test names, and any reasonable
3094   // implementation of ToString() seems to violate them.
3095   SCOPED_TRACE(testcase.ToString());
3096 
3097   auto builder = HloComputation::Builder(TestName());
3098   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3099       0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}),  // bf01
3100       "input"));
3101   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3102       LiteralUtil::CreateR0(testcase.pad_value)));
3103 
3104   PaddingConfig padding_config =
3105       ParsePaddingConfig(testcase.padding).ValueOrDie();
3106   auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3107       ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
3108                                     padding_config)
3109           .ValueOrDie(),
3110       input, pad_value, padding_config));
3111 
3112   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3113       1,
3114       ShapeUtil::MakeShape(
3115           F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}),  // io01
3116       "input"));
3117 
3118   ConvolutionDimensionNumbers dnums =
3119       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3120   Window window =
3121       ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
3122           .ValueOrDie();
3123   builder.AddInstruction(HloInstruction::CreateConvolve(
3124       ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
3125                                          /*feature_group_count=*/1,
3126                                          /*batch_group_count=*/1, window, dnums)
3127           .ValueOrDie(),
3128       lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
3129       window, dnums, DefaultPrecisionConfig(2)));
3130   auto module = CreateNewVerifiedModule();
3131   module->AddEntryComputation(builder.Build());
3132 
3133   AlgebraicSimplifier simplifier(default_options_);
3134   if (testcase.expected_conv_window.empty()) {
3135     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3136   } else {
3137     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3138     auto* conv = module->entry_computation()->root_instruction();
3139     SCOPED_TRACE(module->ToString());
3140     ASSERT_THAT(conv,
3141                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3142     EXPECT_EQ(window_util::ToString(conv->window()),
3143               absl::StrCat("size=3x3 ", testcase.expected_conv_window));
3144   }
3145 }
3146 
3147 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
3148 // computation that does
3149 //
3150 //   conv(param0, pad(param1, padding=padding)), window=orig_conv_window
3151 //
3152 // gets transformed by AlgebraicSimplifier to
3153 //
3154 //   conv(param0, param1), window=expected_conv_window
3155 //
3156 // or, if expected_conv_window is the empty string, checks that
3157 // AlgebraicSimplifier does *not* transform the original convolution.
3158 class ConvFilterPaddingTest
3159     : public AlgebraicSimplifierTest,
3160       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3161 
3162 INSTANTIATE_TEST_SUITE_P(
3163     ConvFilterPaddingTestCases, ConvFilterPaddingTest,
3164     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3165         // Can only merge interior padding on the filter's spatial dimensions;
3166         // all
3167         // other paddings (edge padding and interior padding on the channel
3168         // dims)
3169         // should be rejected out of hand.
3170         {"1_0_0x0_0_0x0_0x0_0", "", ""},
3171         {"0_1_0x0_0_0x0_0x0_0", "", ""},
3172         {"0_0_1x0_0_0x0_0x0_0", "", ""},
3173         {"0_0_0x1_0_0x0_0x0_0", "", ""},
3174         {"0_0_0x0_1_0x0_0x0_0", "", ""},
3175         {"0_0_0x0_0_1x0_0x0_0", "", ""},
3176         {"0_0_0x0_0_0x1_0x0_0", "", ""},
3177         {"0_0_0x0_0_0x0_1x0_0", "", ""},
3178         {"0_0_0x0_0_0x0_0x1_0", "", ""},
3179         {"0_0_0x0_0_0x0_0x0_1", "", ""},
3180 
3181         // Interior padding on channel dims can be merged into the conv, so long
3182         // as the conv and pad don't have interior padding on the same dim.
3183         {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
3184         {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
3185         {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
3186         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
3187         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
3188 
3189         // Can't merge if for a given dim there's interior padding on both the
3190         // pad and conv.
3191         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
3192         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
3193 
3194         // Don't transform if the pad value is nonzero.
3195         {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
3196     }));
3197 
TEST_P(ConvFilterPaddingTest,DoIt)3198 TEST_P(ConvFilterPaddingTest, DoIt) {
3199   ConvPaddingTestcase testcase = GetParam();
3200 
3201   // It would be better to put the testcase's ToString into the test name, but
3202   // gUnit has constraints on what can go into test names, and any reasonable
3203   // implementation of ToString() seems to violate them.
3204   SCOPED_TRACE(testcase.ToString());
3205 
3206   auto builder = HloComputation::Builder(TestName());
3207   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3208       LiteralUtil::CreateR0(testcase.pad_value)));
3209   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3210       1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}),  // io01
3211       "input"));
3212   PaddingConfig padding_config =
3213       ParsePaddingConfig(testcase.padding).ValueOrDie();
3214   auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3215       ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
3216                                     padding_config)
3217           .ValueOrDie(),
3218       filter, pad_value, padding_config));
3219 
3220   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3221       0,
3222       ShapeUtil::MakeShape(
3223           F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}),  // bf01
3224       "input"));
3225 
3226   ConvolutionDimensionNumbers dnums =
3227       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3228   Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
3229                                               rhs_pad->shape().dimensions(2),
3230                                               rhs_pad->shape().dimensions(3),
3231                                               testcase.orig_conv_window))
3232                       .ValueOrDie();
3233 
3234   // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
3235   // after the transformation.
3236   PrecisionConfig precision_config;
3237   precision_config.add_operand_precision(PrecisionConfig::HIGH);
3238   precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
3239 
3240   builder.AddInstruction(HloInstruction::CreateConvolve(
3241       ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
3242                                          /*feature_group_count=*/1,
3243                                          /*batch_group_count=*/1, window, dnums)
3244           .ValueOrDie(),
3245       input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
3246       window, dnums, precision_config));
3247 
3248   auto module = CreateNewVerifiedModule();
3249   module->AddEntryComputation(builder.Build());
3250 
3251   AlgebraicSimplifier simplifier(default_options_);
3252   if (testcase.expected_conv_window.empty()) {
3253     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3254   } else {
3255     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3256     auto* conv = module->entry_computation()->root_instruction();
3257     SCOPED_TRACE(module->ToString());
3258     ASSERT_THAT(conv,
3259                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3260     EXPECT_EQ(window_util::ToString(conv->window()),
3261               absl::StrFormat("size=%dx%d %s",
3262                               conv->operand(1)->shape().dimensions(2),
3263                               conv->operand(1)->shape().dimensions(3),
3264                               testcase.expected_conv_window));
3265     EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
3266                     ->precision_config()
3267                     .operand_precision(),
3268                 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
3269   }
3270 }
3271 
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)3272 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
3273   struct ConvTestOptions {
3274     int in_batch = 10;
3275     int in_height = 2;
3276     int in_width = 2;
3277     int in_channels = 3;
3278     int f_width = 1;
3279     int f_height = 1;
3280     int f_output_channels = 10;
3281     int row_stride = 1;
3282     int row_padding = 0;
3283     int col_stride = 1;
3284     int col_padding = 0;
3285     bool input_minor_to_major_layout = false;
3286     bool filter_minor_to_major_layout = false;
3287     bool output_minor_to_major_layout = false;
3288 
3289     const char* dim_order = "NHWC";         // can use chars NHWC in any order.
3290     const char* kernel_dim_order = "HWIO";  // can use chars HWIO in any order.
3291 
3292     ConvTestOptions& Reset() {
3293       *this = ConvTestOptions();
3294       return *this;
3295     }
3296   };
3297 
3298   ConvTestOptions options;
3299 
3300   // Builds a convolution from <options> and runs algebraic simplification on
3301   // the computation. Returns a string description of the result of
3302   // simplification.
3303   auto build_and_simplify = [&]() -> string {
3304     HloComputation::Builder b(TestName());
3305 
3306     Window window;
3307     auto* f_dim_1 = window.add_dimensions();
3308     f_dim_1->set_size(options.f_height);
3309     f_dim_1->set_stride(options.row_stride);
3310     f_dim_1->set_padding_low(options.row_padding);
3311     f_dim_1->set_padding_high(options.row_padding);
3312     f_dim_1->set_window_dilation(1);
3313     f_dim_1->set_base_dilation(1);
3314     auto* f_dim_2 = window.add_dimensions();
3315     f_dim_2->set_size(options.f_width);
3316     f_dim_2->set_stride(options.col_stride);
3317     f_dim_2->set_padding_low(options.col_padding);
3318     f_dim_2->set_padding_high(options.col_padding);
3319     f_dim_2->set_window_dilation(1);
3320     f_dim_2->set_base_dilation(1);
3321 
3322     ConvolutionDimensionNumbers dnums;
3323     std::vector<int64> in_dims;
3324     int in_channel_idx = -1;
3325     // filled in later
3326     dnums.add_input_spatial_dimensions(-1);
3327     dnums.add_output_spatial_dimensions(-1);
3328     dnums.add_input_spatial_dimensions(-1);
3329     dnums.add_output_spatial_dimensions(-1);
3330     for (int i = 0; i < strlen(options.dim_order); ++i) {
3331       char ch = options.dim_order[i];
3332       if (ch == 'N') {
3333         dnums.set_input_batch_dimension(i);
3334         dnums.set_output_batch_dimension(i);
3335         in_dims.push_back(options.in_batch);
3336       } else if (ch == 'H') {
3337         dnums.set_input_spatial_dimensions(0, i);
3338         dnums.set_output_spatial_dimensions(0, i);
3339         in_dims.push_back(options.in_height);
3340       } else if (ch == 'W') {
3341         dnums.set_input_spatial_dimensions(1, i);
3342         dnums.set_output_spatial_dimensions(1, i);
3343         in_dims.push_back(options.in_width);
3344       } else if (ch == 'C') {
3345         dnums.set_input_feature_dimension(i);
3346         dnums.set_output_feature_dimension(i);
3347         in_dims.push_back(options.in_channels);
3348         in_channel_idx = i;
3349       }
3350     }
3351 
3352     std::vector<int64> f_dims;
3353     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
3354     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
3355     for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
3356       char ch = options.kernel_dim_order[i];
3357       if (ch == 'H') {
3358         dnums.set_kernel_spatial_dimensions(0, i);
3359         f_dims.push_back(options.f_height);
3360       } else if (ch == 'W') {
3361         dnums.set_kernel_spatial_dimensions(1, i);
3362         f_dims.push_back(options.f_width);
3363       } else if (ch == 'I') {
3364         dnums.set_kernel_input_feature_dimension(i);
3365         f_dims.push_back(options.in_channels);
3366       } else if (ch == 'O') {
3367         dnums.set_kernel_output_feature_dimension(i);
3368         f_dims.push_back(options.f_output_channels);
3369       }
3370     }
3371 
3372     auto out_dims = in_dims;
3373     out_dims[in_channel_idx] = options.f_output_channels;
3374 
3375     auto make_shape = [](absl::Span<const int64> dims,
3376                          bool minor_to_major_layout) {
3377       if (minor_to_major_layout) {
3378         return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
3379       } else {
3380         return ShapeUtil::MakeShape(F32, dims);
3381       }
3382     };
3383     auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
3384     auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
3385     auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
3386 
3387     HloInstruction* input =
3388         b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
3389     HloInstruction* filter =
3390         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
3391 
3392     b.AddInstruction(HloInstruction::CreateConvolve(
3393         out_shape, input, filter,
3394         /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
3395         DefaultPrecisionConfig(2)));
3396 
3397     // TODO(b/80488902): verify this module.
3398     auto module = CreateNewUnverifiedModule();
3399     auto* computation = module->AddEntryComputation(b.Build());
3400 
3401     AlgebraicSimplifierOptions simplifier_options;
3402     simplifier_options.set_is_layout_sensitive(true);
3403     AlgebraicSimplifier simplifier(simplifier_options);
3404     if (!simplifier.Run(module.get()).ValueOrDie()) {
3405       return "NO_CHANGE";
3406     }
3407     auto* root = computation->root_instruction();
3408     if (root->opcode() == HloOpcode::kBitcast &&
3409         root->operand(0)->opcode() == HloOpcode::kDot) {
3410       auto lhs_shape = root->operand(0)->operand(0)->shape();
3411       auto rhs_shape = root->operand(0)->operand(1)->shape();
3412       return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
3413                           absl::StrJoin(rhs_shape.dimensions(), "x"));
3414     }
3415     return "UNEXPECTED CHANGE";
3416   };
3417 
3418   // Default options are the simplest case and succeed.
3419   options.Reset();
3420   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3421 
3422   // Swapping dim spatial and batch order works.
3423   options.Reset().dim_order = "NWHC";
3424   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3425   options.Reset().dim_order = "WHNC";
3426   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3427   // Channel dimension earlier fails.
3428   options.Reset().dim_order = "HWCN";
3429   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3430   options.Reset().dim_order = "CHWN";
3431   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3432 
3433   // Filtering dims spatial dims can be anywhere, since they are 1x1.
3434   options.Reset().kernel_dim_order = "WHIO";
3435   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3436   options.Reset().kernel_dim_order = "IWOH";
3437   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3438   options.Reset().kernel_dim_order = "IWHO";
3439   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3440   // But moving output channel before input channel fails.
3441   options.Reset().kernel_dim_order = "HWOI";
3442   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3443   options.Reset().kernel_dim_order = "WHOI";
3444   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3445   options.Reset().kernel_dim_order = "OWIH";
3446   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3447   options.Reset().kernel_dim_order = "OWHI";
3448   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3449 
3450   // Combine different dim and kernel dim orders.
3451   options.Reset().kernel_dim_order = "IWHO";
3452   options.dim_order = "WHNC";
3453   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3454 
3455   // Test invalid cases from wrong filter size, strides, or padding.
3456   options.Reset().f_width = 2;
3457   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3458   options.Reset().f_height = 2;
3459   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3460   options.Reset().row_stride = 2;
3461   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3462   options.Reset().col_stride = 2;
3463   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3464   options.Reset().col_padding = 1;
3465   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3466   options.Reset().row_padding = 1;
3467   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3468 
3469   // The default dim_order is "NHWC". Col-major layout makes C the most major.
3470   options.Reset().input_minor_to_major_layout = true;
3471   options.output_minor_to_major_layout = true;
3472   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3473 
3474   // The input and output have different layouts.
3475   options.Reset().input_minor_to_major_layout = true;
3476   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3477 
3478   // C is most minor, and I is more major than O.
3479   options.Reset().input_minor_to_major_layout = true;
3480   options.filter_minor_to_major_layout = true;
3481   options.output_minor_to_major_layout = true;
3482   options.dim_order = "CHWN";
3483   options.kernel_dim_order = "OIHW";
3484   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3485 
3486   // C is not the most minor dimension.
3487   options.Reset().input_minor_to_major_layout = true;
3488   options.filter_minor_to_major_layout = true;
3489   options.output_minor_to_major_layout = true;
3490   options.dim_order = "HWNC";
3491   options.kernel_dim_order = "OIHW";
3492   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3493 
3494   // I is more minor than O.
3495   options.Reset().input_minor_to_major_layout = true;
3496   options.filter_minor_to_major_layout = true;
3497   options.output_minor_to_major_layout = true;
3498   options.dim_order = "CHWN";
3499   options.kernel_dim_order = "IOHW";
3500   EXPECT_EQ("NO_CHANGE", build_and_simplify());
3501 }
3502 
3503 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
3504 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)3505 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
3506   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
3507   HloComputation::Builder builder(TestName());
3508   HloInstruction* scalar_param = builder.AddInstruction(
3509       HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
3510 
3511   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
3512   HloInstruction* broadcast = builder.AddInstruction(
3513       HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
3514 
3515   Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
3516   HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
3517       slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
3518 
3519   auto module = CreateNewVerifiedModule();
3520   auto computation = module->AddEntryComputation(builder.Build());
3521 
3522   HloInstruction* root = computation->root_instruction();
3523   EXPECT_EQ(root, slice);
3524   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
3525 
3526   AlgebraicSimplifier simplifier(default_options_);
3527 
3528   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3529 
3530   // Running simplification again should not result in any further changes.
3531   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3532   EXPECT_THAT(computation->root_instruction(),
3533               GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
3534                              .WithShapeEqualTo(&slice_shape)));
3535 }
3536 
3537 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
3538 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)3539 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
3540   HloComputation::Builder builder(TestName());
3541   HloInstruction* forty_two = builder.AddInstruction(
3542       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
3543 
3544   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
3545   HloInstruction* broadcast = builder.AddInstruction(
3546       HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
3547 
3548   HloInstruction* transpose =
3549       builder.AddInstruction(HloInstruction::CreateTranspose(
3550           ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
3551 
3552   Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
3553   HloInstruction* reshape = builder.AddInstruction(
3554       HloInstruction::CreateReshape(reshape_shape, transpose));
3555 
3556   auto module = CreateNewVerifiedModule();
3557   auto computation = module->AddEntryComputation(builder.Build());
3558 
3559   HloInstruction* root = computation->root_instruction();
3560   EXPECT_EQ(root, reshape);
3561   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
3562 
3563   AlgebraicSimplifier simplifier(default_options_);
3564   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3565   EXPECT_THAT(computation->root_instruction(),
3566               GmockMatch(m::Broadcast(m::Op().Is(forty_two))
3567                              .WithShapeEqualTo(&reshape_shape)));
3568 }
3569 
3570 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)3571 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
3572   // TODO(b/80488902): verify this module.
3573   auto module = CreateNewUnverifiedModule();
3574   HloComputation::Builder builder(TestName());
3575 
3576   // Create operand to the pad.
3577   HloInstruction* operand =
3578       builder.AddInstruction(HloInstruction::CreateParameter(
3579           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
3580 
3581   // Create the pad.
3582   PaddingConfig padding = MakeNoPaddingConfig(4);
3583   padding.mutable_dimensions(1)->set_edge_padding_low(1);
3584   padding.mutable_dimensions(3)->set_edge_padding_high(2);
3585 
3586   HloInstruction* pad_value = builder.AddInstruction(
3587       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3588   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3589       ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
3590 
3591   // Create add computation.
3592   HloComputation* add_computation = nullptr;
3593   {
3594     HloComputation::Builder builder(TestName() + ".add");
3595     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
3596     HloInstruction* p0 = builder.AddInstruction(
3597         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
3598     HloInstruction* p1 = builder.AddInstruction(
3599         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
3600     builder.AddInstruction(
3601         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
3602     add_computation = module->AddEmbeddedComputation(builder.Build());
3603   }
3604 
3605   // Create the reduce-window.
3606   Window window;
3607   for (int64 i = 0; i < pad->shape().rank(); ++i) {
3608     auto* dim = window.add_dimensions();
3609     dim->set_size(1);
3610     dim->set_padding_low(10);
3611     dim->set_padding_high(100);
3612     dim->set_window_dilation(1);
3613     dim->set_base_dilation(1);
3614   }
3615   const Shape reduce_window_shape =
3616       ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
3617   HloInstruction* reduce_init_value = builder.AddInstruction(
3618       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3619   HloInstruction* reduce_window =
3620       builder.AddInstruction(HloInstruction::CreateReduceWindow(
3621           reduce_window_shape, pad, reduce_init_value, window,
3622           add_computation));
3623 
3624   // Build the computation and run the simplifier.
3625   auto computation = module->AddEntryComputation(builder.Build());
3626   HloInstruction* root = computation->root_instruction();
3627   EXPECT_EQ(root, reduce_window);
3628   AlgebraicSimplifier simplifier(default_options_);
3629   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3630 
3631   // Running simplification again should not result in any further changes.
3632   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3633 
3634   // Verify the result
3635   root = computation->root_instruction();
3636   EXPECT_THAT(root,
3637               GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant())));
3638   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
3639       << ShapeUtil::HumanString(root->shape()) << " vs "
3640       << ShapeUtil::HumanString(reduce_window_shape);
3641   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
3642   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
3643   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
3644   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
3645   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
3646   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
3647   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
3648   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
3649 }
3650 
3651 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
3652 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)3653 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
3654   // TODO(b/80488902): verify this module.
3655   auto module = CreateNewUnverifiedModule();
3656   HloComputation::Builder builder(TestName());
3657 
3658   // Create operand to the pad.
3659   HloInstruction* parameter =
3660       builder.AddInstruction(HloInstruction::CreateParameter(
3661           0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
3662 
3663   // Create the pad.
3664   PaddingConfig padding = MakeNoPaddingConfig(4);
3665   padding.mutable_dimensions(1)->set_edge_padding_low(1);
3666   padding.mutable_dimensions(3)->set_edge_padding_high(2);
3667 
3668   HloInstruction* pad_value = builder.AddInstruction(
3669       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3670   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3671       ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
3672 
3673   HloInstruction* convert =
3674       builder.AddInstruction(HloInstruction::CreateConvert(
3675           ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
3676 
3677   // Create add computation.
3678   HloComputation* add_computation = nullptr;
3679   {
3680     HloComputation::Builder builder(TestName() + ".add");
3681     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
3682     HloInstruction* p0 = builder.AddInstruction(
3683         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
3684     HloInstruction* p1 = builder.AddInstruction(
3685         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
3686     builder.AddInstruction(
3687         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
3688     add_computation = module->AddEmbeddedComputation(builder.Build());
3689   }
3690 
3691   // Create the reduce-window.
3692   Window window;
3693   for (int64 i = 0; i < pad->shape().rank(); ++i) {
3694     auto* dim = window.add_dimensions();
3695     dim->set_size(1);
3696     dim->set_padding_low(10);
3697     dim->set_padding_high(100);
3698     dim->set_window_dilation(1);
3699     dim->set_base_dilation(1);
3700   }
3701   const Shape reduce_window_shape =
3702       ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
3703   HloInstruction* reduce_init_value = builder.AddInstruction(
3704       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3705   HloInstruction* reduce_window =
3706       builder.AddInstruction(HloInstruction::CreateReduceWindow(
3707           reduce_window_shape, convert, reduce_init_value, window,
3708           add_computation));
3709 
3710   // Build the computation and run the simplifier.
3711   auto computation = module->AddEntryComputation(builder.Build());
3712   HloInstruction* root = computation->root_instruction();
3713   EXPECT_EQ(root, reduce_window);
3714   AlgebraicSimplifier simplifier(default_options_);
3715   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3716 
3717   // Running simplification again should not result in any further changes.
3718   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3719 
3720   // Verify the result
3721   root = computation->root_instruction();
3722   EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
3723                                                m::Constant())));
3724   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
3725       << ShapeUtil::HumanString(root->shape()) << " vs "
3726       << ShapeUtil::HumanString(reduce_window_shape);
3727   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
3728   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
3729   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
3730   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
3731   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
3732   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
3733   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
3734   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
3735 }
3736 
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)3737 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
3738   HloComputation::Builder builder(TestName());
3739   const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
3740   HloInstruction* a =
3741       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
3742   builder.AddInstruction(
3743       HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
3744 
3745   auto module = CreateNewVerifiedModule();
3746   auto computation = module->AddEntryComputation(builder.Build());
3747 
3748   AlgebraicSimplifier simplifier(default_options_);
3749   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3750 
3751   HloInstruction* root = computation->root_instruction();
3752   EXPECT_EQ(a, root);
3753   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
3754 }
3755 
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)3756 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
3757   // Dots add computations to the parent module. Test that, when the HloModule's
3758   // computations are updated, then iterator invalidation doesn't occur
3759   // when running on subsequent computations.
3760   auto m = CreateNewVerifiedModule();
3761   Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
3762   HloComputation::Builder builder(TestName() + ".Dot");
3763   HloInstruction* x =
3764       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
3765   HloInstruction* y =
3766       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
3767   DotDimensionNumbers dot_dnums;
3768   dot_dnums.add_lhs_batch_dimensions(0);
3769   dot_dnums.add_rhs_batch_dimensions(0);
3770   builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
3771                                                    DefaultPrecisionConfig(2)));
3772   std::unique_ptr<HloComputation> dot_computation(builder.Build());
3773 
3774   HloComputation::Builder call_builder(TestName() + ".Call");
3775   HloInstruction* zero = call_builder.AddInstruction(
3776       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
3777   HloInstruction* one = call_builder.AddInstruction(
3778       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
3779   call_builder.AddInstruction(
3780       HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
3781 
3782   m->AddEmbeddedComputation(std::move(dot_computation));
3783   m->AddEntryComputation(call_builder.Build());
3784   AlgebraicSimplifier simplifier(default_options_);
3785   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3786 }
3787 
3788 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)3789 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
3790   auto m = CreateNewVerifiedModule();
3791   HloComputation::Builder builder(TestName());
3792   const float constant_scalar = 7.3f;
3793   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
3794   Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
3795                         LiteralUtil::CreateR1<float>(constant_vector)};
3796   Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
3797   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
3798 
3799   auto computation = m->AddEntryComputation(builder.Build());
3800 
3801   AlgebraicSimplifier simplifier(default_options_);
3802   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3803   EXPECT_THAT(computation->root_instruction(),
3804               GmockMatch(m::Tuple(m::Constant(), m::Constant())));
3805 }
3806 
3807 // A dynamic-slice is trivial if its start indices are all zeroes and the size
3808 // of its input equals the size of its output.  In this case, the dynamic slice
3809 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)3810 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
3811   auto m = CreateNewVerifiedModule();
3812   HloComputation::Builder builder(TestName());
3813 
3814   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
3815   std::vector<HloInstruction*> params;
3816   for (int i = 0; i < 3; ++i) {
3817     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
3818         i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
3819   }
3820   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
3821       shape,
3822       builder.AddInstruction(
3823           HloInstruction::CreateParameter(0, shape, "slice_from")),
3824       params,
3825       /*slice_sizes=*/{10, 100, 1000}));
3826 
3827   auto computation = m->AddEntryComputation(builder.Build());
3828   AlgebraicSimplifier simplifier(default_options_);
3829   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3830   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
3831 }
3832 
3833 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
3834 // size of its "update" equals the size of its output.  In this case, the
3835 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)3836 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
3837   auto m = CreateNewVerifiedModule();
3838   HloComputation::Builder builder(TestName());
3839 
3840   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
3841   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
3842 
3843   std::vector<HloInstruction*> slice_indices, update_indices;
3844   for (int i = 0; i < 3; ++i) {
3845     slice_indices.push_back(
3846         builder.AddInstruction(HloInstruction::CreateParameter(
3847             i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
3848     update_indices.push_back(
3849         builder.AddInstruction(HloInstruction::CreateParameter(
3850             i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
3851   }
3852   HloInstruction* slice =
3853       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
3854           slice_shape,
3855           builder.AddInstruction(
3856               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
3857           slice_indices,
3858           /*slice_sizes=*/{10, 1, 1000}));
3859 
3860   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3861       slice_shape,
3862       builder.AddInstruction(
3863           HloInstruction::CreateParameter(4, slice_shape, "to_update")),
3864       slice, update_indices));
3865 
3866   auto computation = m->AddEntryComputation(builder.Build());
3867   AlgebraicSimplifier simplifier(default_options_);
3868   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3869   EXPECT_THAT(computation->root_instruction(),
3870               GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
3871                                          m::Parameter(), m::Parameter())));
3872 }
3873 
3874 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)3875 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
3876   auto m = CreateNewVerifiedModule();
3877   HloComputation::Builder builder(TestName());
3878   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
3879   HloInstruction* input_array = builder.AddInstruction(
3880       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
3881   HloInstruction* inner_bcast = builder.AddInstruction(
3882       HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
3883   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
3884   builder.AddInstruction(
3885       HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
3886 
3887   auto computation = m->AddEntryComputation(builder.Build());
3888   HloInstruction* root = computation->root_instruction();
3889   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3890   AlgebraicSimplifier simplifier(default_options_);
3891   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3892   root = computation->root_instruction();
3893   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3894   EXPECT_THAT(root->dimensions(), ElementsAre(2));
3895 }
3896 
3897 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)3898 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
3899   auto m = CreateNewVerifiedModule();
3900   HloComputation::Builder builder(TestName());
3901   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
3902   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
3903   HloInstruction* param0 = builder.AddInstruction(
3904       HloInstruction::CreateParameter(0, r2f32, "param0"));
3905   // The initial dimensions go to places 0 and 2 in the 3-dim array,
3906   // and to places 1 and 3 in the 4-dim array,
3907   HloInstruction* inner_bcast = builder.AddInstruction(
3908       HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
3909   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
3910   builder.AddInstruction(
3911       HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
3912 
3913   auto computation = m->AddEntryComputation(builder.Build());
3914   HloInstruction* root = computation->root_instruction();
3915   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3916   AlgebraicSimplifier simplifier(default_options_);
3917   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3918   root = computation->root_instruction();
3919   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
3920   EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
3921 }
3922 
3923 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)3924 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
3925   auto m = CreateNewVerifiedModule();
3926   HloComputation::Builder builder(TestName());
3927   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
3928   HloInstruction* iota =
3929       builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
3930   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
3931   builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
3932 
3933   auto computation = m->AddEntryComputation(builder.Build());
3934   HloInstruction* root = computation->root_instruction();
3935   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3936   AlgebraicSimplifier simplifier(default_options_);
3937   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3938   root = computation->root_instruction();
3939   EXPECT_THAT(root, GmockMatch(m::Iota()));
3940   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
3941 }
3942 
3943 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)3944 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
3945   auto m = CreateNewVerifiedModule();
3946   HloComputation::Builder builder(TestName());
3947   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
3948   HloInstruction* iota =
3949       builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
3950   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
3951   builder.AddInstruction(
3952       HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
3953 
3954   auto computation = m->AddEntryComputation(builder.Build());
3955   HloInstruction* root = computation->root_instruction();
3956   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3957   AlgebraicSimplifier simplifier(default_options_);
3958   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3959   root = computation->root_instruction();
3960   EXPECT_THAT(root, GmockMatch(m::Iota()));
3961   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
3962 }
3963 
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)3964 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
3965   const char* hlo_string = R"(
3966     HloModule module
3967 
3968     ENTRY test {
3969       param = f32[3,4] parameter(0)
3970       constant = f32[] constant(0.0)
3971       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
3972       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
3973     }
3974   )";
3975   TF_ASSERT_OK_AND_ASSIGN(auto module,
3976                           ParseAndReturnVerifiedModule(hlo_string));
3977 
3978   AlgebraicSimplifierOptions options;
3979   AlgebraicSimplifier simplifier(options);
3980   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3981   auto root = module->entry_computation()->root_instruction();
3982   EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
3983 }
3984 
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)3985 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
3986   const char* hlo_string = R"(
3987     HloModule module
3988 
3989     ENTRY test {
3990       param = f32[3,4] parameter(0)
3991       constant = f32[] constant(0.0)
3992       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
3993       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
3994     }
3995   )";
3996   TF_ASSERT_OK_AND_ASSIGN(auto module,
3997                           ParseAndReturnVerifiedModule(hlo_string));
3998 
3999   AlgebraicSimplifierOptions options;
4000   AlgebraicSimplifier simplifier(options);
4001   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4002   auto root = module->entry_computation()->root_instruction();
4003   EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
4004 }
4005 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)4006 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
4007   const char* hlo_string = R"(
4008     HloModule module
4009 
4010     ENTRY test {
4011       param = f32[3,4] parameter(0)
4012       constant = f32[] constant(0.0)
4013       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4014       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
4015     }
4016   )";
4017   TF_ASSERT_OK_AND_ASSIGN(auto module,
4018                           ParseAndReturnVerifiedModule(hlo_string));
4019 
4020   AlgebraicSimplifierOptions options;
4021   AlgebraicSimplifier simplifier(options);
4022   EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4023 }
4024 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)4025 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
4026   const char* hlo_string = R"(
4027     HloModule module
4028 
4029     ENTRY test {
4030       param = f32[3,4] parameter(0)
4031       constant = f32[] constant(0.0)
4032       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4033       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
4034     }
4035   )";
4036   TF_ASSERT_OK_AND_ASSIGN(auto module,
4037                           ParseAndReturnVerifiedModule(hlo_string));
4038 
4039   AlgebraicSimplifierOptions options;
4040   AlgebraicSimplifier simplifier(options);
4041   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4042   auto root = module->entry_computation()->root_instruction();
4043   EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
4044 }
4045 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)4046 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
4047   const char* hlo_string = R"(
4048     HloModule module
4049 
4050     ENTRY test {
4051       param = f32[1,1] parameter(0)
4052       constant = f32[] constant(0.0)
4053       pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
4054       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
4055     }
4056   )";
4057   TF_ASSERT_OK_AND_ASSIGN(auto module,
4058                           ParseAndReturnVerifiedModule(hlo_string));
4059 
4060   AlgebraicSimplifierOptions options;
4061   AlgebraicSimplifier simplifier(options);
4062   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4063   auto root = module->entry_computation()->root_instruction();
4064   EXPECT_THAT(root, GmockMatch(m::Parameter()));
4065 }
4066 
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)4067 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
4068   const char* hlo_string = R"(
4069     HloModule module
4070 
4071     ENTRY entry () -> f32[1]{0} {
4072       constant.val = f32[] constant(4)
4073       constant.pad = f32[] constant(-7)
4074       reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
4075       pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
4076       slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
4077       ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
4078     }
4079   )";
4080   TF_ASSERT_OK_AND_ASSIGN(auto module,
4081                           ParseAndReturnVerifiedModule(hlo_string));
4082 
4083   AlgebraicSimplifierOptions options;
4084   AlgebraicSimplifier simplifier(options);
4085   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4086   auto root = module->entry_computation()->root_instruction();
4087   EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
4088 }
4089 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)4090 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
4091   const char* hlo_string = R"(
4092     HloModule module
4093 
4094     ENTRY test {
4095       param.0 = f32[2] parameter(0)
4096       param.1 = f32[1] parameter(1)
4097       param.2 = f32[3] parameter(2)
4098       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4099       ROOT slice = f32[1] slice(concat), slice={[2:3]}
4100     }
4101   )";
4102   TF_ASSERT_OK_AND_ASSIGN(auto module,
4103                           ParseAndReturnVerifiedModule(hlo_string));
4104 
4105   AlgebraicSimplifierOptions options;
4106   AlgebraicSimplifier simplifier(options);
4107   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4108   auto root = module->entry_computation()->root_instruction();
4109   EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
4110 }
4111 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)4112 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
4113   const char* hlo_string = R"(
4114     HloModule module
4115 
4116     ENTRY test {
4117       param.0 = f32[2] parameter(0)
4118       param.1 = f32[1] parameter(1)
4119       param.2 = f32[3] parameter(2)
4120       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4121       ROOT slice = f32[1] slice(concat), slice={[4:5]}
4122     }
4123   )";
4124   TF_ASSERT_OK_AND_ASSIGN(auto module,
4125                           ParseAndReturnVerifiedModule(hlo_string));
4126 
4127   AlgebraicSimplifierOptions options;
4128   AlgebraicSimplifier simplifier(options);
4129   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4130   auto root = module->entry_computation()->root_instruction();
4131   EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
4132   EXPECT_EQ(root->slice_starts(0), 1);
4133   EXPECT_EQ(root->slice_limits(0), 2);
4134 }
4135 
TEST_F(AlgebraicSimplifierTest,NegateNegate)4136 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
4137   const char* hlo_string = R"(
4138     HloModule module
4139 
4140     ENTRY test {
4141       param.0 = f32[2] parameter(0)
4142       neg.0 = f32[2] negate(param.0)
4143       ROOT neg.1 = f32[2] negate(neg.0)
4144     }
4145   )";
4146   TF_ASSERT_OK_AND_ASSIGN(auto module,
4147                           ParseAndReturnVerifiedModule(hlo_string));
4148 
4149   AlgebraicSimplifierOptions options;
4150   AlgebraicSimplifier simplifier(options);
4151   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4152   auto root = module->entry_computation()->root_instruction();
4153   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
4154 }
4155 
TEST_F(AlgebraicSimplifierTest,NotNot)4156 TEST_F(AlgebraicSimplifierTest, NotNot) {
4157   const char* hlo_string = R"(
4158     HloModule module
4159 
4160     ENTRY test {
4161       param.0 = pred[2] parameter(0)
4162       not.0 = pred[2] not(param.0)
4163       ROOT not.1 = pred[2] not(not.0)
4164     }
4165   )";
4166   TF_ASSERT_OK_AND_ASSIGN(auto module,
4167                           ParseAndReturnVerifiedModule(hlo_string));
4168 
4169   AlgebraicSimplifierOptions options;
4170   AlgebraicSimplifier simplifier(options);
4171   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4172   auto root = module->entry_computation()->root_instruction();
4173   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
4174 }
4175 
4176 struct PadReduceWindowEffectiveBroadcastCase {
4177   std::vector<int64> input_spatials;
4178   std::vector<int64> symmetric_pad_spatials;
4179   std::vector<int64> reduce_window_spatials;
4180   // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
4181   //
4182   // This doesn't test any different functionality but is useful for making sure
4183   // kBroadcast nodes are well formed.
4184   bool prepend_a;
4185   bool should_become_broadcast;
4186 
ToTestCaseNamexla::__anonac242c730111::PadReduceWindowEffectiveBroadcastCase4187   string ToTestCaseName() const {
4188     return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
4189                         absl::StrJoin(symmetric_pad_spatials, ","), ";",
4190                         absl::StrJoin(reduce_window_spatials, ","), ";",
4191                         prepend_a, ";", should_become_broadcast);
4192   }
4193 };
4194 
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)4195 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
4196   *os << c.ToTestCaseName();
4197 }
4198 
4199 class PadReduceWindowEffectiveBroadcastTest
4200     : public AlgebraicSimplifierTest,
4201       public ::testing::WithParamInterface<
4202           PadReduceWindowEffectiveBroadcastCase> {};
4203 
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)4204 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
4205   auto m = CreateNewVerifiedModule();
4206   const auto& param = GetParam();
4207 
4208   // a and b are parallel bounds we can either turn into a B F S0 S1 or
4209   // `B S0 S1 F` kind of pattern.
4210   auto decorate_spatials = [&param](absl::Span<const int64> spatials, int64 a,
4211                                     int64 b) {
4212     std::vector<int64> result;
4213     if (param.prepend_a) {
4214       result.push_back(a);
4215     }
4216     for (int64 s : spatials) {
4217       result.push_back(s);
4218     }
4219     if (!param.prepend_a) {
4220       result.push_back(a);
4221     }
4222     result.push_back(b);
4223     return result;
4224   };
4225 
4226   HloComputation::Builder builder(TestName());
4227   const Shape input_shape = ShapeUtil::MakeShape(
4228       F32, decorate_spatials(param.input_spatials, 128, 2048));
4229   HloInstruction* input = builder.AddInstruction(
4230       HloInstruction::CreateParameter(0, input_shape, "input"));
4231 
4232   PaddingConfig padding = window_util::MakeSymmetricPadding(
4233       decorate_spatials(param.symmetric_pad_spatials, 0, 0));
4234   TF_ASSERT_OK_AND_ASSIGN(
4235       const Shape pad_shape,
4236       ShapeInference::InferPadShape(input->shape(),
4237                                     ShapeUtil::MakeShape(F32, {}), padding));
4238   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4239       pad_shape, input,
4240       builder.AddInstruction(
4241           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
4242       padding));
4243 
4244   HloComputation* add_computation = nullptr;
4245   {
4246     HloComputation::Builder builder(TestName() + ".add");
4247     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4248     HloInstruction* p0 = builder.AddInstruction(
4249         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4250     HloInstruction* p1 = builder.AddInstruction(
4251         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4252     builder.AddInstruction(
4253         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4254     add_computation = m->AddEmbeddedComputation(builder.Build());
4255   }
4256 
4257   Window window = window_util::MakeWindow(
4258       decorate_spatials(param.reduce_window_spatials, 1, 1));
4259   auto zero = builder.AddInstruction(
4260       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
4261   TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
4262                           ShapeInference::InferReduceWindowShape(
4263                               pad->shape(), zero->shape(), window,
4264                               add_computation->ComputeProgramShape()));
4265   builder.AddInstruction(HloInstruction::CreateReduceWindow(
4266       output_shape, pad, zero, window, add_computation));
4267 
4268   auto computation = m->AddEntryComputation(builder.Build());
4269   AlgebraicSimplifier simplifier(default_options_);
4270   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4271   ASSERT_TRUE(run_successful);
4272 
4273   EXPECT_TRUE(
4274       ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
4275 
4276   if (param.should_become_broadcast) {
4277     EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
4278   } else {
4279     EXPECT_THAT(computation->root_instruction(),
4280                 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
4281   }
4282 }
4283 
4284 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()4285 PadReduceWindowEffectiveBroadcastCases() {
4286   static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
4287       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
4288        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
4289        /*should_become_broadcast=*/true},  //
4290       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
4291        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
4292        /*should_become_broadcast=*/true},  //
4293       {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
4294        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
4295        /*should_become_broadcast=*/false},  //
4296       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
4297        /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
4298        /*should_become_broadcast=*/false},  //
4299       {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
4300        /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
4301        /*should_become_broadcast=*/false},  //
4302   };
4303   return *cases;
4304 }
4305 
4306 INSTANTIATE_TEST_SUITE_P(
4307     PadReduceWindowEffectiveBroadcastInstantiation,
4308     PadReduceWindowEffectiveBroadcastTest,
4309     ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
4310 
4311 class BatchDotStrengthReductionTest
4312     : public AlgebraicSimplifierTest,
4313       public ::testing::WithParamInterface<
4314           ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)4315 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
4316   auto module = CreateNewVerifiedModule();
4317   int m, k, n;
4318   PrimitiveType element_type;
4319   std::tie(m, k, n, element_type) = GetParam();
4320   std::vector<int64> lhs_dims = {1, 3, 5};
4321   std::vector<int64> rhs_dims = lhs_dims;
4322   std::vector<int64> output_dims = lhs_dims;
4323   if (m > 0) {
4324     lhs_dims.push_back(m);
4325     output_dims.push_back(m);
4326   }
4327   if (k > 0) {
4328     lhs_dims.push_back(k);
4329     rhs_dims.push_back(k);
4330   }
4331   if (n > 0) {
4332     rhs_dims.push_back(n);
4333     output_dims.push_back(n);
4334   }
4335   Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
4336   Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
4337   Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
4338   HloComputation::Builder builder(TestName());
4339 
4340   auto lhs = builder.AddInstruction(
4341       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
4342   auto rhs = builder.AddInstruction(
4343       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
4344   DotDimensionNumbers dot_dnums;
4345   dot_dnums.add_lhs_batch_dimensions(0);
4346   dot_dnums.add_lhs_batch_dimensions(1);
4347   dot_dnums.add_lhs_batch_dimensions(2);
4348   dot_dnums.add_rhs_batch_dimensions(0);
4349   dot_dnums.add_rhs_batch_dimensions(1);
4350   dot_dnums.add_rhs_batch_dimensions(2);
4351   if (k > 0) {
4352     dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
4353     dot_dnums.add_rhs_contracting_dimensions(3);
4354   }
4355   builder.AddInstruction(HloInstruction::CreateDot(
4356       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4357   auto computation = module->AddEntryComputation(builder.Build());
4358   AlgebraicSimplifier simplifier(default_options_);
4359   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
4360   const bool dot_should_be_transformed =
4361       m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
4362   EXPECT_EQ(changed, dot_should_be_transformed);
4363   bool has_no_dot = true;
4364   for (const auto& hlo : computation->instructions()) {
4365     if (hlo->opcode() == HloOpcode::kDot) {
4366       has_no_dot = false;
4367       break;
4368     }
4369   }
4370   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
4371 }
4372 
4373 INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation,
4374                          BatchDotStrengthReductionTest,
4375                          ::testing::Combine(::testing::Values(-1, 1, 2),
4376                                             ::testing::Values(-1, 1, 2),
4377                                             ::testing::Values(-1, 1, 2),
4378                                             ::testing::Values(F32, BF16)));
4379 
4380 class DotStrengthReductionTest
4381     : public AlgebraicSimplifierTest,
4382       public ::testing::WithParamInterface<
4383           ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)4384 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
4385   auto module = CreateNewVerifiedModule();
4386   int m, k, n;
4387   bool transpose_lhs, transpose_rhs;
4388   PrimitiveType element_type;
4389   std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
4390 
4391   Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
4392   Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
4393   Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
4394   Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
4395   Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
4396   HloComputation::Builder builder(TestName());
4397 
4398   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
4399       0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
4400   if (transpose_lhs) {
4401     lhs = builder.AddInstruction(
4402         HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
4403   }
4404   auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
4405       1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
4406   if (transpose_rhs) {
4407     rhs = builder.AddInstruction(
4408         HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
4409   }
4410   DotDimensionNumbers dot_dnums;
4411   dot_dnums.add_lhs_contracting_dimensions(1);
4412   dot_dnums.add_rhs_contracting_dimensions(0);
4413   builder.AddInstruction(HloInstruction::CreateDot(
4414       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4415   auto computation = module->AddEntryComputation(builder.Build());
4416   AlgebraicSimplifier simplifier(default_options_);
4417   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
4418   const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
4419   const bool computation_should_be_modified =
4420       dot_should_be_transformed || (transpose_lhs && transpose_rhs);
4421   EXPECT_EQ(changed, computation_should_be_modified);
4422   bool has_no_dot = true;
4423   for (const auto& hlo : computation->instructions()) {
4424     if (hlo->opcode() == HloOpcode::kDot) {
4425       has_no_dot = false;
4426       break;
4427     }
4428   }
4429   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
4430 }
4431 
4432 INSTANTIATE_TEST_SUITE_P(
4433     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
4434     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
4435                        ::testing::Values(1, 2), ::testing::Bool(),
4436                        ::testing::Bool(), ::testing::Values(F32, BF16)));
4437 
4438 struct DotOfConcatTestSpec {
4439   int64 m;
4440   int64 k;
4441   int64 n;
4442 };
4443 
4444 class DotOfConcatSimplificationTest
4445     : public AlgebraicSimplifierTest,
4446       public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
4447 
4448 // Test that we transform
4449 //  dot(const, concat(A, B, C))
4450 // to
4451 //  add(dot(const_0, A), dot(const_1, B),  dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)4452 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
4453   auto m = CreateNewVerifiedModule();
4454   HloComputation::Builder builder(TestName());
4455 
4456   DotOfConcatTestSpec spec = GetParam();
4457 
4458   ASSERT_GE(spec.k, 3);
4459 
4460   int64 k0 = spec.k / 3;
4461   int64 k1 = spec.k / 3;
4462   int64 k2 = spec.k - k0 - k1;
4463 
4464   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
4465   auto* lhs = builder.AddInstruction(
4466       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4467           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
4468 
4469   Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
4470   Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
4471   Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
4472 
4473   HloInstruction* rhs0 = builder.AddInstruction(
4474       HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
4475   HloInstruction* rhs1 = builder.AddInstruction(
4476       HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
4477   HloInstruction* rhs2 = builder.AddInstruction(
4478       HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
4479 
4480   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
4481   HloInstruction* rhs = builder.AddInstruction(
4482       HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
4483 
4484   DotDimensionNumbers dot_dnums;
4485   dot_dnums.add_lhs_contracting_dimensions(1);
4486   dot_dnums.add_rhs_contracting_dimensions(0);
4487 
4488   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
4489   builder.AddInstruction(HloInstruction::CreateDot(
4490       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4491 
4492   auto computation = m->AddEntryComputation(builder.Build());
4493   AlgebraicSimplifier simplifier(default_options_);
4494   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4495   ASSERT_TRUE(run_successful);
4496 
4497   EXPECT_TRUE(
4498       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4499 
4500   auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
4501   auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
4502   auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
4503   EXPECT_THAT(
4504       computation->root_instruction(),
4505       GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
4506 }
4507 
4508 // Test that we transform
4509 //  dot(concat(A, B, C), const)
4510 // to
4511 //  add(dot(A, const_0), dot(B, const_1),  dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)4512 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
4513   auto m = CreateNewVerifiedModule();
4514   HloComputation::Builder builder(TestName());
4515 
4516   DotOfConcatTestSpec spec = GetParam();
4517 
4518   ASSERT_GE(spec.k, 4);
4519 
4520   int64 k0 = spec.k / 4;
4521   int64 k1 = spec.k / 4;
4522   int64 k2 = spec.k / 4;
4523   int64 k3 = spec.k - k0 - k1 - k2;
4524 
4525   Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
4526   Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
4527   Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
4528   Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
4529 
4530   HloInstruction* lhs0 = builder.AddInstruction(
4531       HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
4532   HloInstruction* lhs1 = builder.AddInstruction(
4533       HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
4534   HloInstruction* lhs2 = builder.AddInstruction(
4535       HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
4536   HloInstruction* lhs3 = builder.AddInstruction(
4537       HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
4538 
4539   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
4540   HloInstruction* lhs =
4541       builder.AddInstruction(HloInstruction::CreateConcatenate(
4542           lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
4543 
4544   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
4545   auto* rhs = builder.AddInstruction(
4546       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4547           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
4548 
4549   DotDimensionNumbers dot_dnums;
4550   dot_dnums.add_lhs_contracting_dimensions(1);
4551   dot_dnums.add_rhs_contracting_dimensions(0);
4552 
4553   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
4554   builder.AddInstruction(HloInstruction::CreateDot(
4555       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4556 
4557   auto computation = m->AddEntryComputation(builder.Build());
4558   AlgebraicSimplifier simplifier(default_options_);
4559   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4560   ASSERT_TRUE(run_successful);
4561   EXPECT_TRUE(
4562       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4563 
4564   auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
4565   auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
4566   auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
4567   auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
4568   EXPECT_THAT(
4569       computation->root_instruction(),
4570       GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
4571                         match_dot_3)));
4572 }
4573 
4574 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
4575     {/*m=*/3, /*k=*/9, /*n=*/3},    //
4576     {/*m=*/3, /*k=*/20, /*n=*/3},   //
4577     {/*m=*/1, /*k=*/18, /*n=*/5},   //
4578     {/*m=*/20, /*k=*/20, /*n=*/1},  //
4579     {/*m=*/1, /*k=*/16, /*n=*/1},   //
4580 };
4581 
4582 // Test that DynamicUpdateSlice update param with any dimension equal to zero
4583 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)4584 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
4585   auto m = CreateNewVerifiedModule();
4586   HloComputation::Builder builder(TestName());
4587   const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
4588   HloInstruction* const operand = builder.AddInstruction(
4589       HloInstruction::CreateParameter(0, dslice_shape, "operand"));
4590   const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
4591   HloInstruction* const update = builder.AddInstruction(
4592       HloInstruction::CreateParameter(1, update_shape, "update"));
4593   HloInstruction* const start_indices = builder.AddInstruction(
4594       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
4595   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4596       dslice_shape, operand, update,
4597       std::initializer_list<HloInstruction*>({start_indices})));
4598   const HloComputation* const computation =
4599       m->AddEntryComputation(builder.Build());
4600 
4601   AlgebraicSimplifier simplifier(default_options_);
4602   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4603   EXPECT_THAT(computation->root_instruction(), operand);
4604 }
4605 
4606 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
4607                          DotOfConcatSimplificationTest,
4608                          ::testing::ValuesIn(kDotOfConcatTestSpecs));
4609 
4610 struct DotOfGatherTestSpec {
4611   int64 m;
4612   int64 k;
4613   int64 n;
4614   int s;      // start index for dynamic slice on the non-contracting dimension
4615   int64 lcd;  // left contracting dimension
4616   int64 rcd;  // right contracting dimension
4617   bool neg;   // is negative testcase
4618 };
4619 
4620 class DotOfGatherSimplificationTest
4621     : public AlgebraicSimplifierTest,
4622       public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
4623 
4624 // input: dot(DS(ctA), ctB))
4625 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
4626 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
4627 // output: DS(dot(ctA, ctB))
4628 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)4629 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
4630   auto m = CreateNewVerifiedModule();
4631   HloComputation::Builder builder(TestName());
4632 
4633   DotOfGatherTestSpec spec = GetParam();
4634 
4635   ASSERT_LE(spec.s, spec.m);
4636 
4637   // For negative tests, increase k of the dynamic slice argument to prevent the
4638   // optimization (constants ctA, ctB must have equal contracting dimensions).
4639   int64 k_increase = spec.neg ? 5 : 0;
4640   int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
4641   int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
4642   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
4643   auto* lhs = builder.AddInstruction(
4644       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4645           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
4646           /*cols=*/lhs_cols)));
4647 
4648   int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
4649   int32 start_col = (spec.lcd == 0) ? spec.s : 0;
4650   std::vector<HloInstruction*> start_indices = {
4651       builder.AddInstruction(HloInstruction::CreateConstant(
4652           LiteralUtil::CreateR0<int32>(start_row))),
4653       builder.AddInstruction(HloInstruction::CreateConstant(
4654           LiteralUtil::CreateR0<int32>(start_col)))};
4655   int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
4656   int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
4657   std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
4658   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
4659   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4660       ds_shape, lhs, start_indices, slice_sizes));
4661 
4662   int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
4663   int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
4664   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
4665   auto* rhs = builder.AddInstruction(
4666       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4667           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
4668           /*cols=*/rhs_cols)));
4669 
4670   DotDimensionNumbers dot_dnums;
4671   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
4672   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
4673 
4674   int64 dot_row_size = 1;
4675   int64 dot_col_size = spec.n;
4676   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
4677   builder.AddInstruction(HloInstruction::CreateDot(
4678       dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4679 
4680   auto computation = m->AddEntryComputation(builder.Build());
4681   AlgebraicSimplifier simplifier(default_options_);
4682   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4683   ASSERT_TRUE(run_successful);
4684   EXPECT_TRUE(
4685       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4686 
4687   if (spec.neg) {
4688     EXPECT_NE(computation->root_instruction()->opcode(),
4689               HloOpcode::kDynamicSlice);
4690   } else {
4691     EXPECT_THAT(computation->root_instruction(),
4692                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
4693                                            m::Constant(), m::Constant())));
4694   }
4695 }
4696 
4697 // input: dot(ctA, DS(ctB))
4698 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
4699 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
4700 // output: DS(dot(ctA, ctB))
4701 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)4702 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
4703   auto m = CreateNewVerifiedModule();
4704   HloComputation::Builder builder(TestName());
4705 
4706   DotOfGatherTestSpec spec = GetParam();
4707 
4708   ASSERT_LE(spec.s, spec.n);
4709 
4710   int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
4711   int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
4712   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
4713   auto* lhs = builder.AddInstruction(
4714       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4715           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
4716           /*cols=*/lhs_cols)));
4717 
4718   // For negative tests increase k of the dynamic slice argument to prevent the
4719   // optimization
4720   int64 k_increase = spec.neg ? 5 : 0;
4721   int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
4722   int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
4723   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
4724   auto* rhs = builder.AddInstruction(
4725       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4726           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
4727           /*cols=*/rhs_cols)));
4728 
4729   int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
4730   int32 start_col = (spec.rcd == 0) ? spec.s : 0;
4731   std::vector<HloInstruction*> start_indices = {
4732       builder.AddInstruction(HloInstruction::CreateConstant(
4733           LiteralUtil::CreateR0<int32>(start_row))),
4734       builder.AddInstruction(HloInstruction::CreateConstant(
4735           LiteralUtil::CreateR0<int32>(start_col)))};
4736   int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
4737   int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
4738   std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
4739   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
4740   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4741       ds_shape, rhs, start_indices, slice_sizes));
4742 
4743   DotDimensionNumbers dot_dnums;
4744   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
4745   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
4746 
4747   int64 dot_row_size = spec.m;
4748   int64 dot_col_size = 1;
4749   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
4750   builder.AddInstruction(HloInstruction::CreateDot(
4751       dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
4752 
4753   auto computation = m->AddEntryComputation(builder.Build());
4754   AlgebraicSimplifier simplifier(default_options_);
4755   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4756   ASSERT_TRUE(run_successful);
4757   EXPECT_TRUE(
4758       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4759 
4760   if (spec.neg) {
4761     EXPECT_NE(computation->root_instruction()->opcode(),
4762               HloOpcode::kDynamicSlice);
4763   } else {
4764     EXPECT_THAT(computation->root_instruction(),
4765                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
4766                                            m::Constant(), m::Constant())));
4767   }
4768 }
4769 
DotOfGatherPositiveNegativeTests()4770 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
4771   std::vector<DotOfGatherTestSpec> positives = {
4772       // "Classical dot", i.e. matrix multiply:
4773       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
4774        /*neg=*/false},
4775       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
4776        /*neg=*/false},
4777       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
4778        /*neg=*/false},
4779       // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
4780       // dot(ct, ct) before DotOfGather optimization kicks in.
4781       // Contract on rows:
4782       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
4783        /*neg=*/false},
4784       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
4785        /*neg=*/false},
4786       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
4787        /*neg=*/false},
4788       // Reverse matrix multiply:
4789       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
4790        /*neg=*/false},
4791       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
4792        /*neg=*/false},
4793       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
4794        /*neg=*/false},
4795       // Contract on columns:
4796       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
4797        /*neg=*/false},
4798       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
4799        /*neg=*/false},
4800       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
4801        /*neg=*/false},
4802   };
4803   std::vector<DotOfGatherTestSpec> all;
4804   for (int i = 0; i < positives.size(); i++) {
4805     DotOfGatherTestSpec positive_test = positives[i];
4806     all.push_back(positive_test);
4807     DotOfGatherTestSpec negative_test = positive_test;
4808     negative_test.neg = true;
4809     all.push_back(negative_test);
4810   }
4811   return all;
4812 }
4813 
4814 INSTANTIATE_TEST_SUITE_P(
4815     DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
4816     ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
4817 
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)4818 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
4819   const char* hlo_string = R"(
4820 HloModule module
4821 
4822 reducer {
4823   parameter.1 = f32[] parameter(0)
4824   parameter.3 = f32[] parameter(2)
4825   add.2 = f32[] add(parameter.1, parameter.3)
4826   parameter.0 = f32[] parameter(1)
4827   parameter.2 = f32[] parameter(3)
4828   add.3 = f32[] add(parameter.0, parameter.2)
4829   ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
4830 }
4831 
4832 ENTRY entry {
4833   parameter.6 = (f32[], f32[]) parameter(0)
4834   get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
4835   get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
4836   constant = f32[] constant(0)
4837   ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
4838 }
4839 )";
4840   TF_ASSERT_OK_AND_ASSIGN(auto module,
4841                           ParseAndReturnVerifiedModule(hlo_string));
4842 
4843   AlgebraicSimplifierOptions options;
4844   AlgebraicSimplifier simplifier(options);
4845   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4846   auto root = module->entry_computation()->root_instruction();
4847   EXPECT_THAT(root, GmockMatch(m::Tuple(
4848                         m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
4849                         m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
4850 }
4851 
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)4852 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
4853   const char* hlo_string = R"(
4854 HloModule module
4855 
4856 reducer {
4857   parameter.1 = f32[] parameter(0)
4858   parameter.3 = f32[] parameter(2)
4859   mul.2 = f32[] add(parameter.1, parameter.3)
4860   parameter.0 = f32[] parameter(1)
4861   parameter.2 = f32[] parameter(3)
4862   add.3 = f32[] add(parameter.0, parameter.2)
4863   ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
4864 }
4865 
4866 ENTRY entry {
4867   parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
4868   get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
4869   get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
4870   constant.0 = f32[] constant(0)
4871   constant.1 = f32[] constant(1)
4872   ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
4873 }
4874 )";
4875   TF_ASSERT_OK_AND_ASSIGN(auto module,
4876                           ParseAndReturnVerifiedModule(hlo_string));
4877 
4878   AlgebraicSimplifierOptions options;
4879   AlgebraicSimplifier simplifier(options);
4880   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4881   auto root = module->entry_computation()->root_instruction();
4882   EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
4883                                         m::Broadcast(m::ConstantScalar(1)))));
4884 }
4885 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)4886 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
4887   auto builder = HloComputation::Builder(TestName());
4888   HloInstruction* param =
4889       builder.AddInstruction(HloInstruction::CreateParameter(
4890           0, ShapeUtil::MakeShape(F32, {1}), "param"));
4891   HloInstruction* broadcast =
4892       builder.AddInstruction(HloInstruction::CreateBroadcast(
4893           ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
4894 
4895   // Create a reshape with zero sized result and without layout.
4896   Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
4897   reshaped_shape.clear_layout();
4898   builder.AddInstruction(
4899       HloInstruction::CreateReshape(reshaped_shape, broadcast));
4900 
4901   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
4902   module->AddEntryComputation(builder.Build());
4903 
4904   AlgebraicSimplifierOptions options;
4905   AlgebraicSimplifier simplifier(options);
4906   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4907   HloInstruction* root = module->entry_computation()->root_instruction();
4908   EXPECT_THAT(root, GmockMatch(m::Constant()));
4909 }
4910 
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)4911 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
4912   Shape shape = ShapeUtil::MakeShape(F32, {});
4913   shape.clear_layout();
4914   auto builder = HloComputation::Builder(TestName());
4915   HloInstruction* param = builder.AddInstruction(
4916       HloInstruction::CreateParameter(0, shape, "param"));
4917 
4918   HloInstruction* const_value = builder.AddInstruction(
4919       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
4920   builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
4921                                                       param, const_value));
4922 
4923   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
4924   module->AddEntryComputation(builder.Build());
4925 
4926   AlgebraicSimplifierOptions options;
4927   AlgebraicSimplifier simplifier(options);
4928   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4929   HloInstruction* root = module->entry_computation()->root_instruction();
4930   EXPECT_THAT(root, GmockMatch(m::Multiply()));
4931 }
4932 
4933 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)4934 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
4935   const char* kModuleStr = R"(
4936     HloModule m
4937     test {
4938       p0 = f32[] parameter(0)
4939       p1 = f32[] parameter(1)
4940       sqrt = f32[] sqrt(p0)
4941       ROOT div = f32[] divide(p1, sqrt)
4942     }
4943   )";
4944   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
4945   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
4946   EXPECT_THAT(m->entry_computation()->root_instruction(),
4947               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
4948                                              m::Rsqrt(m::Parameter(0)))));
4949 }
4950 
4951 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)4952 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
4953   const char* kModuleStr = R"(
4954     HloModule m
4955     test {
4956       p0 = f32[] parameter(0)
4957       p1 = f32[] parameter(1)
4958       rsqrt = f32[] rsqrt(p0)
4959       ROOT div = f32[] divide(p1, rsqrt)
4960     }
4961   )";
4962   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
4963   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
4964   EXPECT_THAT(m->entry_computation()->root_instruction(),
4965               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
4966                                              m::Sqrt(m::Parameter(0)))));
4967 }
4968 
4969 }  // namespace
4970 }  // namespace xla
4971