1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/test.h"
40 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/window_util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 
46 namespace xla {
47 namespace {
48 
49 using ::testing::ElementsAre;
50 namespace m = match;
51 
52 class AlgebraicSimplifierTest : public HloTestBase {
53  protected:
54   AlgebraicSimplifierOptions default_options_;
55 };
56 
57 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)58 TEST_F(AlgebraicSimplifierTest, AddZero) {
59   auto m = CreateNewVerifiedModule();
60   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
61   HloComputation::Builder builder(TestName());
62   HloInstruction* param0 = builder.AddInstruction(
63       HloInstruction::CreateParameter(0, r0f32, "param0"));
64   HloInstruction* zero = builder.AddInstruction(
65       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
66   builder.AddInstruction(
67       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
68 
69   auto computation = m->AddEntryComputation(builder.Build());
70   HloInstruction* root = computation->root_instruction();
71   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
72   AlgebraicSimplifier simplifier(default_options_);
73   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
74   root = computation->root_instruction();
75   EXPECT_EQ(root, param0);
76 }
77 
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)78 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
79   const char* kModuleStr = R"(
80     HloModule m
81     test {
82       p0 = s32[8] parameter(0)
83       p1 = s32[8] parameter(1)
84       p2 = s32[8] parameter(2)
85       x = s32[8] multiply(p0, p2)
86       y = s32[8] multiply(p1, p2)
87       ROOT sum = s32[8] add(x, y)
88     }
89   )";
90   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
91   AlgebraicSimplifier simplifier(default_options_);
92   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
93   EXPECT_THAT(
94       m->entry_computation()->root_instruction(),
95       GmockMatch(m::MultiplyAnyOrder(
96           m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
97 }
98 
99 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)100 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
101   const char* kModuleStr = R"(
102     HloModule m
103     test {
104       p0 = f32[] parameter(0)
105       p1 = f32[] parameter(1)
106       c = f32[] constant(0.125)
107       x = f32[] multiply(p0, c)
108       y = f32[] multiply(p1, c)
109       ROOT sum = f32[] add(x, y)
110     }
111   )";
112   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
113   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
114   EXPECT_THAT(m->entry_computation()->root_instruction(),
115               GmockMatch(m::MultiplyAnyOrder(
116                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
117                   m::ConstantScalar(0.125))));
118 }
119 
120 // (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest,SquareOfAbs)121 TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
122   const char* kModuleStr = R"(
123     HloModule m
124     test {
125       p = f32[] parameter(0)
126       a = f32[] abs(p)
127       ROOT z = f32[] multiply(a, a)
128     }
129   )";
130   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
131   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
132   EXPECT_THAT(m->entry_computation()->root_instruction(),
133               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
134 }
135 
136 // (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain)137 TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
138   const char* kModuleStr = R"(
139     HloModule m
140     test {
141       p0 = f32[] parameter(0)
142       p1 = f32[] parameter(1)
143       c = f32[] constant(2)
144       d = f32[] constant(4)
145       x = f32[] multiply(p0, c)
146       y = f32[] multiply(p1, d)
147       ROOT z = f32[] multiply(x, y)
148     }
149   )";
150   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
151   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
152   EXPECT_THAT(
153       m->entry_computation()->root_instruction(),
154       GmockMatch(m::MultiplyAnyOrder(
155           m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
156           m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
157 }
158 
159 // (a*C1)*C2 => a*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain2)160 TEST_F(AlgebraicSimplifierTest, MultiplyChain2) {
161   const char* kModuleStr = R"(
162     HloModule m
163     test {
164       p0 = f32[] parameter(0)
165       a = f32[] constant(2)
166       b = f32[] constant(4)
167       c = f32[] multiply(p0, a)
168       ROOT y = f32[] multiply(c, b)
169     }
170   )";
171   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
172   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
173   EXPECT_THAT(m->entry_computation()->root_instruction(),
174               GmockMatch(m::MultiplyAnyOrder(
175                   m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2),
176                                                        m::ConstantScalar(4)))));
177 }
178 
179 // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==>
180 // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant))))
TEST_F(AlgebraicSimplifierTest,MultiplyBroadcastReassoc)181 TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) {
182   const char* kModuleStr = R"(
183     HloModule m
184     test {
185       p0 = f32[2,2] parameter(0)
186       p1 = f32[] parameter(1)
187       b = f32[] constant(2)
188       c = f32[2, 2] broadcast(b), dimensions={}
189       x = f32[2,2] multiply(p0, c)
190       y = f32[2,2] broadcast(p1), dimensions={}
191       ROOT z = f32[2,2] multiply(y, x)
192     }
193   )";
194   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
195   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
196   EXPECT_THAT(m->entry_computation()->root_instruction(),
197               GmockMatch(m::MultiplyAnyOrder(
198                   m::Parameter(0), m::Broadcast(m::MultiplyAnyOrder(
199                                        m::Parameter(1), m::Constant())))));
200 }
201 
202 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)203 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
204   const char* kModuleStr = R"(
205     HloModule m
206     test {
207       p0 = f32[4] parameter(0)
208       p1 = f32[4] parameter(1)
209       c = f32[] constant(0.125)
210       b = f32[4] broadcast(c), dimensions={}
211       x = f32[4] multiply(p0, b)
212       y = f32[4] multiply(p1, b)
213       ROOT sum = f32[4] add(x, y)
214     }
215   )";
216   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
217   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
218   EXPECT_THAT(m->entry_computation()->root_instruction(),
219               GmockMatch(m::MultiplyAnyOrder(
220                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
221                   m::Broadcast(m::ConstantScalar(0.125)))));
222 }
223 
224 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
225 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)226 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
227   const char* kModuleStr = R"(
228     HloModule m
229     test {
230       p0 = f32[] parameter(0)
231       p1 = f32[] parameter(1)
232       c = f32[] constant(0.3)
233       x = f32[] multiply(p0, c)
234       y = f32[] multiply(p1, c)
235       ROOT sum = f32[] add(x, y)
236     }
237   )";
238   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
239   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
240 }
241 
242 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
243 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)244 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
245   const char* kModuleStr = R"(
246     HloModule m
247     test {
248       p0 = c64[8] parameter(0)
249       p1 = c64[8] parameter(1)
250       p2 = c64[8] parameter(2)
251       x = c64[8] multiply(p0, p2)
252       y = c64[8] multiply(p1, p2)
253       ROOT sum = c64[8] add(x, y)
254     }
255   )";
256   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
257   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
258 }
259 
260 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)261 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
262   const char* kModuleStr = R"(
263     HloModule m
264     test {
265       p0 = bf16[4] parameter(0)
266       p1 = bf16[4] parameter(1)
267       c = bf16[] constant(0.125)
268       b = bf16[4] broadcast(c), dimensions={}
269       x = bf16[4] multiply(p0, b)
270       y = bf16[4] multiply(p1, b)
271       ROOT sum = bf16[4] add(x, y)
272     }
273   )";
274   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
275   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
276   EXPECT_THAT(m->entry_computation()->root_instruction(),
277               GmockMatch(m::MultiplyAnyOrder(
278                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
279                   m::Broadcast(m::ConstantScalar(0.125)))));
280 }
281 
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)282 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
283   const char* kModuleStr = R"(
284     HloModule m
285     test {
286       p = u32[4] parameter(0)
287       c = u32[] constant(8)
288       b = u32[4] broadcast(c), dimensions={}
289       ROOT d = u32[4] divide(p, b)
290     }
291   )";
292   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
293   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
294   EXPECT_THAT(m->entry_computation()->root_instruction(),
295               GmockMatch(m::ShiftRightLogical(
296                   m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
297 }
298 
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)299 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
300   const char* kModuleStr = R"(
301     HloModule m
302     test {
303       p = s32[4] parameter(0)
304       c = s32[] constant(8)
305       b = s32[4] broadcast(c), dimensions={}
306       ROOT d = s32[4] divide(p, b)
307     }
308   )";
309   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
310   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
311   auto match_dividend_is_negative =
312       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
313   auto match_abs = m::Select(match_dividend_is_negative,
314                              m::Negate(m::Parameter(0)), m::Parameter(0));
315   auto match_shift =
316       m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
317   EXPECT_THAT(m->entry_computation()->root_instruction(),
318               GmockMatch(m::Select(match_dividend_is_negative,
319                                    m::Negate(match_shift), match_shift)));
320 }
321 
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)322 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
323   const char* kModuleStr = R"(
324     HloModule m
325     test {
326       p = u32[4] parameter(0)
327       c = u32[] constant(8)
328       b = u32[4] broadcast(c), dimensions={}
329       ROOT r = u32[4] remainder(p, b)
330     }
331   )";
332   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
333   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
334   EXPECT_THAT(m->entry_computation()->root_instruction(),
335               GmockMatch(m::AndAnyOrder(m::Parameter(0),
336                                         m::Broadcast(m::ConstantScalar(7)))));
337 }
338 
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)339 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
340   const char* kModuleStr = R"(
341     HloModule m
342     test {
343       p = s32[4] parameter(0)
344       c = s32[] constant(8)
345       b = s32[4] broadcast(c), dimensions={}
346       ROOT r = s32[4] remainder(p, b)
347     }
348   )";
349   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
350   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
351   auto match_dividend_is_negative =
352       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
353   auto match_abs = m::Select(match_dividend_is_negative,
354                              m::Negate(m::Parameter(0)), m::Parameter(0));
355   auto match_and =
356       m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
357   EXPECT_THAT(m->entry_computation()->root_instruction(),
358               GmockMatch(m::Select(match_dividend_is_negative,
359                                    m::Negate(match_and), match_and)));
360 }
361 
362 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)363 TEST_F(AlgebraicSimplifierTest, MulZero) {
364   auto m = CreateNewVerifiedModule();
365   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
366   HloComputation::Builder builder(TestName());
367   HloInstruction* param0 = builder.AddInstruction(
368       HloInstruction::CreateParameter(0, r0s32, "param0"));
369   HloInstruction* zero = builder.AddInstruction(
370       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
371   builder.AddInstruction(
372       HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
373 
374   auto computation = m->AddEntryComputation(builder.Build());
375   HloInstruction* root = computation->root_instruction();
376   EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
377   AlgebraicSimplifier simplifier(default_options_);
378   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
379   EXPECT_EQ(computation->root_instruction(), zero);
380 }
381 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeConstants)382 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
383   const char* kModuleStr = R"(
384     HloModule m
385     test {
386       p0 = f32[] parameter(0)
387       c0 = f32[] constant(2.0)
388       c1 = f32[] constant(3.0)
389       multiply0 = f32[] multiply(p0, c0)
390       ROOT multiply1 = f32[] multiply(multiply0, c1)
391     }
392   )";
393   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
394   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
395   EXPECT_THAT(m->entry_computation()->root_instruction(),
396               GmockMatch(m::Multiply(m::Parameter(0),
397                                      m::Multiply(m::ConstantScalar(2.0),
398                                                  m::ConstantScalar(3.0)))));
399 }
400 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeBroadcastedConstants)401 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
402   const char* kModuleStr = R"(
403     HloModule m
404     test {
405       p0 = f32[4] parameter(0)
406       c0 = f32[] constant(2.0)
407       c1 = f32[] constant(3.0)
408       b0 = f32[4] broadcast(c0), dimensions={}
409       b1 = f32[4] broadcast(c1), dimensions={}
410       multiply0 = f32[4] multiply(p0, b0)
411       ROOT multiply1 = f32[4] multiply(multiply0, b1)
412     }
413   )";
414   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
415   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
416   EXPECT_THAT(
417       m->entry_computation()->root_instruction(),
418       GmockMatch(m::Multiply(
419           m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
420                                                     m::ConstantScalar(3.0))))));
421 }
422 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsScalar)423 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
424   const char* kModuleStr = R"(
425     HloModule m
426     test {
427       p0 = f32[] parameter(0)
428       p1 = f32[] parameter(1)
429       b0 = f32[4] broadcast(p0), dimensions={}
430       b1 = f32[4] broadcast(p1), dimensions={}
431       ROOT multiply = f32[4] multiply(b1, b0)
432     }
433   )";
434   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
435   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
436   EXPECT_THAT(
437       m->entry_computation()->root_instruction(),
438       GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
439                                           m::Broadcast(m::Parameter(0))))));
440 }
441 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsConstantMix)442 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
443   const char* kModuleStr = R"(
444     HloModule m
445     test {
446       p0 = f32[4] parameter(0)
447       c0 = f32[] constant(2.0)
448       b0 = f32[4,2] broadcast(c0), dimensions={}
449       b1 = f32[4,2] broadcast(p0), dimensions={0}
450       ROOT multiply = f32[4,2] multiply(b1, b0)
451     }
452   )";
453   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
454   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
455   EXPECT_THAT(m->entry_computation()->root_instruction(),
456               GmockMatch(m::Broadcast(m::Multiply(
457                   m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
458 }
459 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsNonScalar)460 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
461   const char* kModuleStr = R"(
462     HloModule m
463     test {
464       p0 = f32[4] parameter(0)
465       p1 = f32[4] parameter(1)
466       b0 = f32[4,2] broadcast(p0), dimensions={0}
467       b1 = f32[4,2] broadcast(p1), dimensions={0}
468       ROOT multiply = f32[4,2] multiply(b1, b0)
469     }
470   )";
471   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
472   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
473   EXPECT_THAT(
474       m->entry_computation()->root_instruction(),
475       GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
476 }
477 
TEST_F(AlgebraicSimplifierTest,ElementwiseNoSinkBroadcastsDifferentDims)478 TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
479   const char* kModuleStr = R"(
480     HloModule m
481     test {
482       p0 = f32[4] parameter(0)
483       p1 = f32[8] parameter(1)
484       b0 = f32[4,8] broadcast(p0), dimensions={0}
485       b1 = f32[4,8] broadcast(p1), dimensions={1}
486       ROOT multiply = f32[4,8] multiply(b1, b0)
487     }
488   )";
489   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
490   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
491   EXPECT_THAT(m->entry_computation()->root_instruction(),
492               GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
493                                      m::Broadcast(m::Parameter(0)))));
494 }
495 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMultiplyOfConstantAndBroadcast)496 TEST_F(AlgebraicSimplifierTest,
497        MultiplyReassociateMultiplyOfConstantAndBroadcast) {
498   const char* kModuleStr = R"(
499     HloModule m
500     test {
501       c0 = f32[4] constant({2.0, 3.0, 4.0, 5.0})
502       c1 = f32[] constant(3.0)
503       c2 = f32[] constant(4.0)
504       b0 = f32[4] broadcast(c1), dimensions={}
505       b1 = f32[4] broadcast(c2), dimensions={}
506       multiply0 = f32[4] multiply(c0, b0)
507       ROOT multiply1 = f32[4] multiply(multiply0, b1)
508     }
509   )";
510   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
511   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
512   EXPECT_THAT(
513       m->entry_computation()->root_instruction(),
514       GmockMatch(m::Multiply(
515           m::Constant(), m::Broadcast(m::Multiply(m::ConstantScalar(3.0),
516                                                   m::ConstantScalar(4.0))))));
517 }
518 
519 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)520 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
521   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
522   HloComputation::Builder builder(TestName());
523   HloInstruction* param0 = builder.AddInstruction(
524       HloInstruction::CreateParameter(0, r0s32, "param0"));
525   HloInstruction* param1 = builder.AddInstruction(
526       HloInstruction::CreateParameter(1, r0s32, "param1"));
527   HloInstruction* one = builder.AddInstruction(
528       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
529   builder.AddInstruction(HloInstruction::CreateTernary(
530       r0s32, HloOpcode::kSelect, one, param0, param1));
531 
532   auto module = CreateNewVerifiedModule();
533   auto computation = module->AddEntryComputation(builder.Build());
534   HloInstruction* root = computation->root_instruction();
535   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
536   AlgebraicSimplifier simplifier(default_options_);
537   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
538   EXPECT_EQ(computation->root_instruction(), param0);
539 }
540 
541 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)542 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
543   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
544   HloComputation::Builder builder(TestName());
545   HloInstruction* param0 = builder.AddInstruction(
546       HloInstruction::CreateParameter(0, r0s32, "param0"));
547   HloInstruction* param1 = builder.AddInstruction(
548       HloInstruction::CreateParameter(1, r0s32, "param1"));
549   HloInstruction* zero = builder.AddInstruction(
550       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
551   builder.AddInstruction(HloInstruction::CreateTernary(
552       r0s32, HloOpcode::kSelect, zero, param0, param1));
553 
554   auto module = CreateNewVerifiedModule();
555   auto computation = module->AddEntryComputation(builder.Build());
556   HloInstruction* root = computation->root_instruction();
557   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
558   AlgebraicSimplifier simplifier(default_options_);
559   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
560   EXPECT_EQ(computation->root_instruction(), param1);
561 }
562 
563 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)564 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
565   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
566   HloComputation::Builder builder(TestName());
567   HloInstruction* param0 = builder.AddInstruction(
568       HloInstruction::CreateParameter(0, r0s32, "param0"));
569   HloInstruction* param1 = builder.AddInstruction(
570       HloInstruction::CreateParameter(1, r0s32, "param1"));
571   builder.AddInstruction(HloInstruction::CreateTernary(
572       r0s32, HloOpcode::kSelect, param0, param1, param1));
573 
574   auto module = CreateNewVerifiedModule();
575   auto computation = module->AddEntryComputation(builder.Build());
576   HloInstruction* root = computation->root_instruction();
577   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
578   AlgebraicSimplifier simplifier(default_options_);
579   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
580   EXPECT_EQ(computation->root_instruction(), param1);
581 }
582 
583 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)584 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
585   auto m = CreateNewVerifiedModule();
586   HloComputation::Builder builder(TestName());
587   // Create add computation.
588   HloInstruction* zero = builder.AddInstruction(
589       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
590   HloComputation* add_computation = nullptr;
591   {
592     HloComputation::Builder builder(TestName() + ".add");
593     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
594     HloInstruction* p0 = builder.AddInstruction(
595         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
596     HloInstruction* p1 = builder.AddInstruction(
597         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
598     builder.AddInstruction(
599         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
600     add_computation = m->AddEmbeddedComputation(builder.Build());
601   }
602   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
603   HloInstruction* param = builder.AddInstruction(
604       HloInstruction::CreateParameter(0, r4f32, "param"));
605   std::vector<int64> dims0({0});
606   Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
607   HloInstruction* reduce0 = builder.AddInstruction(
608       HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
609   std::vector<int64> dims1({1, 2});
610   Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
611   builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
612                                                       dims1, add_computation));
613   m->AddEntryComputation(builder.Build());
614   AlgebraicSimplifier simplifier(default_options_);
615   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
616   HloInstruction* root = m->entry_computation()->root_instruction();
617   EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
618   EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
619 }
620 
621 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)622 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
623   auto m = CreateNewVerifiedModule();
624   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
625   HloComputation::Builder builder(TestName());
626   HloInstruction* param0 = builder.AddInstruction(
627       HloInstruction::CreateParameter(0, r0f32, "param0"));
628   HloInstruction* constant = builder.AddInstruction(
629       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
630   builder.AddInstruction(
631       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
632 
633   auto computation = m->AddEntryComputation(builder.Build());
634   HloInstruction* root = computation->root_instruction();
635   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
636   AlgebraicSimplifier simplifier(default_options_);
637   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
638   root = computation->root_instruction();
639   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
640 }
641 
642 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)643 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
644   auto m = CreateNewVerifiedModule();
645   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
646   HloComputation::Builder builder(TestName());
647   HloInstruction* param0 = builder.AddInstruction(
648       HloInstruction::CreateParameter(0, r0f32, "param0"));
649   HloInstruction* constant1 = builder.AddInstruction(
650       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
651   HloInstruction* constant2 = builder.AddInstruction(
652       HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
653 
654   HloInstruction* add1 = builder.AddInstruction(
655       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
656   builder.AddInstruction(
657       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
658 
659   auto computation = m->AddEntryComputation(builder.Build());
660   HloInstruction* root = computation->root_instruction();
661   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
662   AlgebraicSimplifier simplifier(default_options_);
663   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
664   root = computation->root_instruction();
665   EXPECT_THAT(root, GmockMatch(m::Add(
666                         m::Op().Is(param0),
667                         m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
668 }
669 
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeBroadcastedConstants)670 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
671   const char* kModuleStr = R"(
672     HloModule m
673     test {
674       p0 = f32[4] parameter(0)
675       c0 = f32[] constant(1.0)
676       c1 = f32[] constant(2.0)
677       b0 = f32[4] broadcast(c0), dimensions={}
678       b1 = f32[4] broadcast(c1), dimensions={}
679       add0 = f32[4] add(p0, b0)
680       ROOT add1 = f32[4] add(add0, b1)
681     }
682   )";
683   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
684   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
685   EXPECT_THAT(m->entry_computation()->root_instruction(),
686               GmockMatch(m::Add(m::Parameter(0),
687                                 m::Broadcast(m::Add(m::ConstantScalar(1.0),
688                                                     m::ConstantScalar(2.0))))));
689 }
690 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)691 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
692   auto m = CreateNewVerifiedModule();
693   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
694   HloComputation::Builder builder(TestName());
695   HloInstruction* param0 = builder.AddInstruction(
696       HloInstruction::CreateParameter(0, r2f32, "param0"));
697   HloInstruction* zero = builder.AddInstruction(
698       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
699   HloInstruction* bcast = builder.AddInstruction(
700       HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
701   builder.AddInstruction(
702       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
703 
704   auto computation = m->AddEntryComputation(builder.Build());
705   HloInstruction* root = computation->root_instruction();
706   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
707   AlgebraicSimplifier simplifier(default_options_);
708   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
709   root = computation->root_instruction();
710   EXPECT_EQ(root, param0);
711 }
712 
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)713 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
714   auto m = CreateNewVerifiedModule();
715   HloComputation::Builder builder(TestName());
716   // Create add computation.
717   HloComputation* add_computation = nullptr;
718   {
719     HloComputation::Builder builder(TestName() + ".add");
720     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
721     HloInstruction* p0 = builder.AddInstruction(
722         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
723     HloInstruction* p1 = builder.AddInstruction(
724         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
725     builder.AddInstruction(
726         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
727     add_computation = m->AddEmbeddedComputation(builder.Build());
728   }
729   Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
730   HloInstruction* param0 = builder.AddInstruction(
731       HloInstruction::CreateParameter(0, r2f32, "param0"));
732   HloInstruction* zero = builder.AddInstruction(
733       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
734   builder.AddInstruction(HloInstruction::CreateMap(
735       r2f32,
736       {param0, builder.AddInstruction(
737                    HloInstruction::CreateBroadcast(r2f32, zero, {}))},
738       add_computation));
739 
740   auto computation = m->AddEntryComputation(builder.Build());
741   HloInstruction* root = computation->root_instruction();
742   EXPECT_EQ(root->opcode(), HloOpcode::kMap);
743   AlgebraicSimplifier simplifier(default_options_);
744   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
745   root = computation->root_instruction();
746   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
747                                       m::Broadcast(m::Op().Is(zero)))));
748 }
749 
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)750 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
751   const char* kModuleStr = R"(
752     HloModule m
753     fusion {
754       x = f32[] parameter(0)
755       c = f32[] constant(42)
756       m = f32[] multiply(x, x)
757       ROOT a = f32[] add(m, c)
758     }
759 
760     map {
761       x = f32[] parameter(0)
762       ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
763     }
764 
765     ENTRY test {
766       p = f32[2,2] parameter(0)
767       ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
768     }
769   )";
770   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
771   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
772 }
773 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)774 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
775   auto m = CreateNewVerifiedModule();
776   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
777   HloComputation::Builder builder(TestName());
778   HloInstruction* param0 = builder.AddInstruction(
779       HloInstruction::CreateParameter(0, r2f32, "param0"));
780   HloInstruction* zero = builder.AddInstruction(
781       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
782   HloInstruction* bcast =
783       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
784   builder.AddInstruction(
785       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
786 
787   auto computation = m->AddEntryComputation(builder.Build());
788   HloInstruction* root = computation->root_instruction();
789   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
790   AlgebraicSimplifier simplifier(default_options_);
791   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
792   root = computation->root_instruction();
793   EXPECT_EQ(root, param0);
794 }
795 
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)796 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
797   auto m = CreateNewVerifiedModule();
798   HloComputation::Builder builder(TestName());
799   builder.AddInstruction(HloInstruction::CreateConstant(
800       LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
801 
802   auto computation = m->AddEntryComputation(builder.Build());
803   HloInstruction* root = computation->root_instruction();
804   EXPECT_THAT(root, GmockMatch(m::Constant()));
805   AlgebraicSimplifier simplifier(default_options_);
806   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
807   root = computation->root_instruction();
808   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
809   EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
810 }
811 
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)812 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
813   auto m = CreateNewVerifiedModule();
814   HloComputation::Builder builder(TestName());
815   builder.AddInstruction(HloInstruction::CreateConstant(
816       LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
817 
818   auto computation = m->AddEntryComputation(builder.Build());
819   HloInstruction* root = computation->root_instruction();
820   EXPECT_THAT(root, GmockMatch(m::Constant()));
821   AlgebraicSimplifier simplifier(default_options_);
822   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
823   root = computation->root_instruction();
824   EXPECT_THAT(root, GmockMatch(m::Constant()));
825 }
826 
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)827 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
828   auto m = CreateNewVerifiedModule();
829   HloComputation::Builder builder(TestName());
830   builder.AddInstruction(HloInstruction::CreateConstant(
831       LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
832 
833   auto computation = m->AddEntryComputation(builder.Build());
834   HloInstruction* root = computation->root_instruction();
835   EXPECT_THAT(root, GmockMatch(m::Constant()));
836   AlgebraicSimplifier simplifier(default_options_);
837   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
838   root = computation->root_instruction();
839   EXPECT_THAT(root, GmockMatch(m::Iota()));
840 }
841 
842 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)843 TEST_F(AlgebraicSimplifierTest, SubZero) {
844   auto m = CreateNewVerifiedModule();
845   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
846   HloComputation::Builder builder(TestName());
847   HloInstruction* param0 = builder.AddInstruction(
848       HloInstruction::CreateParameter(0, r0f32, "param0"));
849   HloInstruction* zero = builder.AddInstruction(
850       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
851   builder.AddInstruction(
852       HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
853 
854   auto computation = m->AddEntryComputation(builder.Build());
855   HloInstruction* root = computation->root_instruction();
856   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
857   AlgebraicSimplifier simplifier(default_options_);
858   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
859   root = computation->root_instruction();
860   EXPECT_EQ(root, param0);
861 }
862 
863 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)864 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
865   auto m = CreateNewVerifiedModule();
866   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
867   HloComputation::Builder builder(TestName());
868   HloInstruction* param0 = builder.AddInstruction(
869       HloInstruction::CreateParameter(0, r0f32, "param0"));
870   HloInstruction* constant = builder.AddInstruction(
871       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
872   builder.AddInstruction(HloInstruction::CreateBinary(
873       r0f32, HloOpcode::kSubtract, param0, constant));
874 
875   auto computation = m->AddEntryComputation(builder.Build());
876   HloInstruction* root = computation->root_instruction();
877   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
878   AlgebraicSimplifier simplifier(default_options_);
879   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
880   root = computation->root_instruction();
881   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
882                                       m::Negate(m::Op().Is(constant)))));
883 }
884 
885 // Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
TEST_F(AlgebraicSimplifierTest,SubBroadcastConstCanonicalization)886 TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
887   const char* kModuleStr = R"(
888     HloModule m
889     test {
890       p0 = f32[4] parameter(0)
891       c = f32[] constant(0.125)
892       b = f32[4] broadcast(c), dimensions={}
893       ROOT sub = f32[4] subtract(p0, b)
894     }
895   )";
896   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
897   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
898   EXPECT_THAT(
899       m->entry_computation()->root_instruction(),
900       GmockMatch(m::Add(m::Parameter(0),
901                         m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
902 }
903 
904 // Test that Broadcast(x) where x has degenerate dimensions first removes the
905 // degenerate dimensions.
TEST_F(AlgebraicSimplifierTest,DegenerateDimsInOperandRemovedFromBroadcast)906 TEST_F(AlgebraicSimplifierTest, DegenerateDimsInOperandRemovedFromBroadcast) {
907   const char* kModuleStr = R"(
908     HloModule m
909     test {
910       c = f32[1,4] parameter(0)
911       ROOT b = f32[5,1,4,3] broadcast(c), dimensions={1,2}
912     }
913   )";
914   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
915   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
916   EXPECT_THAT(m->entry_computation()->root_instruction(),
917               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
918 }
919 
920 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)921 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
922   auto m = CreateNewVerifiedModule();
923   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
924   HloComputation::Builder builder(TestName());
925   HloInstruction* param0 = builder.AddInstruction(
926       HloInstruction::CreateParameter(0, r0f32, "param0"));
927   HloInstruction* param1 = builder.AddInstruction(
928       HloInstruction::CreateParameter(1, r0f32, "param1"));
929   HloInstruction* param2 = builder.AddInstruction(
930       HloInstruction::CreateParameter(2, r0f32, "param2"));
931   HloInstruction* div = builder.AddInstruction(
932       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
933   builder.AddInstruction(
934       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
935 
936   auto computation = m->AddEntryComputation(builder.Build());
937 
938   EXPECT_THAT(computation->root_instruction(),
939               GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
940                                    m::Parameter(2))));
941 
942   AlgebraicSimplifier simplifier(default_options_);
943   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
944 
945   EXPECT_THAT(
946       computation->root_instruction(),
947       GmockMatch(m::Divide(m::Parameter(0),
948                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
949 }
950 
951 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)952 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
953   auto m = CreateNewVerifiedModule();
954   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
955   HloComputation::Builder builder(TestName());
956   HloInstruction* param0 = builder.AddInstruction(
957       HloInstruction::CreateParameter(0, r0f32, "param0"));
958   HloInstruction* param1 = builder.AddInstruction(
959       HloInstruction::CreateParameter(1, r0f32, "param1"));
960   HloInstruction* param2 = builder.AddInstruction(
961       HloInstruction::CreateParameter(2, r0f32, "param2"));
962   HloInstruction* div = builder.AddInstruction(
963       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
964   builder.AddInstruction(
965       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
966 
967   auto computation = m->AddEntryComputation(builder.Build());
968 
969   EXPECT_THAT(
970       computation->root_instruction(),
971       GmockMatch(m::Divide(m::Parameter(0),
972                            m::Divide(m::Parameter(1), m::Parameter(2)))));
973 
974   AlgebraicSimplifier simplifier(default_options_);
975   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
976 
977   EXPECT_THAT(
978       computation->root_instruction(),
979       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
980                            m::Parameter(1))));
981 }
982 
983 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)984 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
985   auto m = CreateNewVerifiedModule();
986   Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
987   HloComputation::Builder builder(TestName());
988   HloInstruction* param0 = builder.AddInstruction(
989       HloInstruction::CreateParameter(0, r2f32, "param0"));
990   HloInstruction* param1 = builder.AddInstruction(
991       HloInstruction::CreateParameter(1, r2f32, "param1"));
992   HloInstruction* param2 = builder.AddInstruction(
993       HloInstruction::CreateParameter(2, r2f32, "param2"));
994   HloInstruction* param3 = builder.AddInstruction(
995       HloInstruction::CreateParameter(3, r2f32, "param3"));
996   HloInstruction* div0 = builder.AddInstruction(
997       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
998   HloInstruction* div1 = builder.AddInstruction(
999       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
1000   builder.AddInstruction(
1001       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
1002 
1003   auto computation = m->AddEntryComputation(builder.Build());
1004 
1005   EXPECT_THAT(
1006       computation->root_instruction(),
1007       GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1008                            m::Divide(m::Parameter(2), m::Parameter(3)))));
1009 
1010   AlgebraicSimplifier simplifier(default_options_);
1011   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1012 
1013   EXPECT_THAT(
1014       computation->root_instruction(),
1015       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
1016                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
1017 }
1018 
1019 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)1020 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
1021   auto m = CreateNewVerifiedModule();
1022   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1023   HloComputation::Builder builder(TestName());
1024   HloInstruction* param0 = builder.AddInstruction(
1025       HloInstruction::CreateParameter(0, r0f32, "param0"));
1026   HloInstruction* param1 = builder.AddInstruction(
1027       HloInstruction::CreateParameter(1, r0f32, "param1"));
1028   HloInstruction* exp = builder.AddInstruction(
1029       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1030   builder.AddInstruction(
1031       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
1032 
1033   auto computation = m->AddEntryComputation(builder.Build());
1034 
1035   EXPECT_THAT(computation->root_instruction(),
1036               GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
1037 
1038   AlgebraicSimplifier simplifier(default_options_);
1039   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1040 
1041   EXPECT_THAT(computation->root_instruction(),
1042               GmockMatch(m::Multiply(m::Parameter(0),
1043                                      m::Exp(m::Negate(m::Parameter(1))))));
1044 }
1045 
1046 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)1047 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
1048   auto m = CreateNewVerifiedModule();
1049   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1050   HloComputation::Builder builder(TestName());
1051   HloInstruction* param0 = builder.AddInstruction(
1052       HloInstruction::CreateParameter(0, r0f32, "param0"));
1053   HloInstruction* param1 = builder.AddInstruction(
1054       HloInstruction::CreateParameter(1, r0f32, "param1"));
1055   HloInstruction* param2 = builder.AddInstruction(
1056       HloInstruction::CreateParameter(2, r0f32, "param2"));
1057   HloInstruction* power = builder.AddInstruction(
1058       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
1059   builder.AddInstruction(
1060       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
1061 
1062   auto computation = m->AddEntryComputation(builder.Build());
1063 
1064   EXPECT_THAT(
1065       computation->root_instruction(),
1066       GmockMatch(m::Divide(m::Parameter(0),
1067                            m::Power(m::Parameter(1), m::Parameter(2)))));
1068 
1069   AlgebraicSimplifier simplifier(default_options_);
1070   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1071 
1072   EXPECT_THAT(computation->root_instruction(),
1073               GmockMatch(m::Multiply(
1074                   m::Parameter(0),
1075                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1076 }
1077 
1078 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
1079 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)1080 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
1081   auto m = CreateNewVerifiedModule();
1082   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1083   HloComputation::Builder builder(TestName());
1084   HloInstruction* param0 = builder.AddInstruction(
1085       HloInstruction::CreateParameter(0, r1f32, "param0"));
1086   HloInstruction* param1 = builder.AddInstruction(
1087       HloInstruction::CreateParameter(1, r1f32, "param1"));
1088   HloInstruction* param2 = builder.AddInstruction(
1089       HloInstruction::CreateParameter(2, r1f32, "param2"));
1090   HloInstruction* power = builder.AddInstruction(
1091       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
1092   builder.AddInstruction(
1093       HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
1094 
1095   auto computation = m->AddEntryComputation(builder.Build());
1096 
1097   EXPECT_THAT(
1098       computation->root_instruction(),
1099       GmockMatch(m::Divide(m::Parameter(0),
1100                            m::Power(m::Parameter(1), m::Parameter(2)))));
1101 
1102   AlgebraicSimplifier simplifier(default_options_);
1103   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1104 
1105   ASSERT_THAT(computation->root_instruction(),
1106               GmockMatch(m::Multiply(
1107                   m::Parameter(0),
1108                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1109 }
1110 
1111 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)1112 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
1113   auto m = CreateNewVerifiedModule();
1114   Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
1115   HloComputation::Builder builder(TestName());
1116   HloInstruction* param0 = builder.AddInstruction(
1117       HloInstruction::CreateParameter(0, r1f32, "param0"));
1118   HloInstruction* constant =
1119       builder.AddInstruction(HloInstruction::CreateConstant(
1120           LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
1121   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
1122                                                       param0, constant));
1123 
1124   auto computation = m->AddEntryComputation(builder.Build());
1125 
1126   AlgebraicSimplifier simplifier(default_options_);
1127   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1128 
1129   EXPECT_THAT(computation->root_instruction(),
1130               GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
1131 }
1132 
1133 // A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest,DivideByBroadcastedConstant)1134 TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
1135   const char* kModuleStr = R"(
1136     HloModule m
1137     test {
1138       p = f32[4] parameter(0)
1139       c = f32[] constant(256.0)
1140       b = f32[4] broadcast(c), dimensions={}
1141       ROOT d = f32[4] divide(p, b)
1142     }
1143   )";
1144   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
1145   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
1146 
1147   EXPECT_THAT(m->entry_computation()->root_instruction(),
1148               GmockMatch(m::Multiply(
1149                   m::Parameter(0),
1150                   m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
1151 }
1152 
1153 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)1154 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
1155   auto m = CreateNewVerifiedModule();
1156   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1157   HloComputation::Builder builder(TestName());
1158   HloInstruction* base = builder.AddInstruction(
1159       HloInstruction::CreateParameter(0, r1f32, "param0"));
1160   HloInstruction* exp1 = builder.AddInstruction(
1161       HloInstruction::CreateParameter(1, r1f32, "param1"));
1162   HloInstruction* exp2 = builder.AddInstruction(
1163       HloInstruction::CreateParameter(2, r1f32, "param2"));
1164   HloInstruction* inner_power = builder.AddInstruction(
1165       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
1166   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
1167                                                       inner_power, exp2));
1168 
1169   AlgebraicSimplifier simplifier(default_options_);
1170   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1171 }
1172 
1173 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
1174 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)1175 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
1176   auto m = CreateNewVerifiedModule();
1177   Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
1178   HloComputation::Builder builder(TestName());
1179   HloInstruction* base = builder.AddInstruction(
1180       HloInstruction::CreateParameter(0, r1c64, "param0"));
1181   HloInstruction* exp1 = builder.AddInstruction(
1182       HloInstruction::CreateParameter(1, r1c64, "param1"));
1183   HloInstruction* exp2 = builder.AddInstruction(
1184       HloInstruction::CreateParameter(2, r1c64, "param2"));
1185   HloInstruction* inner_power = builder.AddInstruction(
1186       HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
1187   builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
1188                                                       inner_power, exp2));
1189 
1190   m->AddEntryComputation(builder.Build());
1191   AlgebraicSimplifier simplifier(default_options_);
1192   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1193 }
1194 
1195 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)1196 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
1197   auto m = CreateNewVerifiedModule();
1198   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1199   HloComputation::Builder builder(TestName());
1200   HloInstruction* param0 = builder.AddInstruction(
1201       HloInstruction::CreateParameter(0, r0f32, "param0"));
1202   HloInstruction* one = builder.AddInstruction(
1203       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1204   HloInstruction* div = builder.AddInstruction(
1205       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
1206 
1207   auto computation = m->AddEntryComputation(builder.Build());
1208   HloInstruction* root = computation->root_instruction();
1209   EXPECT_EQ(root, div);
1210   AlgebraicSimplifier simplifier(default_options_);
1211   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1212   root = computation->root_instruction();
1213   EXPECT_EQ(root, param0);
1214 }
1215 
1216 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)1217 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
1218   auto m = CreateNewVerifiedModule();
1219   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1220   HloComputation::Builder builder(TestName());
1221   HloInstruction* param0 = builder.AddInstruction(
1222       HloInstruction::CreateParameter(0, r2f32, "param0"));
1223   HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
1224       LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
1225   HloInstruction* div = builder.AddInstruction(
1226       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
1227 
1228   auto computation = m->AddEntryComputation(builder.Build());
1229   HloInstruction* root = computation->root_instruction();
1230   EXPECT_EQ(root, div);
1231   AlgebraicSimplifier simplifier(default_options_);
1232   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1233   root = computation->root_instruction();
1234   EXPECT_EQ(root, param0);
1235 }
1236 
1237 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)1238 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
1239   auto m = CreateNewVerifiedModule();
1240   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1241   Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
1242   HloComputation::Builder builder(TestName());
1243   HloInstruction* param0 = builder.AddInstruction(
1244       HloInstruction::CreateParameter(0, r2c64, "param0"));
1245   HloInstruction* real = builder.AddInstruction(
1246       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
1247   HloInstruction* imag = builder.AddInstruction(
1248       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
1249   HloInstruction* cplx = builder.AddInstruction(
1250       HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
1251 
1252   auto computation = m->AddEntryComputation(builder.Build());
1253   HloInstruction* root = computation->root_instruction();
1254   EXPECT_EQ(root, cplx);
1255   AlgebraicSimplifier simplifier(default_options_);
1256   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1257   root = computation->root_instruction();
1258   EXPECT_EQ(root, param0);
1259 }
1260 
1261 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)1262 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
1263   auto m = CreateNewVerifiedModule();
1264   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1265   HloComputation::Builder builder(TestName());
1266   HloInstruction* param0 = builder.AddInstruction(
1267       HloInstruction::CreateParameter(0, r2f32, "param0"));
1268   HloInstruction* param1 = builder.AddInstruction(
1269       HloInstruction::CreateParameter(1, r2f32, "param1"));
1270   HloInstruction* cplx = builder.AddInstruction(
1271       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1272                                    HloOpcode::kComplex, param0, param1));
1273   HloInstruction* real = builder.AddInstruction(
1274       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
1275 
1276   auto computation = m->AddEntryComputation(builder.Build());
1277   HloInstruction* root = computation->root_instruction();
1278   EXPECT_EQ(root, real);
1279   AlgebraicSimplifier simplifier(default_options_);
1280   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1281   root = computation->root_instruction();
1282   EXPECT_EQ(root, param0);
1283 }
1284 
1285 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)1286 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
1287   auto m = CreateNewVerifiedModule();
1288   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1289   HloComputation::Builder builder(TestName());
1290   HloInstruction* param0 = builder.AddInstruction(
1291       HloInstruction::CreateParameter(0, r2f32, "param0"));
1292   HloInstruction* param1 = builder.AddInstruction(
1293       HloInstruction::CreateParameter(1, r2f32, "param1"));
1294   HloInstruction* cplx = builder.AddInstruction(
1295       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1296                                    HloOpcode::kComplex, param0, param1));
1297   HloInstruction* imag = builder.AddInstruction(
1298       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1299 
1300   auto computation = m->AddEntryComputation(builder.Build());
1301   HloInstruction* root = computation->root_instruction();
1302   EXPECT_EQ(root, imag);
1303   AlgebraicSimplifier simplifier(default_options_);
1304   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1305   root = computation->root_instruction();
1306   EXPECT_EQ(root, param1);
1307 }
1308 
1309 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1310 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1311   auto m = CreateNewVerifiedModule();
1312   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1313   HloComputation::Builder builder(TestName());
1314   HloInstruction* param0 = builder.AddInstruction(
1315       HloInstruction::CreateParameter(0, r0f32, "param0"));
1316   HloInstruction* param1 = builder.AddInstruction(
1317       HloInstruction::CreateParameter(1, r0f32, "param1"));
1318   HloInstruction* param2 = builder.AddInstruction(
1319       HloInstruction::CreateParameter(2, r0f32, "param2"));
1320   HloInstruction* tuple =
1321       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1322   HloInstruction* get = builder.AddInstruction(
1323       HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1324   HloInstruction* add = builder.AddInstruction(
1325       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1326 
1327   auto computation = m->AddEntryComputation(builder.Build());
1328   HloInstruction* root = computation->root_instruction();
1329   EXPECT_EQ(root, add);
1330   AlgebraicSimplifier simplifier(default_options_);
1331   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1332   root = computation->root_instruction();
1333   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1334 }
1335 
1336 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1337 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1338   auto m = CreateNewVerifiedModule();
1339   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1340   HloComputation::Builder builder(TestName());
1341   HloInstruction* param0 = builder.AddInstruction(
1342       HloInstruction::CreateParameter(0, r0f32, "param0"));
1343   HloInstruction* param1 = builder.AddInstruction(
1344       HloInstruction::CreateParameter(1, r0f32, "param1"));
1345   HloInstruction* exp0 = builder.AddInstruction(
1346       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1347   HloInstruction* exp1 = builder.AddInstruction(
1348       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1349   builder.AddInstruction(
1350       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1351 
1352   auto computation = m->AddEntryComputation(builder.Build());
1353 
1354   EXPECT_THAT(
1355       computation->root_instruction(),
1356       GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1357 
1358   AlgebraicSimplifier simplifier(default_options_);
1359   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1360 
1361   EXPECT_THAT(
1362       computation->root_instruction(),
1363       GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1364 }
1365 
1366 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1367 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1368   auto m = CreateNewVerifiedModule();
1369   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1370   HloComputation::Builder builder(TestName());
1371   HloInstruction* param0 = builder.AddInstruction(
1372       HloInstruction::CreateParameter(0, r0f32, "param0"));
1373   HloInstruction* param1 = builder.AddInstruction(
1374       HloInstruction::CreateParameter(1, r0f32, "param1"));
1375   HloInstruction* exp0 = builder.AddInstruction(
1376       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1377   HloInstruction* exp1 = builder.AddInstruction(
1378       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1379   builder.AddInstruction(
1380       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1381 
1382   auto computation = m->AddEntryComputation(builder.Build());
1383 
1384   EXPECT_THAT(computation->root_instruction(),
1385               GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1386                                      m::Exp(m::Parameter(1)))));
1387 
1388   AlgebraicSimplifier simplifier(default_options_);
1389   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1390 
1391   EXPECT_THAT(computation->root_instruction(),
1392               GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1393 }
1394 
1395 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1396 TEST_F(AlgebraicSimplifierTest, PowExp) {
1397   auto m = CreateNewVerifiedModule();
1398   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1399   HloComputation::Builder builder(TestName());
1400   HloInstruction* param0 = builder.AddInstruction(
1401       HloInstruction::CreateParameter(0, r0f32, "param0"));
1402   HloInstruction* param1 = builder.AddInstruction(
1403       HloInstruction::CreateParameter(1, r0f32, "param1"));
1404   HloInstruction* exp0 = builder.AddInstruction(
1405       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1406   builder.AddInstruction(
1407       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1408 
1409   auto computation = m->AddEntryComputation(builder.Build());
1410 
1411   EXPECT_THAT(computation->root_instruction(),
1412               GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1413 
1414   AlgebraicSimplifier simplifier(default_options_);
1415   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1416 
1417   EXPECT_THAT(
1418       computation->root_instruction(),
1419       GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1420 }
1421 
1422 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1423 TEST_F(AlgebraicSimplifierTest, LnPow) {
1424   auto m = CreateNewVerifiedModule();
1425   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1426   HloComputation::Builder builder(TestName());
1427   HloInstruction* param0 = builder.AddInstruction(
1428       HloInstruction::CreateParameter(0, r0f32, "param0"));
1429   HloInstruction* param1 = builder.AddInstruction(
1430       HloInstruction::CreateParameter(1, r0f32, "param1"));
1431   HloInstruction* pow = builder.AddInstruction(
1432       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1433   builder.AddInstruction(
1434       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1435 
1436   auto computation = m->AddEntryComputation(builder.Build());
1437 
1438   EXPECT_THAT(computation->root_instruction(),
1439               GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1440 
1441   AlgebraicSimplifier simplifier(default_options_);
1442   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1443 
1444   EXPECT_THAT(computation->root_instruction(),
1445               GmockMatch(m::Multiply(m::Log(m::Abs(m::Parameter(0))),
1446                                      m::Parameter(1))));
1447 }
1448 
TEST_F(AlgebraicSimplifierTest,LnSqrt)1449 TEST_F(AlgebraicSimplifierTest, LnSqrt) {
1450   auto m = CreateNewVerifiedModule();
1451   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1452   HloComputation::Builder builder(TestName());
1453   HloInstruction* param0 = builder.AddInstruction(
1454       HloInstruction::CreateParameter(0, r0f32, "param0"));
1455   HloInstruction* sqrt = builder.AddInstruction(
1456       HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
1457   builder.AddInstruction(
1458       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, sqrt));
1459 
1460   auto computation = m->AddEntryComputation(builder.Build());
1461 
1462   EXPECT_THAT(computation->root_instruction(),
1463               GmockMatch(m::Log(m::Sqrt(m::Parameter(0)))));
1464 
1465   AlgebraicSimplifier simplifier(default_options_);
1466   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1467 
1468   EXPECT_THAT(
1469       computation->root_instruction(),
1470       GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::ConstantScalar(0.5))));
1471 }
1472 
TEST_F(AlgebraicSimplifierTest,LnRsqrt)1473 TEST_F(AlgebraicSimplifierTest, LnRsqrt) {
1474   auto m = CreateNewVerifiedModule();
1475   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1476   HloComputation::Builder builder(TestName());
1477   HloInstruction* param0 = builder.AddInstruction(
1478       HloInstruction::CreateParameter(0, r0f32, "param0"));
1479   HloInstruction* rsqrt = builder.AddInstruction(
1480       HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
1481   builder.AddInstruction(
1482       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, rsqrt));
1483 
1484   auto computation = m->AddEntryComputation(builder.Build());
1485 
1486   EXPECT_THAT(computation->root_instruction(),
1487               GmockMatch(m::Log(m::Rsqrt(m::Parameter(0)))));
1488 
1489   AlgebraicSimplifier simplifier(default_options_);
1490   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1491 
1492   EXPECT_THAT(computation->root_instruction(),
1493               GmockMatch(m::Multiply(m::Log(m::Parameter(0)),
1494                                      m::ConstantScalar(-0.5))));
1495 }
1496 
1497 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1498 TEST_F(AlgebraicSimplifierTest, LnExp) {
1499   auto m = CreateNewVerifiedModule();
1500   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1501   HloComputation::Builder builder(TestName());
1502   HloInstruction* param0 = builder.AddInstruction(
1503       HloInstruction::CreateParameter(0, r0f32, "param0"));
1504   HloInstruction* exp0 = builder.AddInstruction(
1505       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1506   builder.AddInstruction(
1507       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1508 
1509   auto computation = m->AddEntryComputation(builder.Build());
1510 
1511   EXPECT_THAT(computation->root_instruction(),
1512               GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1513 
1514   AlgebraicSimplifier simplifier(default_options_);
1515   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1516 
1517   EXPECT_EQ(computation->root_instruction(), param0);
1518 }
1519 
1520 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1521 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1522   auto m = CreateNewVerifiedModule();
1523   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1524   HloComputation::Builder builder(TestName());
1525   HloInstruction* param0 = builder.AddInstruction(
1526       HloInstruction::CreateParameter(0, r0f32, "param0"));
1527   HloInstruction* param1 = builder.AddInstruction(
1528       HloInstruction::CreateParameter(1, r0f32, "param1"));
1529   HloInstruction* exp0 = builder.AddInstruction(
1530       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1531   HloInstruction* exp1 = builder.AddInstruction(
1532       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1533   HloInstruction* div = builder.AddInstruction(
1534       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1535   builder.AddInstruction(
1536       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1537 
1538   auto computation = m->AddEntryComputation(builder.Build());
1539 
1540   EXPECT_THAT(computation->root_instruction(),
1541               GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1542                                           m::Exp(m::Parameter(1))))));
1543 
1544   AlgebraicSimplifier simplifier(default_options_);
1545   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1546 
1547   EXPECT_THAT(computation->root_instruction(),
1548               GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1549 }
1550 
1551 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1552 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1553 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1554   auto m = CreateNewVerifiedModule();
1555   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1556   HloComputation::Builder builder(TestName());
1557   HloInstruction* param0 = builder.AddInstruction(
1558       HloInstruction::CreateParameter(0, r0f32, "param0"));
1559   HloInstruction* zero = builder.AddInstruction(
1560       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1561   builder.AddInstruction(
1562       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1563 
1564   auto computation = m->AddEntryComputation(builder.Build());
1565 
1566   EXPECT_THAT(computation->root_instruction(),
1567               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1568 
1569   AlgebraicSimplifier simplifier(default_options_);
1570   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1571 
1572   HloInstruction* root = computation->root_instruction();
1573   EXPECT_THAT(root, GmockMatch(m::Constant()));
1574   EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1575 }
1576 
1577 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1578 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1579   auto m = CreateNewVerifiedModule();
1580   Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1581   HloComputation::Builder builder(TestName());
1582   HloInstruction* param0 = builder.AddInstruction(
1583       HloInstruction::CreateParameter(0, r1f32, "param0"));
1584   HloInstruction* zero = builder.AddInstruction(
1585       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1586   builder.AddInstruction(
1587       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1588 
1589   auto computation = m->AddEntryComputation(builder.Build());
1590 
1591   EXPECT_THAT(computation->root_instruction(),
1592               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1593 
1594   AlgebraicSimplifier simplifier(default_options_);
1595   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1596 
1597   HloInstruction* root = computation->root_instruction();
1598   EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1599   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1600       << ShapeUtil::HumanString(root->shape());
1601   EXPECT_EQ(root->dimensions().size(), 0);
1602   EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1603   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1604 }
1605 
1606 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1607 TEST_F(AlgebraicSimplifierTest, Pow1) {
1608   auto m = CreateNewVerifiedModule();
1609   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1610   HloComputation::Builder builder(TestName());
1611   HloInstruction* param0 = builder.AddInstruction(
1612       HloInstruction::CreateParameter(0, r0f32, "param0"));
1613   HloInstruction* one = builder.AddInstruction(
1614       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1615   builder.AddInstruction(
1616       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1617 
1618   auto computation = m->AddEntryComputation(builder.Build());
1619 
1620   EXPECT_THAT(computation->root_instruction(),
1621               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1622 
1623   AlgebraicSimplifier simplifier(default_options_);
1624   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1625 
1626   EXPECT_EQ(computation->root_instruction(), param0);
1627 }
1628 
1629 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1630 TEST_F(AlgebraicSimplifierTest, Pow2) {
1631   auto m = CreateNewVerifiedModule();
1632   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1633   HloComputation::Builder builder(TestName());
1634   HloInstruction* param0 = builder.AddInstruction(
1635       HloInstruction::CreateParameter(0, r0f32, "param0"));
1636   HloInstruction* two = builder.AddInstruction(
1637       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1638   builder.AddInstruction(
1639       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1640 
1641   auto computation = m->AddEntryComputation(builder.Build());
1642 
1643   EXPECT_THAT(computation->root_instruction(),
1644               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1645 
1646   AlgebraicSimplifier simplifier(default_options_);
1647   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1648 
1649   EXPECT_THAT(computation->root_instruction(),
1650               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1651 }
1652 
1653 // Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest,Pow3)1654 TEST_F(AlgebraicSimplifierTest, Pow3) {
1655   auto m = CreateNewVerifiedModule();
1656   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1657   HloComputation::Builder builder(TestName());
1658   HloInstruction* param0 = builder.AddInstruction(
1659       HloInstruction::CreateParameter(0, r0f32, "param0"));
1660   HloInstruction* three = builder.AddInstruction(
1661       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
1662   builder.AddInstruction(
1663       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
1664 
1665   auto computation = m->AddEntryComputation(builder.Build());
1666 
1667   EXPECT_THAT(computation->root_instruction(),
1668               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
1669 
1670   AlgebraicSimplifier simplifier(default_options_);
1671   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1672 
1673   EXPECT_THAT(
1674       computation->root_instruction(),
1675       GmockMatch(m::Multiply(m::Parameter(0),
1676                              m::Multiply(m::Parameter(0), m::Parameter(0)))));
1677 }
1678 
1679 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1680 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1681   auto m = CreateNewVerifiedModule();
1682   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1683   HloComputation::Builder builder(TestName());
1684   HloInstruction* param0 = builder.AddInstruction(
1685       HloInstruction::CreateParameter(0, r0f32, "param0"));
1686   HloInstruction* negative_one = builder.AddInstruction(
1687       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1688   builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1689                                                       param0, negative_one));
1690 
1691   auto computation = m->AddEntryComputation(builder.Build());
1692 
1693   EXPECT_THAT(computation->root_instruction(),
1694               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1695 
1696   AlgebraicSimplifier simplifier(default_options_);
1697   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1698 
1699   HloInstruction* root = computation->root_instruction();
1700   EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
1701   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1702 }
1703 
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1704 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1705   auto m = CreateNewVerifiedModule();
1706   auto builder = HloComputation::Builder(TestName());
1707   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1708       0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1709 
1710   HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1711       1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1712 
1713   ConvolutionDimensionNumbers dnums;
1714   dnums.set_input_batch_dimension(0);
1715   dnums.add_input_spatial_dimensions(1);
1716   dnums.set_input_feature_dimension(2);
1717 
1718   dnums.set_output_batch_dimension(0);
1719   dnums.add_output_spatial_dimensions(1);
1720   dnums.set_output_feature_dimension(2);
1721 
1722   dnums.add_kernel_spatial_dimensions(0);
1723   dnums.set_kernel_input_feature_dimension(1);
1724   dnums.set_kernel_output_feature_dimension(2);
1725   Window window;
1726   WindowDimension* dim = window.add_dimensions();
1727   dim->set_size(3);
1728   dim->set_padding_low(0);
1729   dim->set_padding_high(0);
1730   dim->set_stride(1);
1731   dim->set_window_dilation(1);
1732   dim->set_base_dilation(1);
1733   dim->set_window_reversal(false);
1734   // Create add computation.
1735   builder.AddInstruction(HloInstruction::CreateConvolve(
1736       ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1737       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1738   m->AddEntryComputation(builder.Build());
1739   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1740   EXPECT_THAT(m->entry_computation()->root_instruction(),
1741               GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1742   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1743   EXPECT_THAT(m->entry_computation()->root_instruction(),
1744               GmockMatch(m::Broadcast(m::Constant())));
1745 }
1746 
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1747 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1748   auto m = CreateNewVerifiedModule();
1749   auto builder = HloComputation::Builder(TestName());
1750   HloInstruction* param =
1751       builder.AddInstruction(HloInstruction::CreateParameter(
1752           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1753   Window window;
1754   for (int64 i = 0; i < 4; ++i) {
1755     WindowDimension* dim = window.add_dimensions();
1756     // Makes 1x2x3x1 window.
1757     dim->set_size((i % 3) + 1);
1758     dim->set_stride(1);
1759     dim->set_padding_low(0);
1760     dim->set_padding_high(0);
1761     dim->set_window_dilation(1);
1762     dim->set_base_dilation(1);
1763   }
1764   // Create add computation.
1765   HloComputation* add_computation = nullptr;
1766   {
1767     HloComputation::Builder builder(TestName() + ".add");
1768     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1769     HloInstruction* p0 = builder.AddInstruction(
1770         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1771     HloInstruction* p1 = builder.AddInstruction(
1772         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1773     builder.AddInstruction(
1774         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1775     add_computation = m->AddEmbeddedComputation(builder.Build());
1776   }
1777   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1778       ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1779       builder.AddInstruction(
1780           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1781       window, add_computation));
1782   m->AddEntryComputation(builder.Build());
1783   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1784   EXPECT_THAT(m->entry_computation()->root_instruction(),
1785               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1786   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1787   EXPECT_THAT(
1788       m->entry_computation()->root_instruction(),
1789       GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1790 }
1791 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1792 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1793   auto m = CreateNewVerifiedModule();
1794   auto builder = HloComputation::Builder(TestName());
1795   HloInstruction* param =
1796       builder.AddInstruction(HloInstruction::CreateParameter(
1797           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1798   Window window;
1799   for (int64 i = 0; i < 2; ++i) {
1800     WindowDimension* dim = window.add_dimensions();
1801     dim->set_size(1);
1802     dim->set_padding_low(1);
1803     dim->set_padding_high(1);
1804     dim->set_window_dilation(1);
1805     dim->set_base_dilation(1);
1806   }
1807   // Create add computation.
1808   HloComputation* add_computation = nullptr;
1809   {
1810     HloComputation::Builder builder(TestName() + ".add");
1811     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1812     HloInstruction* p0 = builder.AddInstruction(
1813         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1814     HloInstruction* p1 = builder.AddInstruction(
1815         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1816     builder.AddInstruction(
1817         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1818     add_computation = m->AddEmbeddedComputation(builder.Build());
1819   }
1820   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1821       ShapeUtil::MakeShape(F32, {5, 2}), param,
1822       builder.AddInstruction(
1823           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1824       window, add_computation));
1825   m->AddEntryComputation(builder.Build());
1826   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1827   EXPECT_THAT(m->entry_computation()->root_instruction(),
1828               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1829   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1830   EXPECT_THAT(m->entry_computation()->root_instruction(),
1831               GmockMatch(m::Broadcast(m::Constant())));
1832 }
1833 
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)1834 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
1835   auto m = CreateNewVerifiedModule();
1836   auto builder = HloComputation::Builder(TestName());
1837   HloInstruction* param =
1838       builder.AddInstruction(HloInstruction::CreateParameter(
1839           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1840   PaddingConfig padding;
1841   for (int i = 0; i < 2; ++i) {
1842     PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
1843     dimension->set_edge_padding_low(1);
1844     dimension->set_edge_padding_high(1);
1845     dimension->set_interior_padding(0);
1846   }
1847   builder.AddInstruction(HloInstruction::CreatePad(
1848       ShapeUtil::MakeShape(F32, {5, 2}), param,
1849       builder.AddInstruction(
1850           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
1851       padding));
1852   m->AddEntryComputation(builder.Build());
1853   EXPECT_THAT(m->entry_computation()->root_instruction(),
1854               GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
1855   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1856   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1857   EXPECT_THAT(m->entry_computation()->root_instruction(),
1858               GmockMatch(m::Broadcast(m::Constant())));
1859 }
1860 
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)1861 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
1862   auto m = CreateNewVerifiedModule();
1863   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1864 
1865   auto builder = HloComputation::Builder(TestName());
1866   auto op = builder.AddInstruction(HloInstruction::CreateParameter(
1867       0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
1868   auto reshape1 = builder.AddInstruction(
1869       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
1870   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1871       ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
1872   builder.AddInstruction(HloInstruction::CreateReshape(
1873       ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
1874 
1875   auto computation = builder.Build();
1876   m->AddEntryComputation(std::move(computation));
1877 
1878   EXPECT_THAT(m->entry_computation()->root_instruction(),
1879               GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
1880 
1881   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1882   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1883 
1884   EXPECT_THAT(m->entry_computation()->root_instruction(), op);
1885 }
1886 
1887 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)1888 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
1889   auto m = CreateNewVerifiedModule();
1890   HloComputation::Builder builder(TestName());
1891   HloInstruction* input = builder.AddInstruction(
1892       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
1893   builder.AddInstruction(
1894       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
1895 
1896   auto computation = m->AddEntryComputation(builder.Build());
1897 
1898   EXPECT_THAT(computation->root_instruction(),
1899               GmockMatch(m::Convert(m::Op().Is(input))));
1900 
1901   AlgebraicSimplifier simplifier(default_options_);
1902   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1903 
1904   EXPECT_THAT(computation->root_instruction(), input);
1905 }
1906 
1907 // Test that convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
1908 // $TYPE2 and convert(A, $TYP1) is an upcast.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairUpCast)1909 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairUpCast) {
1910   auto m = CreateNewVerifiedModule();
1911   HloComputation::Builder builder(TestName());
1912   HloInstruction* input =
1913       builder.AddInstruction(HloInstruction::CreateParameter(
1914           0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
1915           "param"));
1916   HloInstruction* convert_1 =
1917       builder.AddInstruction(HloInstruction::CreateConvert(
1918           ShapeUtil::ChangeElementType(input->shape(), F32), input));
1919   builder.AddInstruction(HloInstruction::CreateConvert(
1920       ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
1921 
1922   auto computation = m->AddEntryComputation(builder.Build());
1923 
1924   EXPECT_THAT(computation->root_instruction(),
1925               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1926 
1927   AlgebraicSimplifier simplifier(default_options_);
1928   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1929 
1930   EXPECT_THAT(computation->root_instruction(), input);
1931 }
1932 
1933 // Test that convert(convert(A, $TYPE1), $TYPE2) is NOT simplified to A even if
1934 // A is of $TYPE2 since convert(A, $TYP1) is a downcast.
TEST_F(AlgebraicSimplifierTest,DoNotEliminateConvertPairDownCast)1935 TEST_F(AlgebraicSimplifierTest, DoNotEliminateConvertPairDownCast) {
1936   auto m = CreateNewVerifiedModule();
1937   HloComputation::Builder builder(TestName());
1938   HloInstruction* input =
1939       builder.AddInstruction(HloInstruction::CreateParameter(
1940           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1941           "param"));
1942   HloInstruction* convert_1 =
1943       builder.AddInstruction(HloInstruction::CreateConvert(
1944           ShapeUtil::ChangeElementType(input->shape(), F16), input));
1945   builder.AddInstruction(HloInstruction::CreateConvert(
1946       ShapeUtil::ChangeElementType(convert_1->shape(), F32), convert_1));
1947 
1948   auto computation = m->AddEntryComputation(builder.Build());
1949 
1950   EXPECT_THAT(computation->root_instruction(),
1951               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1952   AlgebraicSimplifier simplifier(default_options_);
1953   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1954   EXPECT_THAT(computation->root_instruction(),
1955               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1956 }
1957 
1958 // Test that Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1),
1959 // $TYPE2)), convert(convert(A, $TYPE1), $TYPE2)) is simplified to
1960 // Tuple(convert(A, $TYPE1) , floor(A), A) showing a case where the first
1961 // convert has a fan-out.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairMultiOut)1962 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairMultiOut) {
1963   auto m = CreateNewVerifiedModule();
1964   HloComputation::Builder builder(TestName());
1965   HloInstruction* input =
1966       builder.AddInstruction(HloInstruction::CreateParameter(
1967           0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
1968           "param"));
1969   HloInstruction* convert_1 =
1970       builder.AddInstruction(HloInstruction::CreateConvert(
1971           ShapeUtil::ChangeElementType(input->shape(), F32), input));
1972   HloInstruction* convert_2 =
1973       builder.AddInstruction(HloInstruction::CreateConvert(
1974           ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
1975 
1976   HloInstruction* floor = builder.AddInstruction(HloInstruction::CreateUnary(
1977       convert_2->shape(), HloOpcode::kFloor, convert_2));
1978 
1979   // Collect all the reshapes into a tuple so they are not dead.
1980   builder.AddInstruction(
1981       HloInstruction::CreateTuple({convert_1, convert_2, floor}));
1982 
1983   auto computation = m->AddEntryComputation(builder.Build());
1984   EXPECT_THAT(computation->root_instruction(),
1985               GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(convert_2),
1986                                   m::Op().Is(floor))));
1987 
1988   AlgebraicSimplifier simplifier(default_options_);
1989   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1990 
1991   EXPECT_THAT(computation->root_instruction(),
1992               GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(input),
1993                                   m::Floor(m::Op().Is(input)))));
1994 }
1995 
1996 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)1997 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
1998   auto m = CreateNewVerifiedModule();
1999   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2000   HloComputation::Builder builder(TestName());
2001   HloInstruction* param0 = builder.AddInstruction(
2002       HloInstruction::CreateParameter(0, r0f32, "param0"));
2003   builder.AddInstruction(
2004       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2005 
2006   auto computation = m->AddEntryComputation(builder.Build());
2007 
2008   EXPECT_THAT(computation->root_instruction(),
2009               GmockMatch(m::Copy(m::Parameter(0))));
2010 
2011   AlgebraicSimplifier simplifier(default_options_);
2012   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2013 
2014   EXPECT_THAT(computation->root_instruction(), param0);
2015 }
2016 
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)2017 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
2018   auto m = CreateNewVerifiedModule();
2019   HloComputation::Builder builder(TestName());
2020   HloInstruction* param =
2021       builder.AddInstruction(HloInstruction::CreateParameter(
2022           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2023           "param"));
2024   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2025       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2026       HloOpcode::kCopy, param));
2027   HloInstruction* reshape =
2028       builder.AddInstruction(HloInstruction::CreateReshape(
2029           ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
2030   builder.AddInstruction(HloInstruction::CreateUnary(
2031       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
2032       HloOpcode::kCopy, reshape));
2033   auto computation = m->AddEntryComputation(builder.Build());
2034   EXPECT_THAT(computation->root_instruction(),
2035               GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
2036 
2037   AlgebraicSimplifierOptions options;
2038   options.set_is_layout_sensitive(true);
2039   AlgebraicSimplifier simplifier(options);
2040   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2041   // Verify that the copy of reshape of copy is replaced.
2042   EXPECT_THAT(computation->root_instruction(),
2043               GmockMatch(m::Bitcast(m::Parameter(0))));
2044 }
2045 
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)2046 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
2047   auto m = CreateNewVerifiedModule();
2048   HloComputation::Builder builder(TestName());
2049   HloInstruction* param =
2050       builder.AddInstruction(HloInstruction::CreateParameter(
2051           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2052           "param"));
2053   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2054       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2055       HloOpcode::kCopy, param));
2056   builder.AddInstruction(HloInstruction::CreateReshape(
2057       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
2058 
2059   auto computation = m->AddEntryComputation(builder.Build());
2060   EXPECT_THAT(computation->root_instruction(),
2061               GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
2062 
2063   AlgebraicSimplifierOptions options;
2064   options.set_is_layout_sensitive(true);
2065   AlgebraicSimplifier simplifier(options);
2066   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2067   // Verify that the copy of reshape of copy is replaced.
2068   EXPECT_THAT(computation->root_instruction(),
2069               GmockMatch(m::Bitcast(m::Parameter(0))));
2070 }
2071 
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)2072 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
2073   auto m = CreateNewVerifiedModule();
2074   HloComputation::Builder builder(TestName());
2075   HloInstruction* param =
2076       builder.AddInstruction(HloInstruction::CreateParameter(
2077           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2078           "param"));
2079   builder.AddInstruction(HloInstruction::CreateUnary(
2080       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
2081       HloOpcode::kCopy, param));
2082   auto computation = m->AddEntryComputation(builder.Build());
2083   EXPECT_THAT(computation->root_instruction(),
2084               GmockMatch(m::Copy(m::Parameter(0))));
2085 
2086   AlgebraicSimplifierOptions options(
2087       [](const Shape&, const Shape&) { return false; });
2088   options.set_is_layout_sensitive(true);
2089   AlgebraicSimplifier simplifier1(options);
2090   ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
2091   // Verify that the copy is not replaced.
2092   EXPECT_THAT(computation->root_instruction(),
2093               GmockMatch(m::Copy(m::Parameter(0))));
2094 
2095   AlgebraicSimplifierOptions options2;
2096   options2.set_is_layout_sensitive(true);
2097   AlgebraicSimplifier simplifier2(options2);
2098   EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
2099   // Verify that the copy is replaced.
2100   EXPECT_THAT(computation->root_instruction(),
2101               GmockMatch(m::Bitcast(m::Parameter(0))));
2102 }
2103 
2104 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)2105 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
2106   auto m = CreateNewVerifiedModule();
2107   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2108   HloComputation::Builder builder(TestName());
2109   HloInstruction* param0 = builder.AddInstruction(
2110       HloInstruction::CreateParameter(0, r1f32, "param0"));
2111   builder.AddInstruction(
2112       HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
2113 
2114   auto computation = m->AddEntryComputation(builder.Build());
2115 
2116   EXPECT_THAT(computation->root_instruction(),
2117               GmockMatch(m::Concatenate(m::Parameter(0))));
2118 
2119   AlgebraicSimplifier simplifier(default_options_);
2120   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2121 
2122   EXPECT_THAT(computation->root_instruction(), param0);
2123 }
2124 
TEST_F(AlgebraicSimplifierTest,SliceReverse)2125 TEST_F(AlgebraicSimplifierTest, SliceReverse) {
2126   const char* const hlo_string = R"(
2127 HloModule module
2128 
2129 ENTRY test {
2130   param = f32[6,7,32] parameter(0)
2131   constant = f32[] constant(0)
2132   pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2133   rev = f32[8,7,32] reverse(pad), dimensions={0,2}
2134   slice = f32[1,7,32] slice(rev), slice={[2:3:1], [0:7:1], [0:32:1]}
2135   ROOT tuple = (f32[1,7,32]) tuple(slice)
2136 })";
2137 
2138   TF_ASSERT_OK_AND_ASSIGN(auto module,
2139                           ParseAndReturnVerifiedModule(hlo_string));
2140   AlgebraicSimplifier simplifier(default_options_);
2141   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2142   HloComputation* computation = module->entry_computation();
2143   EXPECT_THAT(computation->root_instruction(),
2144               GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2145   const HloInstruction* slice =
2146       computation->root_instruction()->operand(0)->operand(0);
2147   EXPECT_TRUE(
2148       ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 7, 32})));
2149   // slice start,limit of 0th and 2nd dimensions are changed
2150   // while 1st dimension's slice start, limit remains the same since
2151   // it is not reversed.
2152   EXPECT_EQ(slice->slice_starts(0), 5);
2153   EXPECT_EQ(slice->slice_limits(0), 6);
2154   EXPECT_EQ(slice->slice_starts(1), 0);
2155   EXPECT_EQ(slice->slice_limits(1), 7);
2156   EXPECT_EQ(slice->slice_starts(2), 0);
2157   EXPECT_EQ(slice->slice_limits(2), 32);
2158   EXPECT_EQ(slice->slice_strides(0), 1);
2159   EXPECT_EQ(slice->slice_strides(1), 1);
2160   EXPECT_EQ(slice->slice_strides(2), 1);
2161 }
2162 
TEST_F(AlgebraicSimplifierTest,SliceReverseNonUnitEvenOddStrides)2163 TEST_F(AlgebraicSimplifierTest, SliceReverseNonUnitEvenOddStrides) {
2164   const char* const hlo_string = R"(
2165 HloModule module
2166 
2167 ENTRY test {
2168   param = f32[6,7,32] parameter(0)
2169   constant = f32[] constant(0)
2170   pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2171   rev = f32[8,7,32] reverse(pad), dimensions={0,1,2}
2172   slice = f32[1,2,7] slice(rev), slice={[2:3:2], [0:7:4], [0:32:5]}
2173   ROOT tuple = (f32[1,2,7]) tuple(slice)
2174 })";
2175   TF_ASSERT_OK_AND_ASSIGN(auto module,
2176                           ParseAndReturnVerifiedModule(hlo_string));
2177 
2178   AlgebraicSimplifier simplifier(default_options_);
2179   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2180   HloComputation* computation = module->entry_computation();
2181   EXPECT_THAT(computation->root_instruction(),
2182               GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2183   const HloInstruction* slice =
2184       computation->root_instruction()->operand(0)->operand(0);
2185   EXPECT_TRUE(
2186       ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 2, 7})));
2187   // slice start,limit of all dimensions are changed
2188   EXPECT_EQ(slice->slice_starts(0), 5);
2189   EXPECT_EQ(slice->slice_limits(0), 6);
2190   EXPECT_EQ(slice->slice_starts(1), 2);
2191   EXPECT_EQ(slice->slice_limits(1), 7);
2192   EXPECT_EQ(slice->slice_starts(2), 1);
2193   EXPECT_EQ(slice->slice_limits(2), 32);
2194   EXPECT_EQ(slice->slice_strides(0), 2);
2195   EXPECT_EQ(slice->slice_strides(1), 4);
2196   EXPECT_EQ(slice->slice_strides(2), 5);
2197 }
2198 
2199 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)2200 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
2201   auto m = CreateNewVerifiedModule();
2202   const int kParamLength = 100;
2203   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2204   HloComputation::Builder builder(TestName());
2205   HloInstruction* param0 = builder.AddInstruction(
2206       HloInstruction::CreateParameter(0, r1f32, "param0"));
2207   HloInstruction* param1 = builder.AddInstruction(
2208       HloInstruction::CreateParameter(1, r1f32, "param1"));
2209   HloInstruction* empty_literal = builder.AddInstruction(
2210       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2211   HloInstruction* empty_slice =
2212       builder.AddInstruction(HloInstruction::CreateSlice(
2213           ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
2214   Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
2215   builder.AddInstruction(HloInstruction::CreateConcatenate(
2216       result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
2217 
2218   auto computation = m->AddEntryComputation(builder.Build());
2219 
2220   EXPECT_THAT(computation->root_instruction(),
2221               GmockMatch(m::Concatenate(
2222                   m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
2223                   m::Op().Is(empty_slice), m::Parameter(1))));
2224 
2225   AlgebraicSimplifier simplifier(default_options_);
2226   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2227 
2228   EXPECT_THAT(computation->root_instruction(),
2229               GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
2230                                         m::Parameter(1))));
2231 }
2232 
2233 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)2234 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
2235   auto m = CreateNewVerifiedModule();
2236   const int kParamLength = 100;
2237   Shape r3f32 =
2238       ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
2239   HloComputation::Builder builder(TestName());
2240   HloInstruction* param0 = builder.AddInstruction(
2241       HloInstruction::CreateParameter(0, r3f32, "param0"));
2242   HloInstruction* param1 = builder.AddInstruction(
2243       HloInstruction::CreateParameter(1, r3f32, "param1"));
2244   HloInstruction* param2 = builder.AddInstruction(
2245       HloInstruction::CreateParameter(2, r3f32, "param2"));
2246   Shape concat_shape =
2247       ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
2248   HloInstruction* Concatenate =
2249       builder.AddInstruction(HloInstruction::CreateConcatenate(
2250           concat_shape, {param0, param1, param2}, 1));
2251   HloComputation* add_computation = nullptr;
2252   {
2253     HloComputation::Builder builder(TestName() + ".add");
2254     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2255     HloInstruction* p0 = builder.AddInstruction(
2256         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2257     HloInstruction* p1 = builder.AddInstruction(
2258         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2259     builder.AddInstruction(
2260         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2261     add_computation = m->AddEmbeddedComputation(builder.Build());
2262   }
2263   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
2264   Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
2265 
2266   HloInstruction* zero = builder.AddInstruction(
2267       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
2268   builder.AddInstruction(HloInstruction::CreateReduce(
2269       reduce_shape, Concatenate, zero, {1, 2}, add_computation));
2270 
2271   auto computation = m->AddEntryComputation(builder.Build());
2272 
2273   AlgebraicSimplifier simplifier(default_options_);
2274   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2275 
2276   EXPECT_THAT(
2277       computation->root_instruction(),
2278       GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
2279                                m::Reduce(m::Parameter(1), m::Op().Is(zero))),
2280                         m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
2281 }
2282 
2283 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)2284 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
2285   auto m = CreateNewVerifiedModule();
2286   const int kParamLength = 100;
2287   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2288   HloComputation::Builder builder(TestName());
2289   HloInstruction* param0 = builder.AddInstruction(
2290       HloInstruction::CreateParameter(0, r1f32, "param0"));
2291   HloInstruction* empty_literal = builder.AddInstruction(
2292       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2293   HloInstruction* empty_slice =
2294       builder.AddInstruction(HloInstruction::CreateSlice(
2295           ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
2296   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
2297   builder.AddInstruction(HloInstruction::CreateConcatenate(
2298       result_shape, {empty_literal, empty_slice}, 0));
2299 
2300   auto computation = m->AddEntryComputation(builder.Build());
2301 
2302   EXPECT_THAT(computation->root_instruction(),
2303               GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
2304                                         m::Op().Is(empty_slice))));
2305 
2306   AlgebraicSimplifier simplifier(default_options_);
2307   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2308 
2309   EXPECT_EQ(computation->root_instruction(), empty_literal);
2310 }
2311 
2312 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)2313 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
2314   auto m = CreateNewVerifiedModule();
2315   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2316   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2317   HloComputation::Builder builder(TestName());
2318   HloInstruction* param0 = builder.AddInstruction(
2319       HloInstruction::CreateParameter(0, r1f32, "param0"));
2320   HloInstruction* param1 = builder.AddInstruction(
2321       HloInstruction::CreateParameter(1, r0f32, "param1"));
2322   HloInstruction* broadcast = builder.AddInstruction(
2323       HloInstruction::CreateBroadcast(r1f32, param1, {}));
2324   builder.AddInstruction(HloInstruction::CreateConcatenate(
2325       ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
2326 
2327   auto computation = m->AddEntryComputation(builder.Build());
2328 
2329   AlgebraicSimplifier simplifier(default_options_);
2330   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2331   EXPECT_THAT(computation->root_instruction(),
2332               GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
2333 }
2334 
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)2335 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
2336   auto m = CreateNewVerifiedModule();
2337   Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
2338   Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90});
2339   HloComputation::Builder builder(TestName());
2340   HloInstruction* param0 = builder.AddInstruction(
2341       HloInstruction::CreateParameter(0, r2f32, "param0"));
2342   HloInstruction* param1 = builder.AddInstruction(
2343       HloInstruction::CreateParameter(1, r2f32, "param1"));
2344 
2345   HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2346       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
2347       /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
2348 
2349   // Cannot merge 'slice0' and 'slice1' because of different start indices in
2350   // dimension 0.
2351   HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2352       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
2353       /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
2354 
2355   // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
2356   HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2357       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
2358       /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
2359 
2360   // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
2361   HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2362       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
2363       /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
2364 
2365   // Can merge 'slice3' and 'slice4'.
2366   HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
2367       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
2368       /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
2369 
2370   // Can merge 'slice4' and 'slice5'.
2371   HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
2372       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
2373       /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
2374 
2375   // Cannot merge 'slice5' and 'slice6' because of overlap.
2376   HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
2377       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
2378       /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
2379 
2380   // Cannot merge 'slice6' and 'slice7' because of slicing from a different
2381   // parameter.
2382   HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
2383       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
2384       /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
2385   // Can merge 'slice7' and 'slice8'.
2386   HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice(
2387       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89},
2388       /*limit_indices=*/{100, 99}, /*strides=*/{1, 1}));
2389 
2390   builder.AddInstruction(HloInstruction::CreateConcatenate(
2391       concat_shape,
2392       {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8},
2393       1));
2394   auto computation = m->AddEntryComputation(builder.Build());
2395 
2396   AlgebraicSimplifier simplifier(default_options_);
2397   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2398   auto s = m::Slice(m::Parameter(0));
2399   EXPECT_THAT(
2400       computation->root_instruction(),
2401       GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
2402   // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
2403   // shape should have dimensions {50, 30}.
2404   EXPECT_TRUE(
2405       ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
2406                        ShapeUtil::MakeShape(F32, {50, 30})));
2407   EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
2408 
2409   // The operand 6 should be  merge of 'slice7' and 'slice8', so its
2410   // shape should have dimensions {50, 20}
2411   EXPECT_TRUE(
2412       ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(),
2413                        ShapeUtil::MakeShape(F32, {50, 20})));
2414 }
2415 
2416 // Test that a simplification which changes layouts is not performed if layout
2417 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)2418 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
2419   auto m = CreateNewVerifiedModule();
2420   HloComputation::Builder builder(TestName());
2421   HloInstruction* param0 =
2422       builder.AddInstruction(HloInstruction::CreateParameter(
2423           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2424   HloInstruction* copy = builder.AddInstruction(
2425       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2426 
2427   auto computation = m->AddEntryComputation(builder.Build());
2428 
2429   // Set to different layouts.
2430   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2431   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
2432 
2433   EXPECT_THAT(computation->root_instruction(),
2434               GmockMatch(m::Copy(m::Parameter(0))));
2435 
2436   AlgebraicSimplifierOptions options;
2437   options.set_is_layout_sensitive(true);
2438   AlgebraicSimplifier simplifier(options);
2439   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2440 
2441   // Copy has not been removed.
2442   EXPECT_THAT(computation->root_instruction(),
2443               GmockMatch(m::Copy(m::Parameter(0))));
2444 }
2445 
2446 // Test that a simplification which preserves layouts is performed if layout
2447 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)2448 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
2449   auto m = CreateNewVerifiedModule();
2450   HloComputation::Builder builder(TestName());
2451   HloInstruction* param0 =
2452       builder.AddInstruction(HloInstruction::CreateParameter(
2453           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2454   HloInstruction* copy = builder.AddInstruction(
2455       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2456 
2457   auto computation = m->AddEntryComputation(builder.Build());
2458 
2459   // Set to same layouts.
2460   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2461   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2462 
2463   EXPECT_THAT(computation->root_instruction(),
2464               GmockMatch(m::Copy(m::Parameter(0))));
2465 
2466   AlgebraicSimplifierOptions options;
2467   options.set_is_layout_sensitive(true);
2468   AlgebraicSimplifier simplifier(options);
2469   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2470 
2471   // Copy has been removed.
2472   EXPECT_THAT(computation->root_instruction(), param0);
2473 }
2474 
2475 // Test that a reshape which could be replaced with a bitcast is not if
2476 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)2477 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
2478   auto m = CreateNewVerifiedModule();
2479   HloComputation::Builder builder(TestName());
2480   HloInstruction* param0 =
2481       builder.AddInstruction(HloInstruction::CreateParameter(
2482           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2483   HloInstruction* reshape =
2484       builder.AddInstruction(HloInstruction::CreateReshape(
2485           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2486 
2487   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2488   *reshape->mutable_shape()->mutable_layout() =
2489       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2490 
2491   auto computation = m->AddEntryComputation(builder.Build());
2492 
2493   EXPECT_THAT(computation->root_instruction(),
2494               GmockMatch(m::Reshape(m::Parameter(0))));
2495 
2496   AlgebraicSimplifierOptions options(
2497       [](const Shape&, const Shape&) { return false; });
2498   options.set_is_layout_sensitive(true);
2499   AlgebraicSimplifier simplifier(options);
2500   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2501 
2502   // Reshape is not replaced with a bitcast.
2503   EXPECT_THAT(computation->root_instruction(),
2504               GmockMatch(m::Reshape(m::Parameter(0))));
2505 }
2506 
2507 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)2508 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
2509   auto m = CreateNewVerifiedModule();
2510   HloComputation::Builder builder(TestName());
2511   HloInstruction* zero = builder.AddInstruction(
2512       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2513   HloInstruction* one = builder.AddInstruction(
2514       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
2515   HloInstruction* rng0 = builder.AddInstruction(
2516       HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
2517                                 RandomDistribution::RNG_UNIFORM, {zero, one}));
2518 
2519   HloInstruction* transpose = builder.AddInstruction(
2520       HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
2521   Shape reshape_shape = builder
2522                             .AddInstruction(HloInstruction::CreateReshape(
2523                                 ShapeUtil::MakeShape(F32, {4}), transpose))
2524                             ->shape();
2525 
2526   auto computation = m->AddEntryComputation(builder.Build());
2527 
2528   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2529   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2530 
2531   // Verify that reshape(transpose(rng)) is replace by a single rng of the
2532   // same shape as the reshape.
2533   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
2534   EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
2535                                reshape_shape));
2536 }
2537 
2538 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2539 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2540   auto m = CreateNewVerifiedModule();
2541   HloComputation::Builder builder(TestName());
2542   HloInstruction* param0 =
2543       builder.AddInstruction(HloInstruction::CreateParameter(
2544           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2545   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2546 
2547   // Reshape which can be transformed into a bitcast.
2548   HloInstruction* transformable_reshape =
2549       builder.AddInstruction(HloInstruction::CreateReshape(
2550           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2551   *transformable_reshape->mutable_shape()->mutable_layout() =
2552       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2553 
2554   // Reshape does not just add degenerate dimensions.
2555   HloInstruction* dimensions_wrong_reshape =
2556       builder.AddInstruction(HloInstruction::CreateReshape(
2557           ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2558   *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2559       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2560 
2561   // Reshape has wrong layout.
2562   HloInstruction* layout_wrong_reshape =
2563       builder.AddInstruction(HloInstruction::CreateReshape(
2564           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2565   *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2566       LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2567 
2568   // Collect all the reshapes into a tuple so they are not dead.
2569   builder.AddInstruction(HloInstruction::CreateTuple(
2570       {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2571 
2572   auto computation = m->AddEntryComputation(builder.Build());
2573 
2574   EXPECT_THAT(computation->root_instruction(),
2575               GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2576                                   m::Op().Is(dimensions_wrong_reshape),
2577                                   m::Op().Is(layout_wrong_reshape))));
2578 
2579   AlgebraicSimplifierOptions options;
2580   options.set_is_layout_sensitive(true);
2581   AlgebraicSimplifier simplifier(options);
2582   simplifier.Run(m.get()).ValueOrDie();
2583 
2584   // Verify that only the first reshape is replaced.
2585   EXPECT_THAT(
2586       computation->root_instruction(),
2587       GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
2588                           m::Op().Is(layout_wrong_reshape))));
2589 }
2590 
2591 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2592 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)2593 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
2594   auto m = CreateNewVerifiedModule();
2595   HloComputation::Builder builder(TestName());
2596 
2597   // This add (param0 + 0) can be simplified.
2598   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2599   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2600       shape, HloOpcode::kAdd,
2601       builder.AddInstruction(
2602           HloInstruction::CreateParameter(0, shape, "param0")),
2603       builder.AddInstruction(HloInstruction::CreateConstant(
2604           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2605 
2606   builder.AddInstruction(
2607       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
2608 
2609   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2610   m->AddEntryComputation(builder.Build());
2611   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2612 }
2613 
2614 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2615 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)2616 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
2617   auto m = CreateNewVerifiedModule();
2618   HloComputation::Builder builder(TestName());
2619 
2620   // This add (param0 + 0) can be simplified.
2621   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2622   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2623       shape, HloOpcode::kAdd,
2624       builder.AddInstruction(
2625           HloInstruction::CreateParameter(0, shape, "param0")),
2626       builder.AddInstruction(HloInstruction::CreateConstant(
2627           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2628 
2629   builder.AddInstruction(
2630       HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
2631                                       /*broadcast_dimensions=*/{0, 1}));
2632 
2633   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2634   m->AddEntryComputation(builder.Build());
2635   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2636 }
2637 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)2638 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
2639   auto m = CreateNewVerifiedModule();
2640   HloComputation::Builder builder(TestName());
2641   HloInstruction* param =
2642       builder.AddInstruction(HloInstruction::CreateParameter(
2643           0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
2644   *param->mutable_shape()->mutable_layout() =
2645       LayoutUtil::MakeLayout({1, 2, 0, 3});
2646 
2647   HloInstruction* transpose =
2648       builder.AddInstruction(HloInstruction::CreateTranspose(
2649           ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
2650   *transpose->mutable_shape()->mutable_layout() =
2651       LayoutUtil::MakeLayout({0, 1, 2, 3});
2652 
2653   auto computation = m->AddEntryComputation(builder.Build());
2654 
2655   EXPECT_THAT(computation->root_instruction(),
2656               GmockMatch(m::Transpose(m::Parameter(0))));
2657 
2658   AlgebraicSimplifierOptions options;
2659   options.set_is_layout_sensitive(true);
2660   AlgebraicSimplifier simplifier(options);
2661   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2662 
2663   // Verify that the transpose is replaced.
2664   EXPECT_THAT(computation->root_instruction(),
2665               GmockMatch(m::Bitcast(m::Parameter(0))));
2666 }
2667 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)2668 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
2669   auto m = CreateNewVerifiedModule();
2670   HloComputation::Builder builder(TestName());
2671   HloInstruction* param =
2672       builder.AddInstruction(HloInstruction::CreateParameter(
2673           0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
2674   *param->mutable_shape()->mutable_layout() =
2675       LayoutUtil::MakeLayout({1, 2, 3, 0});
2676 
2677   HloInstruction* transpose =
2678       builder.AddInstruction(HloInstruction::CreateTranspose(
2679           ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
2680   *transpose->mutable_shape()->mutable_layout() =
2681       LayoutUtil::MakeLayout({3, 1, 2, 0});
2682 
2683   auto computation = m->AddEntryComputation(builder.Build());
2684 
2685   EXPECT_THAT(computation->root_instruction(),
2686               GmockMatch(m::Transpose(m::Parameter(0))));
2687 
2688   AlgebraicSimplifierOptions options;
2689   options.set_is_layout_sensitive(true);
2690   // Don't replace transposes with bitcasts.
2691   options.set_replace_transpose_with_bitcast(false);
2692   AlgebraicSimplifier simplifier_no_replace(options);
2693   ASSERT_FALSE(simplifier_no_replace.Run(m.get()).ValueOrDie());
2694 
2695   // Replace transposes with bitcasts if possible.
2696   options.set_replace_transpose_with_bitcast(true);
2697   AlgebraicSimplifier simplifier(options);
2698   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2699 
2700   // Verify that the transpose is replaced.
2701   EXPECT_THAT(computation->root_instruction(),
2702               GmockMatch(m::Bitcast(m::Parameter(0))));
2703 }
2704 
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)2705 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
2706   auto m = CreateNewVerifiedModule();
2707   HloComputation::Builder builder(TestName());
2708   HloInstruction* param0 =
2709       builder.AddInstruction(HloInstruction::CreateParameter(
2710           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2711 
2712   HloInstruction* reshape1 =
2713       builder.AddInstruction(HloInstruction::CreateReshape(
2714           ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
2715 
2716   builder.AddInstruction(HloInstruction::CreateReshape(
2717       ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
2718 
2719   auto computation = m->AddEntryComputation(builder.Build());
2720 
2721   EXPECT_THAT(computation->root_instruction(),
2722               GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
2723 
2724   AlgebraicSimplifier simplifier(default_options_);
2725   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2726 
2727   EXPECT_THAT(computation->root_instruction(),
2728               GmockMatch(m::Reshape(m::Parameter(0))));
2729 }
2730 
TEST_F(AlgebraicSimplifierTest,CopiesMerged)2731 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
2732   auto m = CreateNewVerifiedModule();
2733   HloComputation::Builder builder(TestName());
2734   HloInstruction* param0 =
2735       builder.AddInstruction(HloInstruction::CreateParameter(
2736           0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
2737           "param0"));
2738 
2739   HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2740       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
2741       HloOpcode::kCopy, param0));
2742 
2743   builder.AddInstruction(HloInstruction::CreateUnary(
2744       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
2745       HloOpcode::kCopy, copy1));
2746 
2747   auto computation = m->AddEntryComputation(builder.Build());
2748 
2749   EXPECT_THAT(computation->root_instruction(),
2750               GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
2751 
2752   AlgebraicSimplifierOptions options;
2753   options.set_is_layout_sensitive(true);
2754   AlgebraicSimplifier simplifier(options);
2755   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2756 
2757   EXPECT_THAT(computation->root_instruction(),
2758               GmockMatch(m::Copy(m::Parameter(0))));
2759 }
2760 
TEST_F(AlgebraicSimplifierTest,TransposesMerged)2761 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
2762   auto m = CreateNewVerifiedModule();
2763   HloComputation::Builder builder(TestName());
2764   HloInstruction* param0 =
2765       builder.AddInstruction(HloInstruction::CreateParameter(
2766           0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
2767 
2768   HloInstruction* transpose1 =
2769       builder.AddInstruction(HloInstruction::CreateTranspose(
2770           ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
2771 
2772   builder.AddInstruction(HloInstruction::CreateTranspose(
2773       ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
2774 
2775   auto computation = m->AddEntryComputation(builder.Build());
2776 
2777   EXPECT_THAT(computation->root_instruction(),
2778               GmockMatch(m::Transpose(m::Op().Is(transpose1))));
2779 
2780   AlgebraicSimplifier simplifier(default_options_);
2781   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2782 
2783   EXPECT_THAT(computation->root_instruction(),
2784               GmockMatch(m::Transpose(m::Parameter(0))));
2785   EXPECT_EQ(std::vector<int64>({2, 1, 0}),
2786             computation->root_instruction()->dimensions());
2787 }
2788 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcast)2789 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
2790   const char* hlo_string = R"(
2791     HloModule module
2792 
2793     ENTRY test {
2794       p0 = f32[10,20] parameter(0)
2795       b = f32[10,30,20] broadcast(p0), dimensions={0,2}
2796       ROOT s = f32[5,5,5] slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
2797     }
2798   )";
2799   TF_ASSERT_OK_AND_ASSIGN(auto module,
2800                           ParseAndReturnVerifiedModule(hlo_string));
2801 
2802   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2803   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2804   auto root = module->entry_computation()->root_instruction();
2805   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
2806 }
2807 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastPreserveLayout)2808 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
2809   const char* hlo_string = R"(
2810     HloModule module
2811 
2812     ENTRY test {
2813       p0 = f32[10,20] parameter(0)
2814       b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
2815       ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
2816     }
2817   )";
2818   TF_ASSERT_OK_AND_ASSIGN(auto module,
2819                           ParseAndReturnVerifiedModule(hlo_string));
2820 
2821   const Shape original_slice_shape =
2822       module->entry_computation()->root_instruction()->shape();
2823   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2824   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2825   auto root = module->entry_computation()->root_instruction();
2826   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
2827   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
2828 }
2829 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcast)2830 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
2831   const char* hlo_string = R"(
2832     HloModule module
2833 
2834     ENTRY test {
2835       p0 = f32[10,20] parameter(0)
2836       i0 = s32[] parameter(1)
2837       i1 = s32[] parameter(2)
2838       i2 = s32[] parameter(3)
2839       b = f32[10,30,20] broadcast(p0), dimensions={0,2}
2840       ROOT ds = f32[5,5,5] dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
2841     }
2842   )";
2843   TF_ASSERT_OK_AND_ASSIGN(auto module,
2844                           ParseAndReturnVerifiedModule(hlo_string));
2845 
2846   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2847   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2848   auto root = module->entry_computation()->root_instruction();
2849   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
2850                         m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
2851 }
2852 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcastPreserveLayout)2853 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
2854   const char* hlo_string = R"(
2855     HloModule module
2856 
2857     ENTRY test {
2858       p0 = f32[10,20] parameter(0)
2859       i0 = s32[] parameter(1)
2860       i1 = s32[] parameter(2)
2861       i2 = s32[] parameter(3)
2862       b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
2863       ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
2864     }
2865   )";
2866   TF_ASSERT_OK_AND_ASSIGN(auto module,
2867                           ParseAndReturnVerifiedModule(hlo_string));
2868 
2869   const Shape original_dynslice_shape =
2870       module->entry_computation()->root_instruction()->shape();
2871   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2872   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2873   auto root = module->entry_computation()->root_instruction();
2874   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
2875                         m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
2876   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
2877 }
2878 
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)2879 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
2880   const char* hlo_string = R"(
2881     HloModule module
2882 
2883     ENTRY test {
2884       param = f32[10] parameter(0)
2885       reshaped = f32[1,1,10] reshape(f32[10] param)
2886       transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
2887       ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
2888     }
2889   )";
2890   TF_ASSERT_OK_AND_ASSIGN(auto module,
2891                           ParseAndReturnVerifiedModule(hlo_string));
2892 
2893   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2894   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2895   auto root = module->entry_computation()->root_instruction();
2896   EXPECT_THAT(root, GmockMatch(m::Parameter()));
2897 }
2898 
2899 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)2900 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
2901   auto m = CreateNewVerifiedModule();
2902   HloComputation::Builder builder(TestName());
2903   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2904       0, ShapeUtil::MakeShape(F32, {5}), "param0"));
2905   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
2906       ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
2907   builder.AddInstruction(HloInstruction::CreateBroadcast(
2908       ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
2909 
2910   auto computation = m->AddEntryComputation(builder.Build());
2911 
2912   EXPECT_THAT(computation->root_instruction(),
2913               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2914 
2915   AlgebraicSimplifier simplifier(default_options_);
2916   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2917 
2918   EXPECT_THAT(computation->root_instruction(),
2919               GmockMatch(m::Broadcast(m::Parameter(0))));
2920 }
2921 
2922 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)2923 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
2924   auto m = CreateNewVerifiedModule();
2925   HloComputation::Builder builder(TestName());
2926   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2927       0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
2928   auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
2929       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
2930   builder.AddInstruction(HloInstruction::CreateReshape(
2931       ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
2932 
2933   auto computation = m->AddEntryComputation(builder.Build());
2934 
2935   EXPECT_THAT(computation->root_instruction(),
2936               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2937 
2938   AlgebraicSimplifier simplifier(default_options_);
2939   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2940 
2941   EXPECT_THAT(computation->root_instruction(),
2942               GmockMatch(m::Broadcast(m::Parameter(0))));
2943 }
2944 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)2945 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
2946   auto m = CreateNewVerifiedModule();
2947   HloComputation::Builder builder(TestName());
2948   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2949       0, ShapeUtil::MakeShape(F32, {1}), "param"));
2950   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2951       ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
2952   builder.AddInstruction(
2953       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
2954 
2955   auto computation = m->AddEntryComputation(builder.Build());
2956 
2957   EXPECT_THAT(computation->root_instruction(),
2958               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2959 
2960   AlgebraicSimplifier simplifier(default_options_);
2961   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2962 
2963   EXPECT_THAT(computation->root_instruction(),
2964               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2965 }
2966 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)2967 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
2968   auto m = CreateNewVerifiedModule();
2969   HloComputation::Builder builder(TestName());
2970   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2971       0, ShapeUtil::MakeShape(F32, {4}), "param"));
2972   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2973       ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
2974   builder.AddInstruction(HloInstruction::CreateReshape(
2975       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
2976 
2977   HloComputation* computation = m->AddEntryComputation(builder.Build());
2978 
2979   EXPECT_THAT(computation->root_instruction(),
2980               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2981 
2982   AlgebraicSimplifier simplifier(default_options_);
2983   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2984 
2985   EXPECT_THAT(computation->root_instruction(),
2986               GmockMatch(m::Broadcast(m::Parameter(0))));
2987   EXPECT_THAT(computation->root_instruction()->dimensions(),
2988               ::testing::ElementsAre(3));
2989 }
2990 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)2991 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
2992   auto m = CreateNewVerifiedModule();
2993   HloComputation::Builder builder(TestName());
2994   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2995       0, ShapeUtil::MakeShape(F32, {1}), "param"));
2996   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2997       ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
2998   builder.AddInstruction(HloInstruction::CreateReshape(
2999       ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
3000 
3001   HloComputation* computation = m->AddEntryComputation(builder.Build());
3002 
3003   EXPECT_THAT(computation->root_instruction(),
3004               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3005 
3006   AlgebraicSimplifier simplifier(default_options_);
3007   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3008 
3009   EXPECT_THAT(computation->root_instruction(),
3010               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3011   EXPECT_EQ(0, computation->root_instruction()->dimensions().size());
3012 }
3013 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)3014 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
3015   auto m = CreateNewVerifiedModule();
3016   HloComputation::Builder builder(TestName());
3017   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3018       0, ShapeUtil::MakeShape(F32, {4}), "param"));
3019   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3020       ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
3021   builder.AddInstruction(HloInstruction::CreateReshape(
3022       ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
3023 
3024   HloComputation* computation = m->AddEntryComputation(builder.Build());
3025 
3026   EXPECT_THAT(computation->root_instruction(),
3027               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3028 
3029   AlgebraicSimplifier simplifier(default_options_);
3030   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3031 
3032   EXPECT_THAT(computation->root_instruction(),
3033               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3034 }
3035 
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)3036 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
3037   auto m = CreateNewVerifiedModule();
3038   HloComputation::Builder builder(TestName());
3039   auto iota = builder.AddInstruction(HloInstruction::CreateIota(
3040       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
3041   Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
3042   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3043 
3044   auto computation = m->AddEntryComputation(builder.Build());
3045 
3046   EXPECT_THAT(computation->root_instruction(),
3047               GmockMatch(m::Reshape(m::Iota())));
3048 
3049   AlgebraicSimplifier simplifier(default_options_);
3050   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3051 
3052   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3053   EXPECT_TRUE(
3054       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3055 }
3056 
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadix)3057 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
3058   auto m = CreateNewVerifiedModule();
3059   HloComputation::Builder builder(TestName());
3060   auto iota = builder.AddInstruction(
3061       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
3062   Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
3063   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3064 
3065   auto computation = m->AddEntryComputation(builder.Build());
3066 
3067   EXPECT_THAT(computation->root_instruction(),
3068               GmockMatch(m::Reshape(m::Iota())));
3069 
3070   AlgebraicSimplifier simplifier(default_options_);
3071   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3072 
3073   EXPECT_THAT(computation->root_instruction(),
3074               GmockMatch(m::Add(
3075                   m::Iota(),
3076                   m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3077   EXPECT_TRUE(
3078       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3079 }
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadixExtraDims)3080 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
3081   auto m = CreateNewVerifiedModule();
3082   HloComputation::Builder builder(TestName());
3083   auto iota = builder.AddInstruction(
3084       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
3085   Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
3086   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3087 
3088   auto computation = m->AddEntryComputation(builder.Build());
3089 
3090   EXPECT_THAT(computation->root_instruction(),
3091               GmockMatch(m::Reshape(m::Iota())));
3092 
3093   AlgebraicSimplifier simplifier(default_options_);
3094   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3095 
3096   EXPECT_THAT(
3097       computation->root_instruction(),
3098       GmockMatch(m::Add(
3099           m::Add(m::Iota(),
3100                  m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
3101           m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3102   EXPECT_TRUE(
3103       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3104 }
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)3105 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
3106   auto m = CreateNewVerifiedModule();
3107   HloComputation::Builder builder(TestName());
3108   auto iota = builder.AddInstruction(
3109       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
3110   auto result_shape = iota->shape();
3111 
3112   auto computation = m->AddEntryComputation(builder.Build());
3113 
3114   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3115 
3116   AlgebraicSimplifier simplifier(default_options_);
3117   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3118 
3119   auto root = computation->root_instruction();
3120   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3121   EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
3122   EXPECT_TRUE(
3123       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3124 }
3125 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)3126 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
3127   auto m = CreateNewVerifiedModule();
3128   HloComputation::Builder builder(TestName());
3129   auto iota = builder.AddInstruction(
3130       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
3131   builder.AddInstruction(
3132       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
3133 
3134   auto computation = m->AddEntryComputation(builder.Build());
3135 
3136   EXPECT_THAT(computation->root_instruction(),
3137               GmockMatch(m::Reshape(m::Iota())));
3138 
3139   AlgebraicSimplifier simplifier(default_options_);
3140   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3141 
3142   EXPECT_THAT(computation->root_instruction(),
3143               GmockMatch(m::Reshape(m::Iota())));
3144 }
3145 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)3146 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
3147   auto m = CreateNewVerifiedModule();
3148   HloComputation::Builder builder(TestName());
3149   auto iota = builder.AddInstruction(
3150       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
3151   builder.AddInstruction(HloInstruction::CreateReshape(
3152       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
3153 
3154   HloComputation* computation = m->AddEntryComputation(builder.Build());
3155 
3156   EXPECT_THAT(computation->root_instruction(),
3157               GmockMatch(m::Reshape(m::Iota())));
3158 
3159   AlgebraicSimplifier simplifier(default_options_);
3160   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3161 
3162   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3163   EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
3164                 ->iota_dimension(),
3165             3);
3166 }
3167 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)3168 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
3169   auto m = CreateNewVerifiedModule();
3170   HloComputation::Builder builder(TestName());
3171   auto iota = builder.AddInstruction(
3172       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
3173   builder.AddInstruction(HloInstruction::CreateReshape(
3174       ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
3175 
3176   HloComputation* computation = m->AddEntryComputation(builder.Build());
3177 
3178   EXPECT_THAT(computation->root_instruction(),
3179               GmockMatch(m::Reshape(m::Iota())));
3180 
3181   AlgebraicSimplifier simplifier(default_options_);
3182   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3183 
3184   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3185   const int64 iota_dim =
3186       Cast<HloIotaInstruction>(computation->root_instruction())
3187           ->iota_dimension();
3188   EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
3189 }
3190 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)3191 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
3192   auto m = CreateNewVerifiedModule();
3193   HloComputation::Builder builder(TestName());
3194   auto iota = builder.AddInstruction(
3195       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
3196   builder.AddInstruction(
3197       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
3198 
3199   HloComputation* computation = m->AddEntryComputation(builder.Build());
3200 
3201   EXPECT_THAT(computation->root_instruction(),
3202               GmockMatch(m::Reshape(m::Iota())));
3203 
3204   AlgebraicSimplifier simplifier(default_options_);
3205   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3206 
3207   EXPECT_THAT(computation->root_instruction(),
3208               GmockMatch(m::Reshape(m::Iota())));
3209 }
3210 
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)3211 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
3212   HloComputation::Builder builder(TestName());
3213   HloInstruction* param =
3214       builder.AddInstruction(HloInstruction::CreateParameter(
3215           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3216   HloInstruction* zero = builder.AddInstruction(
3217       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3218   PaddingConfig no_padding;
3219   for (int i = 0; i < 2; ++i) {
3220     auto dimension = no_padding.add_dimensions();
3221     dimension->set_edge_padding_low(0);
3222     dimension->set_edge_padding_high(0);
3223     dimension->set_interior_padding(0);
3224   }
3225   builder.AddInstruction(HloInstruction::CreatePad(
3226       ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
3227 
3228   auto module = CreateNewVerifiedModule();
3229   HloComputation* computation = module->AddEntryComputation(builder.Build());
3230 
3231   EXPECT_THAT(computation->root_instruction(),
3232               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3233 
3234   AlgebraicSimplifier simplifier(default_options_);
3235   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3236 
3237   EXPECT_THAT(computation->root_instruction(), param);
3238 }
3239 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSliceOfPad)3240 TEST_F(AlgebraicSimplifierTest, RemoveNoopSliceOfPad) {
3241   HloComputation::Builder builder(TestName());
3242   HloInstruction* param =
3243       builder.AddInstruction(HloInstruction::CreateParameter(
3244           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3245   HloInstruction* zero = builder.AddInstruction(
3246       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3247   PaddingConfig no_padding;
3248   for (int i = 0; i < 2; ++i) {
3249     auto dimension = no_padding.add_dimensions();
3250     dimension->set_edge_padding_low(2);
3251     dimension->set_edge_padding_high(0);
3252     dimension->set_interior_padding(1);
3253   }
3254   auto pad = builder.AddInstruction(HloInstruction::CreatePad(
3255       ShapeUtil::MakeShape(F32, {5, 5}), param, zero, no_padding));
3256   builder.AddInstruction(HloInstruction::CreateSlice(
3257       ShapeUtil::MakeShape(F32, {2, 2}), pad, /*start_indices=*/{2, 2},
3258       /*limit_indices=*/{5, 5}, /*strides=*/{2, 2}));
3259 
3260   auto module = CreateNewVerifiedModule();
3261   HloComputation* computation = module->AddEntryComputation(builder.Build());
3262 
3263   EXPECT_THAT(computation->root_instruction(),
3264               GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3265 
3266   AlgebraicSimplifier simplifier(default_options_);
3267   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3268 
3269   EXPECT_THAT(computation->root_instruction(), param);
3270 }
3271 
TEST_F(AlgebraicSimplifierTest,NegativePadding)3272 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
3273   // Verify that a pad instruction with negative padding is replaced with a
3274   // pad with non-negative padding followed by a slice.
3275   HloComputation::Builder builder(TestName());
3276   HloInstruction* param =
3277       builder.AddInstruction(HloInstruction::CreateParameter(
3278           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3279   HloInstruction* zero = builder.AddInstruction(
3280       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3281   PaddingConfig padding;
3282   int64 low_padding[2] = {-1, -2};
3283   int64 high_padding[2] = {2, -3};
3284   for (int i = 0; i < 2; ++i) {
3285     auto dimension = padding.add_dimensions();
3286     dimension->set_edge_padding_low(low_padding[i]);
3287     dimension->set_edge_padding_high(high_padding[i]);
3288     dimension->set_interior_padding(0);
3289   }
3290   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3291       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3292 
3293   auto module = CreateNewVerifiedModule();
3294   HloComputation* computation = module->AddEntryComputation(builder.Build());
3295 
3296   AlgebraicSimplifier simplifier(default_options_);
3297 
3298   auto has_negative_padding = [](const HloInstruction* pad) {
3299     for (auto& padding_dimension : pad->padding_config().dimensions()) {
3300       if (padding_dimension.edge_padding_low() < 0 ||
3301           padding_dimension.edge_padding_high() < 0) {
3302         return true;
3303       }
3304     }
3305     return false;
3306   };
3307 
3308   EXPECT_THAT(computation->root_instruction(),
3309               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3310   EXPECT_TRUE(has_negative_padding(pad));
3311 
3312   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3313 
3314   EXPECT_THAT(computation->root_instruction(),
3315               GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3316   EXPECT_FALSE(
3317       has_negative_padding(computation->root_instruction()->operand(0)));
3318 }
3319 
TEST_F(AlgebraicSimplifierTest,CanDisableNegativePadding)3320 TEST_F(AlgebraicSimplifierTest, CanDisableNegativePadding) {
3321   // Verify that a pad instruction with negative padding is replaced with a
3322   // pad with non-negative padding followed by a slice.
3323   HloComputation::Builder builder(TestName());
3324   HloInstruction* param =
3325       builder.AddInstruction(HloInstruction::CreateParameter(
3326           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3327   HloInstruction* zero = builder.AddInstruction(
3328       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3329   PaddingConfig padding;
3330   int64 low_padding[2] = {-1, -2};
3331   int64 high_padding[2] = {2, -3};
3332   for (int i = 0; i < 2; ++i) {
3333     auto dimension = padding.add_dimensions();
3334     dimension->set_edge_padding_low(low_padding[i]);
3335     dimension->set_edge_padding_high(high_padding[i]);
3336     dimension->set_interior_padding(0);
3337   }
3338   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3339       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3340 
3341   auto module = CreateNewVerifiedModule();
3342   HloComputation* computation = module->AddEntryComputation(builder.Build());
3343 
3344   // Verify that we can disable the negative padding optimization.
3345   AlgebraicSimplifierOptions opts = default_options_;
3346   opts.set_enable_negative_padding_replacement(false);
3347 
3348   AlgebraicSimplifier simplifier(opts);
3349 
3350   auto has_negative_padding = [](const HloInstruction* pad) {
3351     for (auto& padding_dimension : pad->padding_config().dimensions()) {
3352       if (padding_dimension.edge_padding_low() < 0 ||
3353           padding_dimension.edge_padding_high() < 0) {
3354         return true;
3355       }
3356     }
3357     return false;
3358   };
3359 
3360   EXPECT_THAT(computation->root_instruction(),
3361               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3362   EXPECT_TRUE(has_negative_padding(pad));
3363 
3364   // Nothing has changed since the negative padding replacement is disabled.
3365   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3366 }
3367 
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)3368 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
3369   // Verify that a pad instruction with interior padding on one-sized
3370   // dimensions, removes the interior padding.
3371   HloComputation::Builder builder(TestName());
3372   HloInstruction* param =
3373       builder.AddInstruction(HloInstruction::CreateParameter(
3374           0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
3375   HloInstruction* zero = builder.AddInstruction(
3376       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3377   PaddingConfig padding;
3378   for (int i = 0; i < 2; ++i) {
3379     auto dimension = padding.add_dimensions();
3380     dimension->set_edge_padding_low(3);
3381     dimension->set_edge_padding_high(3);
3382     dimension->set_interior_padding(i * 3);
3383   }
3384   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3385       ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
3386 
3387   auto module = CreateNewVerifiedModule();
3388   HloComputation* computation = module->AddEntryComputation(builder.Build());
3389 
3390   AlgebraicSimplifier simplifier(default_options_);
3391 
3392   ASSERT_THAT(computation->root_instruction(),
3393               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3394   ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
3395 
3396   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3397 
3398   EXPECT_THAT(computation->root_instruction(),
3399               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3400   EXPECT_FALSE(
3401       HasInteriorPadding(computation->root_instruction()->padding_config()));
3402 }
3403 
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)3404 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
3405   HloComputation::Builder builder(TestName());
3406   HloInstruction* param =
3407       builder.AddInstruction(HloInstruction::CreateParameter(
3408           0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
3409   builder.AddInstruction(
3410       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
3411 
3412   auto module = CreateNewVerifiedModule();
3413   HloComputation* computation = module->AddEntryComputation(builder.Build());
3414 
3415   EXPECT_THAT(computation->root_instruction(),
3416               GmockMatch(m::Reshape(m::Parameter(0))));
3417 
3418   AlgebraicSimplifier simplifier(default_options_);
3419   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3420 
3421   EXPECT_THAT(computation->root_instruction(), param);
3422 }
3423 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)3424 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
3425   HloComputation::Builder builder(TestName());
3426   const int64 dim0 = 2;
3427   const int64 dim1 = 3;
3428   HloInstruction* param =
3429       builder.AddInstruction(HloInstruction::CreateParameter(
3430           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3431   builder.AddInstruction(HloInstruction::CreateSlice(
3432       ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
3433       /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
3434 
3435   auto module = CreateNewVerifiedModule();
3436   HloComputation* computation = module->AddEntryComputation(builder.Build());
3437 
3438   EXPECT_THAT(computation->root_instruction(),
3439               GmockMatch(m::Slice(m::Parameter(0))));
3440 
3441   AlgebraicSimplifier simplifier(default_options_);
3442   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3443 
3444   EXPECT_THAT(computation->root_instruction(), param);
3445 }
3446 
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)3447 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
3448   HloComputation::Builder builder(TestName());
3449   const int64 dim0 = 11;
3450   const int64 dim1 = 12;
3451   HloInstruction* param =
3452       builder.AddInstruction(HloInstruction::CreateParameter(
3453           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3454   HloInstruction* original_slice =
3455       builder.AddInstruction(HloInstruction::CreateSlice(
3456           ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
3457           /*start_indices=*/{1, 2},
3458           /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
3459 
3460   builder.AddInstruction(HloInstruction::CreateSlice(
3461       ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
3462       /*start_indices=*/{2, 3},
3463       /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
3464   auto module = CreateNewVerifiedModule();
3465   HloComputation* computation = module->AddEntryComputation(builder.Build());
3466 
3467   EXPECT_THAT(computation->root_instruction(),
3468               GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
3469 
3470   AlgebraicSimplifier simplifier(default_options_);
3471   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3472 
3473   EXPECT_THAT(computation->root_instruction(),
3474               GmockMatch(m::Slice(m::Parameter(0))));
3475   EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
3476   EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
3477   EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
3478   EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
3479 }
3480 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)3481 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
3482   HloComputation::Builder builder(TestName());
3483   const int64 dim0 = 11;
3484   const int64 dim1 = 12;
3485   HloInstruction* param =
3486       builder.AddInstruction(HloInstruction::CreateParameter(
3487           0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
3488   HloInstruction* broadcast =
3489       builder.AddInstruction(HloInstruction::CreateBroadcast(
3490           ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
3491   builder.AddInstruction(HloInstruction::CreateSlice(
3492       ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
3493       /*start_indices=*/{0, 3},
3494       /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
3495   auto module = CreateNewVerifiedModule();
3496   HloComputation* computation = module->AddEntryComputation(builder.Build());
3497 
3498   EXPECT_THAT(computation->root_instruction(),
3499               GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
3500 
3501   AlgebraicSimplifier simplifier(default_options_);
3502   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3503 
3504   EXPECT_THAT(computation->root_instruction(),
3505               GmockMatch(m::Broadcast(m::Parameter(0))));
3506 }
3507 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)3508 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
3509   HloComputation::Builder builder(TestName());
3510   const int64 dim0 = 11;
3511   const int64 dim1 = 12;
3512   const int64 dim2 = 13;
3513   HloInstruction* param =
3514       builder.AddInstruction(HloInstruction::CreateParameter(
3515           0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
3516   HloInstruction* original_reshape =
3517       builder.AddInstruction(HloInstruction::CreateReshape(
3518           ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
3519 
3520   builder.AddInstruction(HloInstruction::CreateSlice(
3521       ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
3522       /*start_indices=*/{0, 0, 0},
3523       /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
3524   auto module = CreateNewVerifiedModule();
3525   HloComputation* computation = module->AddEntryComputation(builder.Build());
3526 
3527   EXPECT_THAT(computation->root_instruction(),
3528               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3529 
3530   AlgebraicSimplifier simplifier(default_options_);
3531   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3532 
3533   EXPECT_THAT(computation->root_instruction(),
3534               GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
3535 }
3536 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)3537 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
3538   HloComputation::Builder builder(TestName());
3539   HloInstruction* param =
3540       builder.AddInstruction(HloInstruction::CreateParameter(
3541           0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
3542   HloInstruction* original_reshape =
3543       builder.AddInstruction(HloInstruction::CreateReshape(
3544           ShapeUtil::MakeShape(F32, {3600, 512}), param));
3545 
3546   builder.AddInstruction(HloInstruction::CreateSlice(
3547       ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
3548       /*start_indices=*/{0, 0},
3549       /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
3550   auto module = CreateNewVerifiedModule();
3551   HloComputation* computation = module->AddEntryComputation(builder.Build());
3552 
3553   EXPECT_THAT(computation->root_instruction(),
3554               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3555 
3556   AlgebraicSimplifier simplifier(default_options_);
3557   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3558 }
3559 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)3560 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
3561   auto builder = HloComputation::Builder(TestName());
3562   auto module = CreateNewVerifiedModule();
3563 
3564   Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
3565   auto keys = builder.AddInstruction(
3566       HloInstruction::CreateParameter(0, keys_shape, "keys"));
3567   TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
3568                            module.get())
3569                    .status());
3570   HloComputation* computation = module->AddEntryComputation(builder.Build());
3571   AlgebraicSimplifier simplifier(default_options_);
3572   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3573   EXPECT_THAT(computation->root_instruction(), keys);
3574 }
3575 
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)3576 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
3577   auto builder = HloComputation::Builder(TestName());
3578   auto module = CreateNewVerifiedModule();
3579 
3580   Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
3581   Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
3582   auto keys = builder.AddInstruction(
3583       HloInstruction::CreateParameter(0, keys_shape, "keys"));
3584   auto values0 = builder.AddInstruction(
3585       HloInstruction::CreateParameter(1, values_shape, "values0"));
3586   auto values1 = builder.AddInstruction(
3587       HloInstruction::CreateParameter(2, values_shape, "values1"));
3588   TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
3589                                {keys_shape, values_shape, values_shape}),
3590                            {keys, values0, values1}, 0, /*is_stable=*/false,
3591                            &builder, module.get())
3592                    .status());
3593   HloComputation* computation = module->AddEntryComputation(builder.Build());
3594   AlgebraicSimplifier simplifier(default_options_);
3595   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3596   EXPECT_THAT(computation->root_instruction(),
3597               GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
3598                                   m::Op().Is(values1))));
3599 }
3600 
3601 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)3602 TEST_F(AlgebraicSimplifierTest, AndTrue) {
3603   auto m = CreateNewVerifiedModule();
3604   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3605   HloComputation::Builder builder(TestName());
3606   HloInstruction* param0 = builder.AddInstruction(
3607       HloInstruction::CreateParameter(0, r0pred, "param0"));
3608   HloInstruction* const_true = builder.AddInstruction(
3609       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3610   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3611                                                       param0, const_true));
3612 
3613   auto computation = m->AddEntryComputation(builder.Build());
3614   HloInstruction* root = computation->root_instruction();
3615   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3616   AlgebraicSimplifier simplifier(default_options_);
3617   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3618   root = computation->root_instruction();
3619   EXPECT_EQ(root, param0);
3620 }
3621 
3622 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)3623 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
3624   auto m = CreateNewVerifiedModule();
3625   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3626   HloComputation::Builder builder(TestName());
3627   HloInstruction* param0 = builder.AddInstruction(
3628       HloInstruction::CreateParameter(0, r0pred, "param0"));
3629   HloInstruction* const_true = builder.AddInstruction(
3630       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3631   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3632                                                       const_true, param0));
3633 
3634   auto computation = m->AddEntryComputation(builder.Build());
3635   HloInstruction* root = computation->root_instruction();
3636   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3637   AlgebraicSimplifier simplifier(default_options_);
3638   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3639   root = computation->root_instruction();
3640   EXPECT_EQ(root, param0);
3641 }
3642 
3643 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)3644 TEST_F(AlgebraicSimplifierTest, AndFalse) {
3645   auto m = CreateNewVerifiedModule();
3646   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3647   HloComputation::Builder builder(TestName());
3648   HloInstruction* param0 = builder.AddInstruction(
3649       HloInstruction::CreateParameter(0, r0pred, "param0"));
3650   HloInstruction* const_false = builder.AddInstruction(
3651       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3652   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3653                                                       param0, const_false));
3654 
3655   auto computation = m->AddEntryComputation(builder.Build());
3656   HloInstruction* root = computation->root_instruction();
3657   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3658   AlgebraicSimplifier simplifier(default_options_);
3659   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3660   root = computation->root_instruction();
3661   EXPECT_EQ(root, const_false);
3662 }
3663 
3664 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)3665 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
3666   auto m = CreateNewVerifiedModule();
3667   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3668   HloComputation::Builder builder(TestName());
3669   HloInstruction* param0 = builder.AddInstruction(
3670       HloInstruction::CreateParameter(0, r0pred, "param0"));
3671   HloInstruction* const_false = builder.AddInstruction(
3672       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3673   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3674                                                       const_false, param0));
3675 
3676   auto computation = m->AddEntryComputation(builder.Build());
3677   HloInstruction* root = computation->root_instruction();
3678   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3679   AlgebraicSimplifier simplifier(default_options_);
3680   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3681   root = computation->root_instruction();
3682   EXPECT_EQ(root, const_false);
3683 }
3684 
3685 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)3686 TEST_F(AlgebraicSimplifierTest, OrTrue) {
3687   auto m = CreateNewVerifiedModule();
3688   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3689   HloComputation::Builder builder(TestName());
3690   HloInstruction* param0 = builder.AddInstruction(
3691       HloInstruction::CreateParameter(0, r0pred, "param0"));
3692   HloInstruction* const_true = builder.AddInstruction(
3693       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3694   builder.AddInstruction(
3695       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
3696 
3697   auto computation = m->AddEntryComputation(builder.Build());
3698   HloInstruction* root = computation->root_instruction();
3699   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3700   AlgebraicSimplifier simplifier(default_options_);
3701   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3702   root = computation->root_instruction();
3703   EXPECT_EQ(root, const_true);
3704 }
3705 
3706 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)3707 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
3708   auto m = CreateNewVerifiedModule();
3709   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3710   HloComputation::Builder builder(TestName());
3711   HloInstruction* param0 = builder.AddInstruction(
3712       HloInstruction::CreateParameter(0, r0pred, "param0"));
3713   HloInstruction* const_true = builder.AddInstruction(
3714       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3715   builder.AddInstruction(
3716       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
3717 
3718   auto computation = m->AddEntryComputation(builder.Build());
3719   HloInstruction* root = computation->root_instruction();
3720   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3721   AlgebraicSimplifier simplifier(default_options_);
3722   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3723   root = computation->root_instruction();
3724   EXPECT_EQ(root, const_true);
3725 }
3726 
3727 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)3728 TEST_F(AlgebraicSimplifierTest, OrFalse) {
3729   auto m = CreateNewVerifiedModule();
3730   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3731   HloComputation::Builder builder(TestName());
3732   HloInstruction* param0 = builder.AddInstruction(
3733       HloInstruction::CreateParameter(0, r0pred, "param0"));
3734   HloInstruction* const_false = builder.AddInstruction(
3735       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3736   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
3737                                                       param0, const_false));
3738 
3739   auto computation = m->AddEntryComputation(builder.Build());
3740   HloInstruction* root = computation->root_instruction();
3741   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3742   AlgebraicSimplifier simplifier(default_options_);
3743   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3744   root = computation->root_instruction();
3745   EXPECT_EQ(root, param0);
3746 }
3747 
3748 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)3749 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
3750   auto m = CreateNewVerifiedModule();
3751   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3752   HloComputation::Builder builder(TestName());
3753   HloInstruction* param0 = builder.AddInstruction(
3754       HloInstruction::CreateParameter(0, r0pred, "param0"));
3755   HloInstruction* const_false = builder.AddInstruction(
3756       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3757   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
3758                                                       const_false, param0));
3759 
3760   auto computation = m->AddEntryComputation(builder.Build());
3761   HloInstruction* root = computation->root_instruction();
3762   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3763   AlgebraicSimplifier simplifier(default_options_);
3764   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3765   root = computation->root_instruction();
3766   EXPECT_EQ(root, param0);
3767 }
3768 
3769 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
3770 // convolution's Window.
3771 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3772   ConvPaddingTestcase(absl::string_view padding,
3773                       absl::string_view orig_conv_window,
3774                       absl::string_view expected_conv_window)
3775       : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
3776                             /*pad_value=*/0) {}
3777 
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3778   ConvPaddingTestcase(absl::string_view padding,
3779                       absl::string_view orig_conv_window,
3780                       absl::string_view expected_conv_window, float pad_value)
3781       : padding(padding),
3782         orig_conv_window(orig_conv_window),
3783         expected_conv_window(expected_conv_window),
3784         pad_value(pad_value) {}
3785 
ToStringxla::__anonac242c730111::ConvPaddingTestcase3786   string ToString() const {
3787     return absl::StrFormat(
3788         "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
3789         "pad_value=%f",
3790         padding, orig_conv_window, expected_conv_window, pad_value);
3791   }
3792 
3793   string padding;
3794   string orig_conv_window;
3795   string expected_conv_window;
3796   float pad_value;
3797 };
3798 
3799 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
3800 // computation that does
3801 //
3802 //   conv(pad(param0, padding=padding), param1), window=orig_conv_window
3803 //
3804 // gets transformed by AlgebraicSimplifier to
3805 //
3806 //   conv(param0, param1), window=expected_conv_window
3807 //
3808 // or, if expected_conv_window is the empty string, checks that
3809 // AlgebraicSimplifier does *not* transform the original convolution.
3810 class ConvInputPaddingTest
3811     : public AlgebraicSimplifierTest,
3812       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3813 
3814 INSTANTIATE_TEST_SUITE_P(
3815     ConvInputPaddingTestCases, ConvInputPaddingTest,
3816     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3817         // Merge this edge padding into the conv.
3818         {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
3819         // Merge this edge padding with the conv's edge padding.
3820         {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
3821         // Merge this interior-padded kPad with the unpadded conv.  The 3x6
3822         // interior padding gets transformed to 4x7 conv lhs dilation.
3823         {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
3824         // kPad has dilation on one dim, conv has it on the other; merge them.
3825         {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
3826         // kPad has dilation and edge padding on one dim, conv has them on the
3827         // other; merge them.
3828         {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
3829          "pad=0_1x3_0 lhs_dilate=2x10"},
3830 
3831         // Don't transform if the pad value is nonzero.
3832         {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
3833 
3834         // We refuse to transform the following because on some dimension, one
3835         // of the kPad and conv has dilation and the other has some sort of
3836         // padding.
3837         {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
3838         {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
3839         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3840         {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
3841         {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
3842         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3843 
3844         // We can't merge feature or batch padding into the conv.
3845         {"1_0x0_0x0_0x0_0", "", ""},
3846         {"0_0x1_0x0_0x0_0", "", ""},
3847     }));
3848 
TEST_P(ConvInputPaddingTest,DoTest)3849 TEST_P(ConvInputPaddingTest, DoTest) {
3850   ConvPaddingTestcase testcase = GetParam();
3851 
3852   // It would be better to put the testcase's ToString into the test name, but
3853   // gUnit has constraints on what can go into test names, and any reasonable
3854   // implementation of ToString() seems to violate them.
3855   SCOPED_TRACE(testcase.ToString());
3856 
3857   auto builder = HloComputation::Builder(TestName());
3858   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3859       0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}),  // bf01
3860       "input"));
3861   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3862       LiteralUtil::CreateR0(testcase.pad_value)));
3863 
3864   PaddingConfig padding_config =
3865       ParsePaddingConfig(testcase.padding).ValueOrDie();
3866   auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3867       ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
3868                                     padding_config)
3869           .ValueOrDie(),
3870       input, pad_value, padding_config));
3871 
3872   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3873       1,
3874       ShapeUtil::MakeShape(
3875           F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}),  // io01
3876       "input"));
3877 
3878   ConvolutionDimensionNumbers dnums =
3879       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3880   Window window =
3881       ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
3882           .ValueOrDie();
3883   builder.AddInstruction(HloInstruction::CreateConvolve(
3884       ShapeInference::InferConvolveShape(
3885           lhs_pad->shape(), filter->shape(),
3886           /*feature_group_count=*/1,
3887           /*batch_group_count=*/1, window, dnums,
3888           /*preferred_element_type=*/absl::nullopt)
3889           .ValueOrDie(),
3890       lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
3891       window, dnums, DefaultPrecisionConfig(2)));
3892   auto module = CreateNewVerifiedModule();
3893   module->AddEntryComputation(builder.Build());
3894 
3895   AlgebraicSimplifier simplifier(default_options_);
3896   if (testcase.expected_conv_window.empty()) {
3897     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3898   } else {
3899     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3900     auto* conv = module->entry_computation()->root_instruction();
3901     SCOPED_TRACE(module->ToString());
3902     ASSERT_THAT(conv,
3903                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3904     EXPECT_EQ(window_util::ToString(conv->window()),
3905               absl::StrCat("size=3x3 ", testcase.expected_conv_window));
3906   }
3907 }
3908 
3909 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
3910 // computation that does
3911 //
3912 //   conv(param0, pad(param1, padding=padding)), window=orig_conv_window
3913 //
3914 // gets transformed by AlgebraicSimplifier to
3915 //
3916 //   conv(param0, param1), window=expected_conv_window
3917 //
3918 // or, if expected_conv_window is the empty string, checks that
3919 // AlgebraicSimplifier does *not* transform the original convolution.
3920 class ConvFilterPaddingTest
3921     : public AlgebraicSimplifierTest,
3922       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3923 
3924 INSTANTIATE_TEST_SUITE_P(
3925     ConvFilterPaddingTestCases, ConvFilterPaddingTest,
3926     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3927         // Can only merge interior padding on the filter's spatial dimensions;
3928         // all
3929         // other paddings (edge padding and interior padding on the channel
3930         // dims)
3931         // should be rejected out of hand.
3932         {"1_0_0x0_0_0x0_0x0_0", "", ""},
3933         {"0_1_0x0_0_0x0_0x0_0", "", ""},
3934         {"0_0_1x0_0_0x0_0x0_0", "", ""},
3935         {"0_0_0x1_0_0x0_0x0_0", "", ""},
3936         {"0_0_0x0_1_0x0_0x0_0", "", ""},
3937         {"0_0_0x0_0_1x0_0x0_0", "", ""},
3938         {"0_0_0x0_0_0x1_0x0_0", "", ""},
3939         {"0_0_0x0_0_0x0_1x0_0", "", ""},
3940         {"0_0_0x0_0_0x0_0x1_0", "", ""},
3941         {"0_0_0x0_0_0x0_0x0_1", "", ""},
3942 
3943         // Interior padding on channel dims can be merged into the conv, so long
3944         // as the conv and pad don't have interior padding on the same dim.
3945         {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
3946         {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
3947         {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
3948         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
3949         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
3950 
3951         // Can't merge if for a given dim there's interior padding on both the
3952         // pad and conv.
3953         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
3954         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
3955 
3956         // Don't transform if the pad value is nonzero.
3957         {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
3958     }));
3959 
TEST_P(ConvFilterPaddingTest,DoIt)3960 TEST_P(ConvFilterPaddingTest, DoIt) {
3961   ConvPaddingTestcase testcase = GetParam();
3962 
3963   // It would be better to put the testcase's ToString into the test name, but
3964   // gUnit has constraints on what can go into test names, and any reasonable
3965   // implementation of ToString() seems to violate them.
3966   SCOPED_TRACE(testcase.ToString());
3967 
3968   auto builder = HloComputation::Builder(TestName());
3969   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3970       LiteralUtil::CreateR0(testcase.pad_value)));
3971   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3972       1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}),  // io01
3973       "input"));
3974   PaddingConfig padding_config =
3975       ParsePaddingConfig(testcase.padding).ValueOrDie();
3976   auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3977       ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
3978                                     padding_config)
3979           .ValueOrDie(),
3980       filter, pad_value, padding_config));
3981 
3982   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3983       0,
3984       ShapeUtil::MakeShape(
3985           F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}),  // bf01
3986       "input"));
3987 
3988   ConvolutionDimensionNumbers dnums =
3989       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3990   Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
3991                                               rhs_pad->shape().dimensions(2),
3992                                               rhs_pad->shape().dimensions(3),
3993                                               testcase.orig_conv_window))
3994                       .ValueOrDie();
3995 
3996   // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
3997   // after the transformation.
3998   PrecisionConfig precision_config;
3999   precision_config.add_operand_precision(PrecisionConfig::HIGH);
4000   precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
4001 
4002   builder.AddInstruction(HloInstruction::CreateConvolve(
4003       ShapeInference::InferConvolveShape(
4004           input->shape(), rhs_pad->shape(),
4005           /*feature_group_count=*/1,
4006           /*batch_group_count=*/1, window, dnums,
4007           /*preferred_element_type=*/absl::nullopt)
4008           .ValueOrDie(),
4009       input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
4010       window, dnums, precision_config));
4011 
4012   auto module = CreateNewVerifiedModule();
4013   module->AddEntryComputation(builder.Build());
4014 
4015   AlgebraicSimplifier simplifier(default_options_);
4016   if (testcase.expected_conv_window.empty()) {
4017     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4018   } else {
4019     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4020     auto* conv = module->entry_computation()->root_instruction();
4021     SCOPED_TRACE(module->ToString());
4022     ASSERT_THAT(conv,
4023                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4024     EXPECT_EQ(window_util::ToString(conv->window()),
4025               absl::StrFormat("size=%dx%d %s",
4026                               conv->operand(1)->shape().dimensions(2),
4027                               conv->operand(1)->shape().dimensions(3),
4028                               testcase.expected_conv_window));
4029     EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
4030                     ->precision_config()
4031                     .operand_precision(),
4032                 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
4033   }
4034 }
4035 
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)4036 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
4037   struct ConvTestOptions {
4038     int in_batch = 10;
4039     int in_height = 2;
4040     int in_width = 2;
4041     int in_channels = 3;
4042     int f_width = 1;
4043     int f_height = 1;
4044     int f_output_channels = 10;
4045     int row_stride = 1;
4046     int row_padding = 0;
4047     int col_stride = 1;
4048     int col_padding = 0;
4049     bool input_minor_to_major_layout = false;
4050     bool filter_minor_to_major_layout = false;
4051     bool output_minor_to_major_layout = false;
4052 
4053     const char* dim_order = "NHWC";         // can use chars NHWC in any order.
4054     const char* kernel_dim_order = "HWIO";  // can use chars HWIO in any order.
4055 
4056     ConvTestOptions& Reset() {
4057       *this = ConvTestOptions();
4058       return *this;
4059     }
4060   };
4061 
4062   ConvTestOptions options;
4063 
4064   // Builds a convolution from <options> and runs algebraic simplification on
4065   // the computation. Returns a string description of the result of
4066   // simplification.
4067   auto build_and_simplify = [&]() -> string {
4068     HloComputation::Builder b(TestName());
4069 
4070     Window window;
4071     auto* f_dim_1 = window.add_dimensions();
4072     f_dim_1->set_size(options.f_height);
4073     f_dim_1->set_stride(options.row_stride);
4074     f_dim_1->set_padding_low(options.row_padding);
4075     f_dim_1->set_padding_high(options.row_padding);
4076     f_dim_1->set_window_dilation(1);
4077     f_dim_1->set_base_dilation(1);
4078     auto* f_dim_2 = window.add_dimensions();
4079     f_dim_2->set_size(options.f_width);
4080     f_dim_2->set_stride(options.col_stride);
4081     f_dim_2->set_padding_low(options.col_padding);
4082     f_dim_2->set_padding_high(options.col_padding);
4083     f_dim_2->set_window_dilation(1);
4084     f_dim_2->set_base_dilation(1);
4085 
4086     ConvolutionDimensionNumbers dnums;
4087     std::vector<int64> in_dims;
4088     int in_channel_idx = -1;
4089     // filled in later
4090     dnums.add_input_spatial_dimensions(-1);
4091     dnums.add_output_spatial_dimensions(-1);
4092     dnums.add_input_spatial_dimensions(-1);
4093     dnums.add_output_spatial_dimensions(-1);
4094     for (int i = 0; i < strlen(options.dim_order); ++i) {
4095       char ch = options.dim_order[i];
4096       if (ch == 'N') {
4097         dnums.set_input_batch_dimension(i);
4098         dnums.set_output_batch_dimension(i);
4099         in_dims.push_back(options.in_batch);
4100       } else if (ch == 'H') {
4101         dnums.set_input_spatial_dimensions(0, i);
4102         dnums.set_output_spatial_dimensions(0, i);
4103         in_dims.push_back(options.in_height);
4104       } else if (ch == 'W') {
4105         dnums.set_input_spatial_dimensions(1, i);
4106         dnums.set_output_spatial_dimensions(1, i);
4107         in_dims.push_back(options.in_width);
4108       } else if (ch == 'C') {
4109         dnums.set_input_feature_dimension(i);
4110         dnums.set_output_feature_dimension(i);
4111         in_dims.push_back(options.in_channels);
4112         in_channel_idx = i;
4113       }
4114     }
4115 
4116     std::vector<int64> f_dims;
4117     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
4118     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
4119     for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
4120       char ch = options.kernel_dim_order[i];
4121       if (ch == 'H') {
4122         dnums.set_kernel_spatial_dimensions(0, i);
4123         f_dims.push_back(options.f_height);
4124       } else if (ch == 'W') {
4125         dnums.set_kernel_spatial_dimensions(1, i);
4126         f_dims.push_back(options.f_width);
4127       } else if (ch == 'I') {
4128         dnums.set_kernel_input_feature_dimension(i);
4129         f_dims.push_back(options.in_channels);
4130       } else if (ch == 'O') {
4131         dnums.set_kernel_output_feature_dimension(i);
4132         f_dims.push_back(options.f_output_channels);
4133       }
4134     }
4135 
4136     auto make_shape = [](absl::Span<const int64> dims,
4137                          bool minor_to_major_layout) {
4138       if (minor_to_major_layout) {
4139         return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
4140       } else {
4141         return ShapeUtil::MakeShape(F32, dims);
4142       }
4143     };
4144     auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
4145     auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
4146 
4147     HloInstruction* input =
4148         b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
4149     HloInstruction* filter =
4150         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
4151     Shape out_shape = ShapeInference::InferConvolveShape(
4152                           in_shape, f_shape, /*feature_group_count=*/1,
4153                           /*batch_group_count=*/1, window, dnums,
4154                           /*preferred_element_type=*/absl::nullopt)
4155                           .ValueOrDie();
4156     if (options.output_minor_to_major_layout) {
4157       out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
4158                                                  {0, 1, 2, 3});
4159     }
4160 
4161     b.AddInstruction(HloInstruction::CreateConvolve(
4162         out_shape, input, filter,
4163         /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
4164         DefaultPrecisionConfig(2)));
4165 
4166     auto module = CreateNewVerifiedModule();
4167     auto* computation = module->AddEntryComputation(b.Build());
4168 
4169     AlgebraicSimplifierOptions simplifier_options;
4170     simplifier_options.set_is_layout_sensitive(true);
4171     AlgebraicSimplifier simplifier(simplifier_options);
4172     if (!simplifier.Run(module.get()).ValueOrDie()) {
4173       return "NO_CHANGE";
4174     }
4175     auto* root = computation->root_instruction();
4176     if (root->opcode() == HloOpcode::kBitcast &&
4177         root->operand(0)->opcode() == HloOpcode::kDot) {
4178       auto lhs_shape = root->operand(0)->operand(0)->shape();
4179       auto rhs_shape = root->operand(0)->operand(1)->shape();
4180       return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
4181                           absl::StrJoin(rhs_shape.dimensions(), "x"));
4182     }
4183     return "UNEXPECTED CHANGE";
4184   };
4185 
4186   // Default options are the simplest case and succeed.
4187   options.Reset();
4188   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4189 
4190   // Swapping dim spatial and batch order works.
4191   options.Reset().dim_order = "NWHC";
4192   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4193   options.Reset().dim_order = "WHNC";
4194   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4195   // Channel dimension earlier fails.
4196   options.Reset().dim_order = "HWCN";
4197   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4198   options.Reset().dim_order = "CHWN";
4199   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4200 
4201   // Filtering dims spatial dims can be anywhere, since they are 1x1.
4202   options.Reset().kernel_dim_order = "WHIO";
4203   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4204   options.Reset().kernel_dim_order = "IWOH";
4205   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4206   options.Reset().kernel_dim_order = "IWHO";
4207   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4208   // But moving output channel before input channel fails.
4209   options.Reset().kernel_dim_order = "HWOI";
4210   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4211   options.Reset().kernel_dim_order = "WHOI";
4212   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4213   options.Reset().kernel_dim_order = "OWIH";
4214   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4215   options.Reset().kernel_dim_order = "OWHI";
4216   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4217 
4218   // Combine different dim and kernel dim orders.
4219   options.Reset().kernel_dim_order = "IWHO";
4220   options.dim_order = "WHNC";
4221   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4222 
4223   // Test invalid cases from wrong filter size, strides, or padding.
4224   options.Reset().f_width = 2;
4225   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4226   options.Reset().f_height = 2;
4227   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4228   options.Reset().row_stride = 2;
4229   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4230   options.Reset().col_stride = 2;
4231   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4232   options.Reset().col_padding = 1;
4233   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4234   options.Reset().row_padding = 1;
4235   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4236 
4237   // The default dim_order is "NHWC". Col-major layout makes C the most major.
4238   options.Reset().input_minor_to_major_layout = true;
4239   options.output_minor_to_major_layout = true;
4240   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4241 
4242   // The input and output have different layouts.
4243   options.Reset().input_minor_to_major_layout = true;
4244   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4245 
4246   // C is most minor, and I is more major than O.
4247   options.Reset().input_minor_to_major_layout = true;
4248   options.filter_minor_to_major_layout = true;
4249   options.output_minor_to_major_layout = true;
4250   options.dim_order = "CHWN";
4251   options.kernel_dim_order = "OIHW";
4252   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4253 
4254   // C is not the most minor dimension.
4255   options.Reset().input_minor_to_major_layout = true;
4256   options.filter_minor_to_major_layout = true;
4257   options.output_minor_to_major_layout = true;
4258   options.dim_order = "HWNC";
4259   options.kernel_dim_order = "OIHW";
4260   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4261 
4262   // I is more minor than O.
4263   options.Reset().input_minor_to_major_layout = true;
4264   options.filter_minor_to_major_layout = true;
4265   options.output_minor_to_major_layout = true;
4266   options.dim_order = "CHWN";
4267   options.kernel_dim_order = "IOHW";
4268   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4269 }
4270 
4271 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
4272 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)4273 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
4274   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
4275   HloComputation::Builder builder(TestName());
4276   HloInstruction* scalar_param = builder.AddInstruction(
4277       HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
4278 
4279   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
4280   HloInstruction* broadcast = builder.AddInstruction(
4281       HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
4282 
4283   Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
4284   HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
4285       slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
4286 
4287   auto module = CreateNewVerifiedModule();
4288   auto computation = module->AddEntryComputation(builder.Build());
4289 
4290   HloInstruction* root = computation->root_instruction();
4291   EXPECT_EQ(root, slice);
4292   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
4293 
4294   AlgebraicSimplifier simplifier(default_options_);
4295 
4296   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4297 
4298   // Running simplification again should not result in any further changes.
4299   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4300   EXPECT_THAT(computation->root_instruction(),
4301               GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
4302                              .WithShapeEqualTo(&slice_shape)));
4303 }
4304 
4305 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
4306 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)4307 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
4308   HloComputation::Builder builder(TestName());
4309   HloInstruction* forty_two = builder.AddInstruction(
4310       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
4311 
4312   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
4313   HloInstruction* broadcast = builder.AddInstruction(
4314       HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
4315 
4316   HloInstruction* transpose =
4317       builder.AddInstruction(HloInstruction::CreateTranspose(
4318           ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
4319 
4320   Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
4321   HloInstruction* reshape = builder.AddInstruction(
4322       HloInstruction::CreateReshape(reshape_shape, transpose));
4323 
4324   auto module = CreateNewVerifiedModule();
4325   auto computation = module->AddEntryComputation(builder.Build());
4326 
4327   HloInstruction* root = computation->root_instruction();
4328   EXPECT_EQ(root, reshape);
4329   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
4330 
4331   AlgebraicSimplifier simplifier(default_options_);
4332   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4333   EXPECT_THAT(computation->root_instruction(),
4334               GmockMatch(m::Broadcast(m::Op().Is(forty_two))
4335                              .WithShapeEqualTo(&reshape_shape)));
4336 }
4337 
4338 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)4339 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
4340   auto module = CreateNewVerifiedModule();
4341   HloComputation::Builder builder(TestName());
4342 
4343   // Create operand to the pad.
4344   HloInstruction* operand =
4345       builder.AddInstruction(HloInstruction::CreateParameter(
4346           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
4347 
4348   // Create the pad.
4349   PaddingConfig padding = MakeNoPaddingConfig(4);
4350   padding.mutable_dimensions(1)->set_edge_padding_low(1);
4351   padding.mutable_dimensions(3)->set_edge_padding_high(2);
4352 
4353   HloInstruction* pad_value = builder.AddInstruction(
4354       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4355   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4356       ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
4357 
4358   // Create add computation.
4359   HloComputation* add_computation = nullptr;
4360   {
4361     HloComputation::Builder builder(TestName() + ".add");
4362     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4363     HloInstruction* p0 = builder.AddInstruction(
4364         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4365     HloInstruction* p1 = builder.AddInstruction(
4366         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4367     builder.AddInstruction(
4368         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4369     add_computation = module->AddEmbeddedComputation(builder.Build());
4370   }
4371 
4372   // Create the reduce-window.
4373   Window window;
4374   for (int64 i = 0; i < pad->shape().rank(); ++i) {
4375     auto* dim = window.add_dimensions();
4376     dim->set_size(1);
4377     dim->set_padding_low(10);
4378     dim->set_padding_high(100);
4379     dim->set_window_dilation(1);
4380     dim->set_base_dilation(1);
4381     dim->set_stride(1);
4382   }
4383   const Shape reduce_window_shape =
4384       ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
4385   HloInstruction* reduce_init_value = builder.AddInstruction(
4386       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4387   HloInstruction* reduce_window =
4388       builder.AddInstruction(HloInstruction::CreateReduceWindow(
4389           reduce_window_shape, pad, reduce_init_value, window,
4390           add_computation));
4391 
4392   // Build the computation and run the simplifier.
4393   auto computation = module->AddEntryComputation(builder.Build());
4394   HloInstruction* root = computation->root_instruction();
4395   EXPECT_EQ(root, reduce_window);
4396   AlgebraicSimplifier simplifier(default_options_);
4397   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4398 
4399   // Running simplification again should not result in any further changes.
4400   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4401 
4402   // Verify the result
4403   root = computation->root_instruction();
4404   EXPECT_THAT(root,
4405               GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant())));
4406   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
4407       << ShapeUtil::HumanString(root->shape()) << " vs "
4408       << ShapeUtil::HumanString(reduce_window_shape);
4409   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4410   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4411   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4412   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4413   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4414   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4415   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4416   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4417 }
4418 
4419 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
4420 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)4421 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
4422   auto module = CreateNewVerifiedModule();
4423   HloComputation::Builder builder(TestName());
4424 
4425   // Create operand to the pad.
4426   HloInstruction* parameter =
4427       builder.AddInstruction(HloInstruction::CreateParameter(
4428           0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
4429 
4430   // Create the pad.
4431   PaddingConfig padding = MakeNoPaddingConfig(4);
4432   padding.mutable_dimensions(1)->set_edge_padding_low(1);
4433   padding.mutable_dimensions(3)->set_edge_padding_high(2);
4434 
4435   HloInstruction* pad_value = builder.AddInstruction(
4436       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4437   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4438       ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
4439 
4440   HloInstruction* convert =
4441       builder.AddInstruction(HloInstruction::CreateConvert(
4442           ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
4443 
4444   // Create add computation.
4445   HloComputation* add_computation = nullptr;
4446   {
4447     HloComputation::Builder builder(TestName() + ".add");
4448     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4449     HloInstruction* p0 = builder.AddInstruction(
4450         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4451     HloInstruction* p1 = builder.AddInstruction(
4452         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4453     builder.AddInstruction(
4454         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4455     add_computation = module->AddEmbeddedComputation(builder.Build());
4456   }
4457 
4458   // Create the reduce-window.
4459   Window window;
4460   for (int64 i = 0; i < pad->shape().rank(); ++i) {
4461     auto* dim = window.add_dimensions();
4462     dim->set_size(1);
4463     dim->set_padding_low(10);
4464     dim->set_padding_high(100);
4465     dim->set_window_dilation(1);
4466     dim->set_base_dilation(1);
4467     dim->set_stride(1);
4468   }
4469   const Shape reduce_window_shape =
4470       ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
4471   HloInstruction* reduce_init_value = builder.AddInstruction(
4472       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4473   HloInstruction* reduce_window =
4474       builder.AddInstruction(HloInstruction::CreateReduceWindow(
4475           reduce_window_shape, convert, reduce_init_value, window,
4476           add_computation));
4477 
4478   // Build the computation and run the simplifier.
4479   auto computation = module->AddEntryComputation(builder.Build());
4480   HloInstruction* root = computation->root_instruction();
4481   EXPECT_EQ(root, reduce_window);
4482   AlgebraicSimplifier simplifier(default_options_);
4483   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4484 
4485   // Running simplification again should not result in any further changes.
4486   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4487 
4488   // Verify the result
4489   root = computation->root_instruction();
4490   EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
4491                                                m::Constant())));
4492   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
4493       << ShapeUtil::HumanString(root->shape()) << " vs "
4494       << ShapeUtil::HumanString(reduce_window_shape);
4495   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4496   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4497   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4498   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4499   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4500   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4501   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4502   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4503 }
4504 
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)4505 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
4506   HloComputation::Builder builder(TestName());
4507   const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
4508   HloInstruction* a =
4509       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
4510   builder.AddInstruction(
4511       HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
4512 
4513   auto module = CreateNewVerifiedModule();
4514   auto computation = module->AddEntryComputation(builder.Build());
4515 
4516   AlgebraicSimplifier simplifier(default_options_);
4517   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4518 
4519   HloInstruction* root = computation->root_instruction();
4520   EXPECT_EQ(a, root);
4521   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
4522 }
4523 
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)4524 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
4525   // Dots add computations to the parent module. Test that, when the HloModule's
4526   // computations are updated, then iterator invalidation doesn't occur
4527   // when running on subsequent computations.
4528   auto m = CreateNewVerifiedModule();
4529   Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
4530   HloComputation::Builder builder(TestName() + ".Dot");
4531   HloInstruction* x =
4532       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
4533   HloInstruction* y =
4534       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
4535   DotDimensionNumbers dot_dnums;
4536   dot_dnums.add_lhs_batch_dimensions(0);
4537   dot_dnums.add_rhs_batch_dimensions(0);
4538   builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
4539                                                    DefaultPrecisionConfig(2)));
4540   std::unique_ptr<HloComputation> dot_computation(builder.Build());
4541 
4542   HloComputation::Builder call_builder(TestName() + ".Call");
4543   HloInstruction* zero = call_builder.AddInstruction(
4544       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
4545   HloInstruction* one = call_builder.AddInstruction(
4546       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
4547   call_builder.AddInstruction(
4548       HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
4549 
4550   m->AddEmbeddedComputation(std::move(dot_computation));
4551   m->AddEntryComputation(call_builder.Build());
4552   AlgebraicSimplifier simplifier(default_options_);
4553   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4554 }
4555 
4556 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)4557 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
4558   auto m = CreateNewVerifiedModule();
4559   HloComputation::Builder builder(TestName());
4560   const float constant_scalar = 7.3f;
4561   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
4562   Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
4563                         LiteralUtil::CreateR1<float>(constant_vector)};
4564   Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
4565   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
4566 
4567   auto computation = m->AddEntryComputation(builder.Build());
4568 
4569   AlgebraicSimplifier simplifier(default_options_);
4570   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4571   EXPECT_THAT(computation->root_instruction(),
4572               GmockMatch(m::Tuple(m::Constant(), m::Constant())));
4573 }
4574 
4575 // A dynamic-slice is trivial if its start indices are all zeroes and the size
4576 // of its input equals the size of its output.  In this case, the dynamic slice
4577 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)4578 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
4579   auto m = CreateNewVerifiedModule();
4580   HloComputation::Builder builder(TestName());
4581 
4582   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4583   std::vector<HloInstruction*> params;
4584   for (int i = 0; i < 3; ++i) {
4585     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
4586         i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
4587   }
4588   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4589       shape,
4590       builder.AddInstruction(
4591           HloInstruction::CreateParameter(0, shape, "slice_from")),
4592       params,
4593       /*slice_sizes=*/{10, 100, 1000}));
4594 
4595   auto computation = m->AddEntryComputation(builder.Build());
4596   AlgebraicSimplifier simplifier(default_options_);
4597   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4598   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
4599 }
4600 
TEST_F(AlgebraicSimplifierTest,ConstantDynamicSlice)4601 TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
4602   auto m = CreateNewVerifiedModule();
4603   HloComputation::Builder builder(TestName());
4604 
4605   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4606   std::vector<HloInstruction*> params;
4607   for (int i = 0; i < 3; ++i) {
4608     params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
4609         LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
4610   }
4611   Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
4612   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4613       ds_shape,
4614       builder.AddInstruction(
4615           HloInstruction::CreateParameter(0, shape, "operand")),
4616       params,
4617       /*slice_sizes=*/{2, 20, 200}));
4618 
4619   auto computation = m->AddEntryComputation(builder.Build());
4620   AlgebraicSimplifier simplifier(default_options_);
4621   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4622   EXPECT_THAT(computation->root_instruction(),
4623               GmockMatch(m::Slice(m::Parameter())));
4624 }
4625 
4626 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
4627 // size of its "update" equals the size of its output.  In this case, the
4628 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)4629 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
4630   auto m = CreateNewVerifiedModule();
4631   HloComputation::Builder builder(TestName());
4632 
4633   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4634   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
4635 
4636   std::vector<HloInstruction*> slice_indices, update_indices;
4637   for (int i = 0; i < 3; ++i) {
4638     slice_indices.push_back(
4639         builder.AddInstruction(HloInstruction::CreateParameter(
4640             i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
4641     update_indices.push_back(
4642         builder.AddInstruction(HloInstruction::CreateParameter(
4643             i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
4644   }
4645   HloInstruction* slice =
4646       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4647           slice_shape,
4648           builder.AddInstruction(
4649               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
4650           slice_indices,
4651           /*slice_sizes=*/{10, 1, 1000}));
4652 
4653   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4654       slice_shape,
4655       builder.AddInstruction(
4656           HloInstruction::CreateParameter(4, slice_shape, "to_update")),
4657       slice, update_indices));
4658 
4659   auto computation = m->AddEntryComputation(builder.Build());
4660   AlgebraicSimplifier simplifier(default_options_);
4661   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4662   EXPECT_THAT(computation->root_instruction(),
4663               GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
4664                                          m::Parameter(), m::Parameter())));
4665 }
4666 
4667 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)4668 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
4669   auto m = CreateNewVerifiedModule();
4670   HloComputation::Builder builder(TestName());
4671   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
4672   HloInstruction* input_array = builder.AddInstruction(
4673       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
4674   HloInstruction* inner_bcast = builder.AddInstruction(
4675       HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
4676   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
4677   builder.AddInstruction(
4678       HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
4679 
4680   auto computation = m->AddEntryComputation(builder.Build());
4681   HloInstruction* root = computation->root_instruction();
4682   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4683   AlgebraicSimplifier simplifier(default_options_);
4684   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4685   root = computation->root_instruction();
4686   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4687   EXPECT_THAT(root->dimensions(), ElementsAre(2));
4688 }
4689 
4690 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)4691 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
4692   auto m = CreateNewVerifiedModule();
4693   HloComputation::Builder builder(TestName());
4694   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
4695   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
4696   HloInstruction* param0 = builder.AddInstruction(
4697       HloInstruction::CreateParameter(0, r2f32, "param0"));
4698   // The initial dimensions go to places 0 and 2 in the 3-dim array,
4699   // and to places 1 and 3 in the 4-dim array,
4700   HloInstruction* inner_bcast = builder.AddInstruction(
4701       HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
4702   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
4703   builder.AddInstruction(
4704       HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
4705 
4706   auto computation = m->AddEntryComputation(builder.Build());
4707   HloInstruction* root = computation->root_instruction();
4708   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4709   AlgebraicSimplifier simplifier(default_options_);
4710   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4711   root = computation->root_instruction();
4712   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
4713   EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
4714 }
4715 
4716 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)4717 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
4718   auto m = CreateNewVerifiedModule();
4719   HloComputation::Builder builder(TestName());
4720   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
4721   HloInstruction* iota =
4722       builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
4723   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
4724   builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
4725 
4726   auto computation = m->AddEntryComputation(builder.Build());
4727   HloInstruction* root = computation->root_instruction();
4728   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4729   AlgebraicSimplifier simplifier(default_options_);
4730   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4731   root = computation->root_instruction();
4732   EXPECT_THAT(root, GmockMatch(m::Iota()));
4733   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
4734 }
4735 
4736 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)4737 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
4738   auto m = CreateNewVerifiedModule();
4739   HloComputation::Builder builder(TestName());
4740   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
4741   HloInstruction* iota =
4742       builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
4743   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
4744   builder.AddInstruction(
4745       HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
4746 
4747   auto computation = m->AddEntryComputation(builder.Build());
4748   HloInstruction* root = computation->root_instruction();
4749   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4750   AlgebraicSimplifier simplifier(default_options_);
4751   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4752   root = computation->root_instruction();
4753   EXPECT_THAT(root, GmockMatch(m::Iota()));
4754   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
4755 }
4756 
TEST_F(AlgebraicSimplifierTest,TransposeOfDot)4757 TEST_F(AlgebraicSimplifierTest, TransposeOfDot) {
4758   const char* hlo_string = R"(
4759     HloModule module
4760 
4761     ENTRY test {
4762       lhs = f32[3,4,5] parameter(0)
4763       rhs = f32[6,3,4] parameter(1)
4764       dot = f32[5,6] dot(lhs,rhs), lhs_contracting_dims={0,1}, rhs_contracting_dims={1,2}
4765       ROOT transpose = f32[6,5] transpose(dot), dimensions={1,0}
4766     }
4767   )";
4768   TF_ASSERT_OK_AND_ASSIGN(auto module,
4769                           ParseAndReturnVerifiedModule(hlo_string));
4770 
4771   AlgebraicSimplifierOptions options;
4772   AlgebraicSimplifier simplifier(options);
4773   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4774   auto root = module->entry_computation()->root_instruction();
4775   EXPECT_THAT(root, GmockMatch(m::Dot(m::Parameter(1), m::Parameter(0))));
4776 }
4777 
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)4778 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
4779   const char* hlo_string = R"(
4780     HloModule module
4781 
4782     ENTRY test {
4783       param = f32[3,4] parameter(0)
4784       constant = f32[] constant(0.0)
4785       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4786       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
4787     }
4788   )";
4789   TF_ASSERT_OK_AND_ASSIGN(auto module,
4790                           ParseAndReturnVerifiedModule(hlo_string));
4791 
4792   AlgebraicSimplifierOptions options;
4793   AlgebraicSimplifier simplifier(options);
4794   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4795   auto root = module->entry_computation()->root_instruction();
4796   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4797 }
4798 
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)4799 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
4800   const char* hlo_string = R"(
4801     HloModule module
4802 
4803     ENTRY test {
4804       param = f32[3,4] parameter(0)
4805       constant = f32[] constant(0.0)
4806       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4807       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
4808     }
4809   )";
4810   TF_ASSERT_OK_AND_ASSIGN(auto module,
4811                           ParseAndReturnVerifiedModule(hlo_string));
4812 
4813   AlgebraicSimplifierOptions options;
4814   AlgebraicSimplifier simplifier(options);
4815   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4816   auto root = module->entry_computation()->root_instruction();
4817   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4818 }
4819 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)4820 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
4821   const char* hlo_string = R"(
4822     HloModule module
4823 
4824     ENTRY test {
4825       param = f32[3,4] parameter(0)
4826       constant = f32[] constant(0.0)
4827       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4828       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
4829     }
4830   )";
4831   TF_ASSERT_OK_AND_ASSIGN(auto module,
4832                           ParseAndReturnVerifiedModule(hlo_string));
4833 
4834   AlgebraicSimplifierOptions options;
4835   AlgebraicSimplifier simplifier(options);
4836   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4837   EXPECT_THAT(module->entry_computation()->root_instruction(),
4838               GmockMatch(m::Slice(m::Parameter(0))));
4839 }
4840 
TEST_F(AlgebraicSimplifierTest,SliceOfPad)4841 TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
4842   const char* hlo_string = R"(
4843     HloModule module
4844 
4845     ENTRY test {
4846       param = f32[3,4] parameter(0)
4847       constant = f32[] constant(0.0)
4848       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4849       ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
4850     }
4851   )";
4852   TF_ASSERT_OK_AND_ASSIGN(auto module,
4853                           ParseAndReturnVerifiedModule(hlo_string));
4854 
4855   AlgebraicSimplifierOptions options;
4856   AlgebraicSimplifier simplifier(options);
4857   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4858   auto root = module->entry_computation()->root_instruction();
4859   EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
4860   EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
4861 }
4862 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)4863 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
4864   const char* hlo_string = R"(
4865     HloModule module
4866 
4867     ENTRY test {
4868       param = f32[3,4] parameter(0)
4869       constant = f32[] constant(0.0)
4870       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4871       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
4872     }
4873   )";
4874   TF_ASSERT_OK_AND_ASSIGN(auto module,
4875                           ParseAndReturnVerifiedModule(hlo_string));
4876 
4877   AlgebraicSimplifierOptions options;
4878   AlgebraicSimplifier simplifier(options);
4879   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4880   auto root = module->entry_computation()->root_instruction();
4881   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4882 }
4883 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)4884 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
4885   const char* hlo_string = R"(
4886     HloModule module
4887 
4888     ENTRY test {
4889       param = f32[1,1] parameter(0)
4890       constant = f32[] constant(0.0)
4891       pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
4892       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
4893     }
4894   )";
4895   TF_ASSERT_OK_AND_ASSIGN(auto module,
4896                           ParseAndReturnVerifiedModule(hlo_string));
4897 
4898   AlgebraicSimplifierOptions options;
4899   AlgebraicSimplifier simplifier(options);
4900   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4901   auto root = module->entry_computation()->root_instruction();
4902   EXPECT_THAT(root, GmockMatch(m::Parameter()));
4903 }
4904 
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)4905 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
4906   const char* hlo_string = R"(
4907     HloModule module
4908 
4909     ENTRY entry () -> f32[1]{0} {
4910       constant.val = f32[] constant(4)
4911       constant.pad = f32[] constant(-7)
4912       reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
4913       pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
4914       slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
4915       ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
4916     }
4917   )";
4918   TF_ASSERT_OK_AND_ASSIGN(auto module,
4919                           ParseAndReturnVerifiedModule(hlo_string));
4920 
4921   AlgebraicSimplifierOptions options;
4922   AlgebraicSimplifier simplifier(options);
4923   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4924   auto root = module->entry_computation()->root_instruction();
4925   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
4926 }
4927 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)4928 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
4929   const char* hlo_string = R"(
4930     HloModule module
4931 
4932     ENTRY test {
4933       param.0 = f32[2] parameter(0)
4934       param.1 = f32[1] parameter(1)
4935       param.2 = f32[3] parameter(2)
4936       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4937       ROOT slice = f32[1] slice(concat), slice={[2:3]}
4938     }
4939   )";
4940   TF_ASSERT_OK_AND_ASSIGN(auto module,
4941                           ParseAndReturnVerifiedModule(hlo_string));
4942 
4943   AlgebraicSimplifierOptions options;
4944   AlgebraicSimplifier simplifier(options);
4945   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4946   auto root = module->entry_computation()->root_instruction();
4947   EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
4948 }
4949 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)4950 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
4951   const char* hlo_string = R"(
4952     HloModule module
4953 
4954     ENTRY test {
4955       param.0 = f32[2] parameter(0)
4956       param.1 = f32[1] parameter(1)
4957       param.2 = f32[3] parameter(2)
4958       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4959       ROOT slice = f32[1] slice(concat), slice={[4:5]}
4960     }
4961   )";
4962   TF_ASSERT_OK_AND_ASSIGN(auto module,
4963                           ParseAndReturnVerifiedModule(hlo_string));
4964 
4965   AlgebraicSimplifierOptions options;
4966   AlgebraicSimplifier simplifier(options);
4967   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4968   auto root = module->entry_computation()->root_instruction();
4969   EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
4970   EXPECT_EQ(root->slice_starts(0), 1);
4971   EXPECT_EQ(root->slice_limits(0), 2);
4972 }
4973 
TEST_F(AlgebraicSimplifierTest,ConcatToBroadcast)4974 TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) {
4975   const char* hlo_string = R"(
4976     HloModule module
4977 
4978     ENTRY test {
4979       p = f32[2,1,4] parameter(0)
4980       ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1}
4981     }
4982   )";
4983   TF_ASSERT_OK_AND_ASSIGN(auto module,
4984                           ParseAndReturnVerifiedModule(hlo_string));
4985 
4986   AlgebraicSimplifierOptions options;
4987   AlgebraicSimplifier simplifier(options);
4988   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4989   auto root = module->entry_computation()->root_instruction();
4990   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
4991 }
4992 
TEST_F(AlgebraicSimplifierTest,NegateNegate)4993 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
4994   const char* hlo_string = R"(
4995     HloModule module
4996 
4997     ENTRY test {
4998       param.0 = f32[2] parameter(0)
4999       neg.0 = f32[2] negate(param.0)
5000       ROOT neg.1 = f32[2] negate(neg.0)
5001     }
5002   )";
5003   TF_ASSERT_OK_AND_ASSIGN(auto module,
5004                           ParseAndReturnVerifiedModule(hlo_string));
5005 
5006   AlgebraicSimplifierOptions options;
5007   AlgebraicSimplifier simplifier(options);
5008   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5009   auto root = module->entry_computation()->root_instruction();
5010   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5011 }
5012 
TEST_F(AlgebraicSimplifierTest,NotNot)5013 TEST_F(AlgebraicSimplifierTest, NotNot) {
5014   const char* hlo_string = R"(
5015     HloModule module
5016 
5017     ENTRY test {
5018       param.0 = pred[2] parameter(0)
5019       not.0 = pred[2] not(param.0)
5020       ROOT not.1 = pred[2] not(not.0)
5021     }
5022   )";
5023   TF_ASSERT_OK_AND_ASSIGN(auto module,
5024                           ParseAndReturnVerifiedModule(hlo_string));
5025 
5026   AlgebraicSimplifierOptions options;
5027   AlgebraicSimplifier simplifier(options);
5028   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5029   auto root = module->entry_computation()->root_instruction();
5030   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5031 }
5032 
5033 struct PadReduceWindowEffectiveBroadcastCase {
5034   std::vector<int64> input_spatials;
5035   std::vector<int64> symmetric_pad_spatials;
5036   std::vector<int64> reduce_window_spatials;
5037   // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
5038   //
5039   // This doesn't test any different functionality but is useful for making sure
5040   // kBroadcast nodes are well formed.
5041   bool prepend_a;
5042   bool should_become_broadcast;
5043 
ToTestCaseNamexla::__anonac242c730111::PadReduceWindowEffectiveBroadcastCase5044   string ToTestCaseName() const {
5045     return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
5046                         absl::StrJoin(symmetric_pad_spatials, ","), ";",
5047                         absl::StrJoin(reduce_window_spatials, ","), ";",
5048                         prepend_a, ";", should_become_broadcast);
5049   }
5050 };
5051 
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)5052 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
5053   *os << c.ToTestCaseName();
5054 }
5055 
5056 class PadReduceWindowEffectiveBroadcastTest
5057     : public AlgebraicSimplifierTest,
5058       public ::testing::WithParamInterface<
5059           PadReduceWindowEffectiveBroadcastCase> {};
5060 
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)5061 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
5062   auto m = CreateNewVerifiedModule();
5063   const auto& param = GetParam();
5064 
5065   // a and b are parallel bounds we can either turn into a B F S0 S1 or
5066   // `B S0 S1 F` kind of pattern.
5067   auto decorate_spatials = [&param](absl::Span<const int64> spatials, int64 a,
5068                                     int64 b) {
5069     std::vector<int64> result;
5070     if (param.prepend_a) {
5071       result.push_back(a);
5072     }
5073     for (int64 s : spatials) {
5074       result.push_back(s);
5075     }
5076     if (!param.prepend_a) {
5077       result.push_back(a);
5078     }
5079     result.push_back(b);
5080     return result;
5081   };
5082 
5083   HloComputation::Builder builder(TestName());
5084   const Shape input_shape = ShapeUtil::MakeShape(
5085       F32, decorate_spatials(param.input_spatials, 128, 2048));
5086   HloInstruction* input = builder.AddInstruction(
5087       HloInstruction::CreateParameter(0, input_shape, "input"));
5088 
5089   PaddingConfig padding = window_util::MakeSymmetricPadding(
5090       decorate_spatials(param.symmetric_pad_spatials, 0, 0));
5091   TF_ASSERT_OK_AND_ASSIGN(
5092       const Shape pad_shape,
5093       ShapeInference::InferPadShape(input->shape(),
5094                                     ShapeUtil::MakeShape(F32, {}), padding));
5095   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
5096       pad_shape, input,
5097       builder.AddInstruction(
5098           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
5099       padding));
5100 
5101   HloComputation* add_computation = nullptr;
5102   {
5103     HloComputation::Builder builder(TestName() + ".add");
5104     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
5105     HloInstruction* p0 = builder.AddInstruction(
5106         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
5107     HloInstruction* p1 = builder.AddInstruction(
5108         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
5109     builder.AddInstruction(
5110         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
5111     add_computation = m->AddEmbeddedComputation(builder.Build());
5112   }
5113 
5114   Window window = window_util::MakeWindow(
5115       decorate_spatials(param.reduce_window_spatials, 1, 1));
5116   auto zero = builder.AddInstruction(
5117       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
5118   TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
5119                           ShapeInference::InferReduceWindowShape(
5120                               pad->shape(), zero->shape(), window,
5121                               add_computation->ComputeProgramShape()));
5122   builder.AddInstruction(HloInstruction::CreateReduceWindow(
5123       output_shape, pad, zero, window, add_computation));
5124 
5125   auto computation = m->AddEntryComputation(builder.Build());
5126   AlgebraicSimplifier simplifier(default_options_);
5127   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5128   ASSERT_TRUE(run_successful);
5129 
5130   EXPECT_TRUE(
5131       ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
5132 
5133   if (param.should_become_broadcast) {
5134     EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
5135   } else {
5136     EXPECT_THAT(computation->root_instruction(),
5137                 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
5138   }
5139 }
5140 
5141 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()5142 PadReduceWindowEffectiveBroadcastCases() {
5143   static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
5144       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5145        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5146        /*should_become_broadcast=*/true},  //
5147       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5148        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
5149        /*should_become_broadcast=*/true},  //
5150       {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
5151        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5152        /*should_become_broadcast=*/false},  //
5153       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
5154        /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
5155        /*should_become_broadcast=*/false},  //
5156       {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
5157        /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
5158        /*should_become_broadcast=*/false},  //
5159   };
5160   return *cases;
5161 }
5162 
5163 INSTANTIATE_TEST_SUITE_P(
5164     PadReduceWindowEffectiveBroadcastInstantiation,
5165     PadReduceWindowEffectiveBroadcastTest,
5166     ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
5167 
5168 class BatchDotStrengthReductionTest
5169     : public AlgebraicSimplifierTest,
5170       public ::testing::WithParamInterface<
5171           ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)5172 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
5173   auto module = CreateNewVerifiedModule();
5174   int m, k, n;
5175   PrimitiveType element_type;
5176   std::tie(m, k, n, element_type) = GetParam();
5177   std::vector<int64> lhs_dims = {2, 3, 5};
5178   std::vector<int64> rhs_dims = lhs_dims;
5179   std::vector<int64> output_dims = lhs_dims;
5180   if (m > 0) {
5181     lhs_dims.push_back(m);
5182     output_dims.push_back(m);
5183   }
5184   if (k > 0) {
5185     lhs_dims.push_back(k);
5186     rhs_dims.push_back(k);
5187   }
5188   if (n > 0) {
5189     rhs_dims.push_back(n);
5190     output_dims.push_back(n);
5191   }
5192   Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
5193   Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
5194   Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
5195   HloComputation::Builder builder(TestName());
5196 
5197   auto lhs = builder.AddInstruction(
5198       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
5199   auto rhs = builder.AddInstruction(
5200       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
5201   DotDimensionNumbers dot_dnums;
5202   dot_dnums.add_lhs_batch_dimensions(0);
5203   dot_dnums.add_lhs_batch_dimensions(1);
5204   dot_dnums.add_lhs_batch_dimensions(2);
5205   dot_dnums.add_rhs_batch_dimensions(0);
5206   dot_dnums.add_rhs_batch_dimensions(1);
5207   dot_dnums.add_rhs_batch_dimensions(2);
5208   if (k > 0) {
5209     dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
5210     dot_dnums.add_rhs_contracting_dimensions(3);
5211   }
5212   builder.AddInstruction(HloInstruction::CreateDot(
5213       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5214   auto computation = module->AddEntryComputation(builder.Build());
5215   AlgebraicSimplifier simplifier(default_options_);
5216   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5217   const bool dot_should_be_transformed =
5218       m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
5219   EXPECT_EQ(changed, dot_should_be_transformed);
5220   TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5221   bool has_no_dot = true;
5222   for (const auto& hlo : computation->instructions()) {
5223     if (hlo->opcode() == HloOpcode::kDot) {
5224       has_no_dot = false;
5225       break;
5226     }
5227   }
5228   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5229 }
5230 
5231 INSTANTIATE_TEST_SUITE_P(
5232     BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
5233     ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2),
5234                        ::testing::Values(-1, 1, 2),
5235                        ::testing::Values(C128, C64, F64, F32, BF16)));
5236 
5237 class DotStrengthReductionTest
5238     : public AlgebraicSimplifierTest,
5239       public ::testing::WithParamInterface<
5240           ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)5241 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
5242   auto module = CreateNewVerifiedModule();
5243   int m, k, n;
5244   bool transpose_lhs, transpose_rhs;
5245   PrimitiveType element_type;
5246   std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
5247 
5248   Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
5249   Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
5250   Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
5251   Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
5252   Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
5253   HloComputation::Builder builder(TestName());
5254 
5255   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
5256       0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
5257   if (transpose_lhs) {
5258     lhs = builder.AddInstruction(
5259         HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
5260   }
5261   auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
5262       1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
5263   if (transpose_rhs) {
5264     rhs = builder.AddInstruction(
5265         HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
5266   }
5267   DotDimensionNumbers dot_dnums;
5268   dot_dnums.add_lhs_contracting_dimensions(1);
5269   dot_dnums.add_rhs_contracting_dimensions(0);
5270   builder.AddInstruction(HloInstruction::CreateDot(
5271       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5272   auto computation = module->AddEntryComputation(builder.Build());
5273   AlgebraicSimplifier simplifier(default_options_);
5274   // First pass of algebraic simplifier will remove degenerate dimensions
5275   // and optimize dot(transpose(x),transpose(y))
5276   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5277   const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
5278   const bool computation_should_be_modified =
5279       dot_should_be_transformed || (transpose_lhs && transpose_rhs);
5280   EXPECT_EQ(changed, computation_should_be_modified);
5281   // The second pass of algebraic simplifier will remove dots without
5282   // non-contracting dimensions or contracting dimensions.
5283   TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5284   EXPECT_EQ(changed, computation_should_be_modified);
5285   bool has_no_dot = true;
5286   for (const auto& hlo : computation->instructions()) {
5287     if (hlo->opcode() == HloOpcode::kDot) {
5288       has_no_dot = false;
5289       break;
5290     }
5291   }
5292   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5293 }
5294 
5295 INSTANTIATE_TEST_SUITE_P(
5296     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
5297     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
5298                        ::testing::Values(1, 2), ::testing::Bool(),
5299                        ::testing::Bool(),
5300                        ::testing::Values(C128, C64, F64, F32, BF16)));
5301 
5302 struct DotOfConcatTestSpec {
5303   int64 m;
5304   int64 k;
5305   int64 n;
5306 };
5307 
5308 class DotOfConcatSimplificationTest
5309     : public AlgebraicSimplifierTest,
5310       public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
5311 
5312 // Test that we transform
5313 //  dot(const, concat(A, B, C))
5314 // to
5315 //  add(dot(const_0, A), dot(const_1, B),  dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)5316 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
5317   auto m = CreateNewVerifiedModule();
5318   HloComputation::Builder builder(TestName());
5319 
5320   DotOfConcatTestSpec spec = GetParam();
5321 
5322   ASSERT_GE(spec.k, 3);
5323 
5324   int64 k0 = spec.k / 3;
5325   int64 k1 = spec.k / 3;
5326   int64 k2 = spec.k - k0 - k1;
5327 
5328   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5329   auto* lhs = builder.AddInstruction(
5330       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5331           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
5332 
5333   Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
5334   Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
5335   Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
5336 
5337   HloInstruction* rhs0 = builder.AddInstruction(
5338       HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
5339   HloInstruction* rhs1 = builder.AddInstruction(
5340       HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
5341   HloInstruction* rhs2 = builder.AddInstruction(
5342       HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
5343 
5344   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5345   HloInstruction* rhs = builder.AddInstruction(
5346       HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
5347 
5348   DotDimensionNumbers dot_dnums;
5349   dot_dnums.add_lhs_contracting_dimensions(1);
5350   dot_dnums.add_rhs_contracting_dimensions(0);
5351 
5352   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5353   builder.AddInstruction(HloInstruction::CreateDot(
5354       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5355 
5356   auto computation = m->AddEntryComputation(builder.Build());
5357   AlgebraicSimplifier simplifier(default_options_);
5358   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5359   ASSERT_TRUE(run_successful);
5360 
5361   EXPECT_TRUE(
5362       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5363 
5364   auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
5365   auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
5366   auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
5367   EXPECT_THAT(
5368       computation->root_instruction(),
5369       GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
5370 }
5371 
5372 // Test that we transform
5373 //  dot(concat(A, B, C), const)
5374 // to
5375 //  add(dot(A, const_0), dot(B, const_1),  dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)5376 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
5377   auto m = CreateNewVerifiedModule();
5378   HloComputation::Builder builder(TestName());
5379 
5380   DotOfConcatTestSpec spec = GetParam();
5381 
5382   ASSERT_GE(spec.k, 4);
5383 
5384   int64 k0 = spec.k / 4;
5385   int64 k1 = spec.k / 4;
5386   int64 k2 = spec.k / 4;
5387   int64 k3 = spec.k - k0 - k1 - k2;
5388 
5389   Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
5390   Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
5391   Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
5392   Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
5393 
5394   HloInstruction* lhs0 = builder.AddInstruction(
5395       HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
5396   HloInstruction* lhs1 = builder.AddInstruction(
5397       HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
5398   HloInstruction* lhs2 = builder.AddInstruction(
5399       HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
5400   HloInstruction* lhs3 = builder.AddInstruction(
5401       HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
5402 
5403   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5404   HloInstruction* lhs =
5405       builder.AddInstruction(HloInstruction::CreateConcatenate(
5406           lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
5407 
5408   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5409   auto* rhs = builder.AddInstruction(
5410       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5411           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
5412 
5413   DotDimensionNumbers dot_dnums;
5414   dot_dnums.add_lhs_contracting_dimensions(1);
5415   dot_dnums.add_rhs_contracting_dimensions(0);
5416 
5417   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5418   builder.AddInstruction(HloInstruction::CreateDot(
5419       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5420 
5421   auto computation = m->AddEntryComputation(builder.Build());
5422   AlgebraicSimplifier simplifier(default_options_);
5423   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5424   ASSERT_TRUE(run_successful);
5425   EXPECT_TRUE(
5426       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5427 
5428   auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
5429   auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
5430   auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
5431   auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
5432   EXPECT_THAT(
5433       computation->root_instruction(),
5434       GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
5435                         match_dot_3)));
5436 }
5437 
5438 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
5439     {/*m=*/3, /*k=*/9, /*n=*/3},    //
5440     {/*m=*/3, /*k=*/20, /*n=*/3},   //
5441     {/*m=*/1, /*k=*/18, /*n=*/5},   //
5442     {/*m=*/20, /*k=*/20, /*n=*/1},  //
5443     {/*m=*/1, /*k=*/16, /*n=*/1},   //
5444 };
5445 
TEST_F(DotOfConcatSimplificationTest,ConcatIntoScalarDot)5446 TEST_F(DotOfConcatSimplificationTest, ConcatIntoScalarDot) {
5447   const char* kModuleStr = R"(
5448     HloModule m
5449     test {
5450       param0 = f32[4] parameter(0)
5451       param1 = f32[1] parameter(1)
5452       constant = f32[5] constant({-0.38, 0.07, -0.62, 0.66, 0.20})
5453       concat = f32[5] concatenate(param0, param1), dimensions={0}
5454       ROOT dot = f32[] dot(concat, constant), lhs_contracting_dims={0},
5455                                               rhs_contracting_dims={0}
5456     })";
5457   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5458   AlgebraicSimplifierOptions options = default_options_;
5459   options.set_enable_dot_strength_reduction(false);
5460   ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
5461 }
5462 
5463 // Test that DynamicUpdateSlice update param with any dimension equal to zero
5464 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)5465 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
5466   auto m = CreateNewVerifiedModule();
5467   HloComputation::Builder builder(TestName());
5468   const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
5469   HloInstruction* const operand = builder.AddInstruction(
5470       HloInstruction::CreateParameter(0, dslice_shape, "operand"));
5471   const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
5472   HloInstruction* const update = builder.AddInstruction(
5473       HloInstruction::CreateParameter(1, update_shape, "update"));
5474   HloInstruction* const start_indices = builder.AddInstruction(
5475       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
5476   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
5477       dslice_shape, operand, update,
5478       std::initializer_list<HloInstruction*>({start_indices})));
5479   const HloComputation* const computation =
5480       m->AddEntryComputation(builder.Build());
5481 
5482   AlgebraicSimplifier simplifier(default_options_);
5483   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5484   EXPECT_THAT(computation->root_instruction(), operand);
5485 }
5486 
5487 // Test that dynamic-update-slice with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPad)5488 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPad) {
5489   const char* hlo_string = R"(
5490 HloModule AddBroadcastZeroWithDynamicSlice
5491 
5492 ENTRY AddBroadcastZeroWithDynamicSlice {
5493   param0 = f32[1800,12,512]{2,1,0} parameter(0)
5494   constant = f32[] constant(0)
5495   broadcast = f32[1800,12,512]{2,1,0} broadcast(constant), dimensions={}
5496   param1 = f32[1,12,512]{2,1,0} parameter(1)
5497   constant.1 = s32[] constant(0)
5498   dynamic-update-slice = f32[1800,12,512]{2,1,0} dynamic-update-slice(broadcast, param1, constant.1, constant.1, constant.1)
5499   ROOT add = f32[1800,12,512]{2,1,0} add(param0, dynamic-update-slice)
5500 }
5501 )";
5502   TF_ASSERT_OK_AND_ASSIGN(auto module,
5503                           ParseAndReturnVerifiedModule(hlo_string));
5504   VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
5505   AlgebraicSimplifier simplifier(default_options_);
5506   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5507   VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
5508   auto root = module->entry_computation()->root_instruction();
5509   EXPECT_THAT(root->opcode(), HloOpcode::kAdd);
5510   EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
5511 }
5512 
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReduction)5513 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
5514   const char* hlo_string = R"(
5515 HloModule ConstScalarMultiply
5516 ENTRY ConstScalarMultiply {
5517   param0 = f32[16,512,4096]{2,1,0} parameter(0)
5518   constant.0 = f32[] constant(0.5)
5519   broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
5520   multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
5521   param1 = f32[16,512,4096]{2,1,0} parameter(1)
5522   multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
5523   param2 = f32[16,512,1024]{2,1,0} parameter(2)
5524   constant.1 = f32[] constant(1.109)
5525   broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
5526   multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
5527   ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
5528 }
5529 )";
5530   TF_ASSERT_OK_AND_ASSIGN(auto module,
5531                           ParseAndReturnVerifiedModule(hlo_string));
5532   AlgebraicSimplifierOptions options;
5533   options.set_enable_scalar_multiply_reduction(true);
5534   AlgebraicSimplifier simplifier(options);
5535   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5536   auto root = module->entry_computation()->root_instruction();
5537   EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
5538   EXPECT_THAT(root,
5539               GmockMatch(m::MultiplyAnyOrder(
5540                   m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
5541 }
5542 
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReductionMultiUser)5543 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
5544   const char* hlo_string = R"(
5545 HloModule ConstScalarMultiply
5546 ENTRY ConstScalarMultiply {
5547   param0 = f32[16,512,1024] parameter(0)
5548   param1 = f32[4096,1024,1] parameter(1)
5549   convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
5550   constant.1 = f32[] constant(0.5)
5551   broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
5552   multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
5553   param2 = f32[16,512,4096] parameter(2)
5554   multiply.2 = f32[16,512,4096] multiply(convolution, param2)
5555   ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
5556 }
5557 )";
5558   TF_ASSERT_OK_AND_ASSIGN(auto module,
5559                           ParseAndReturnVerifiedModule(hlo_string));
5560   AlgebraicSimplifierOptions options;
5561   options.set_enable_scalar_multiply_reduction(true);
5562   AlgebraicSimplifier simplifier(options);
5563   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
5564 }
5565 
5566 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
5567                          DotOfConcatSimplificationTest,
5568                          ::testing::ValuesIn(kDotOfConcatTestSpecs));
5569 
5570 struct DotOfGatherTestSpec {
5571   int64 m;
5572   int64 k;
5573   int64 n;
5574   int s;      // start index for dynamic slice on the non-contracting dimension
5575   int64 lcd;  // left contracting dimension
5576   int64 rcd;  // right contracting dimension
5577   bool neg;   // is negative testcase
5578 };
5579 
5580 class DotOfGatherSimplificationTest
5581     : public AlgebraicSimplifierTest,
5582       public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
5583 
5584 // input: dot(DS(ctA), ctB))
5585 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
5586 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
5587 // output: DS(dot(ctA, ctB))
5588 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)5589 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
5590   auto m = CreateNewVerifiedModule();
5591   HloComputation::Builder builder(TestName());
5592 
5593   DotOfGatherTestSpec spec = GetParam();
5594 
5595   ASSERT_LE(spec.s, spec.m);
5596 
5597   // For negative tests, increase k of the dynamic slice argument to prevent the
5598   // optimization (constants ctA, ctB must have equal contracting dimensions).
5599   int64 k_increase = spec.neg ? 5 : 0;
5600   int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
5601   int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
5602   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
5603   auto* lhs = builder.AddInstruction(
5604       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5605           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
5606           /*cols=*/lhs_cols)));
5607 
5608   int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
5609   int32 start_col = (spec.lcd == 0) ? spec.s : 0;
5610   std::vector<HloInstruction*> start_indices = {
5611       builder.AddInstruction(HloInstruction::CreateConstant(
5612           LiteralUtil::CreateR0<int32>(start_row))),
5613       builder.AddInstruction(HloInstruction::CreateConstant(
5614           LiteralUtil::CreateR0<int32>(start_col)))};
5615   int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
5616   int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
5617   std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
5618   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
5619   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5620       ds_shape, lhs, start_indices, slice_sizes));
5621 
5622   int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
5623   int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
5624   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
5625   auto* rhs = builder.AddInstruction(
5626       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5627           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
5628           /*cols=*/rhs_cols)));
5629 
5630   DotDimensionNumbers dot_dnums;
5631   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
5632   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
5633 
5634   int64 dot_row_size = 1;
5635   int64 dot_col_size = spec.n;
5636   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
5637   builder.AddInstruction(HloInstruction::CreateDot(
5638       dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5639 
5640   auto computation = m->AddEntryComputation(builder.Build());
5641   AlgebraicSimplifier simplifier(default_options_);
5642   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5643   ASSERT_TRUE(run_successful);
5644   EXPECT_TRUE(
5645       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5646 
5647   if (spec.neg) {
5648     EXPECT_NE(computation->root_instruction()->opcode(),
5649               HloOpcode::kDynamicSlice);
5650   } else {
5651     EXPECT_THAT(computation->root_instruction(),
5652                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
5653                                            m::Constant(), m::Constant())));
5654   }
5655 }
5656 
5657 // input: dot(ctA, DS(ctB))
5658 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
5659 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
5660 // output: DS(dot(ctA, ctB))
5661 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)5662 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
5663   auto m = CreateNewVerifiedModule();
5664   HloComputation::Builder builder(TestName());
5665 
5666   DotOfGatherTestSpec spec = GetParam();
5667 
5668   ASSERT_LE(spec.s, spec.n);
5669 
5670   int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
5671   int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
5672   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
5673   auto* lhs = builder.AddInstruction(
5674       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5675           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
5676           /*cols=*/lhs_cols)));
5677 
5678   // For negative tests increase k of the dynamic slice argument to prevent the
5679   // optimization
5680   int64 k_increase = spec.neg ? 5 : 0;
5681   int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
5682   int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
5683   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
5684   auto* rhs = builder.AddInstruction(
5685       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5686           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
5687           /*cols=*/rhs_cols)));
5688 
5689   int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
5690   int32 start_col = (spec.rcd == 0) ? spec.s : 0;
5691   std::vector<HloInstruction*> start_indices = {
5692       builder.AddInstruction(HloInstruction::CreateConstant(
5693           LiteralUtil::CreateR0<int32>(start_row))),
5694       builder.AddInstruction(HloInstruction::CreateConstant(
5695           LiteralUtil::CreateR0<int32>(start_col)))};
5696   int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
5697   int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
5698   std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
5699   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
5700   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5701       ds_shape, rhs, start_indices, slice_sizes));
5702 
5703   DotDimensionNumbers dot_dnums;
5704   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
5705   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
5706 
5707   int64 dot_row_size = spec.m;
5708   int64 dot_col_size = 1;
5709   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
5710   builder.AddInstruction(HloInstruction::CreateDot(
5711       dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
5712 
5713   auto computation = m->AddEntryComputation(builder.Build());
5714   AlgebraicSimplifier simplifier(default_options_);
5715   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5716   ASSERT_TRUE(run_successful);
5717   EXPECT_TRUE(
5718       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5719 
5720   if (spec.neg) {
5721     EXPECT_NE(computation->root_instruction()->opcode(),
5722               HloOpcode::kDynamicSlice);
5723   } else {
5724     EXPECT_THAT(computation->root_instruction(),
5725                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
5726                                            m::Constant(), m::Constant())));
5727   }
5728 }
5729 
DotOfGatherPositiveNegativeTests()5730 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
5731   std::vector<DotOfGatherTestSpec> positives = {
5732       // "Classical dot", i.e. matrix multiply:
5733       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
5734        /*neg=*/false},
5735       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
5736        /*neg=*/false},
5737       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
5738        /*neg=*/false},
5739       // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
5740       // dot(ct, ct) before DotOfGather optimization kicks in.
5741       // Contract on rows:
5742       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
5743        /*neg=*/false},
5744       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
5745        /*neg=*/false},
5746       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
5747        /*neg=*/false},
5748       // Reverse matrix multiply:
5749       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
5750        /*neg=*/false},
5751       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
5752        /*neg=*/false},
5753       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
5754        /*neg=*/false},
5755       // Contract on columns:
5756       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
5757        /*neg=*/false},
5758       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
5759        /*neg=*/false},
5760       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
5761        /*neg=*/false},
5762   };
5763   std::vector<DotOfGatherTestSpec> all;
5764   for (int i = 0; i < positives.size(); i++) {
5765     DotOfGatherTestSpec positive_test = positives[i];
5766     all.push_back(positive_test);
5767     DotOfGatherTestSpec negative_test = positive_test;
5768     negative_test.neg = true;
5769     all.push_back(negative_test);
5770   }
5771   return all;
5772 }
5773 
5774 INSTANTIATE_TEST_SUITE_P(
5775     DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
5776     ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
5777 
TEST_F(AlgebraicSimplifierTest,GatherOfScalarToBroadcast)5778 TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
5779   const char* hlo_string = R"(
5780   HloModule repeat
5781 
5782   ENTRY main {
5783     o = f32[1,1] parameter(0)
5784     i = s32[100,2] parameter(1)
5785     ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
5786                                   start_index_map={0,1},
5787                                   index_vector_dim=1,
5788                                   offset_dims={},
5789                                   slice_sizes={1,1}
5790   }
5791   )";
5792   TF_ASSERT_OK_AND_ASSIGN(auto module,
5793                           ParseAndReturnVerifiedModule(hlo_string));
5794 
5795   AlgebraicSimplifierOptions options;
5796   AlgebraicSimplifier simplifier(options);
5797   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5798   auto root = module->entry_computation()->root_instruction();
5799   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
5800 }
5801 
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)5802 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
5803   const char* hlo_string = R"(
5804 HloModule module
5805 
5806 reducer {
5807   parameter.1 = f32[] parameter(0)
5808   parameter.3 = f32[] parameter(2)
5809   add.2 = f32[] add(parameter.1, parameter.3)
5810   parameter.0 = f32[] parameter(1)
5811   parameter.2 = f32[] parameter(3)
5812   add.3 = f32[] add(parameter.0, parameter.2)
5813   ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
5814 }
5815 
5816 ENTRY entry {
5817   parameter.6 = (f32[], f32[]) parameter(0)
5818   get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
5819   get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
5820   constant = f32[] constant(0)
5821   ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
5822 }
5823 )";
5824   TF_ASSERT_OK_AND_ASSIGN(auto module,
5825                           ParseAndReturnVerifiedModule(hlo_string));
5826 
5827   AlgebraicSimplifierOptions options;
5828   AlgebraicSimplifier simplifier(options);
5829   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5830   auto root = module->entry_computation()->root_instruction();
5831   EXPECT_THAT(root, GmockMatch(m::Tuple(
5832                         m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
5833                         m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
5834 }
5835 
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)5836 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
5837   const char* hlo_string = R"(
5838 HloModule module
5839 
5840 reducer {
5841   parameter.1 = f32[] parameter(0)
5842   parameter.3 = f32[] parameter(2)
5843   mul.2 = f32[] add(parameter.1, parameter.3)
5844   parameter.0 = f32[] parameter(1)
5845   parameter.2 = f32[] parameter(3)
5846   add.3 = f32[] add(parameter.0, parameter.2)
5847   ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
5848 }
5849 
5850 ENTRY entry {
5851   parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
5852   get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
5853   get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
5854   constant.0 = f32[] constant(0)
5855   constant.1 = f32[] constant(1)
5856   ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
5857 }
5858 )";
5859   TF_ASSERT_OK_AND_ASSIGN(auto module,
5860                           ParseAndReturnVerifiedModule(hlo_string));
5861 
5862   AlgebraicSimplifierOptions options;
5863   AlgebraicSimplifier simplifier(options);
5864   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5865   auto root = module->entry_computation()->root_instruction();
5866   EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
5867                                         m::Broadcast(m::ConstantScalar(1)))));
5868 }
5869 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)5870 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
5871   auto builder = HloComputation::Builder(TestName());
5872   HloInstruction* param =
5873       builder.AddInstruction(HloInstruction::CreateParameter(
5874           0, ShapeUtil::MakeShape(F32, {1}), "param"));
5875   HloInstruction* broadcast =
5876       builder.AddInstruction(HloInstruction::CreateBroadcast(
5877           ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
5878 
5879   // Create a reshape with zero sized result and without layout.
5880   Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
5881   reshaped_shape.clear_layout();
5882   builder.AddInstruction(
5883       HloInstruction::CreateReshape(reshaped_shape, broadcast));
5884 
5885   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
5886   module->AddEntryComputation(builder.Build());
5887 
5888   AlgebraicSimplifierOptions options;
5889   AlgebraicSimplifier simplifier(options);
5890   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5891   HloInstruction* root = module->entry_computation()->root_instruction();
5892   EXPECT_THAT(root, GmockMatch(m::Constant()));
5893 }
5894 
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)5895 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
5896   Shape shape = ShapeUtil::MakeShape(F32, {});
5897   shape.clear_layout();
5898   auto builder = HloComputation::Builder(TestName());
5899   HloInstruction* param = builder.AddInstruction(
5900       HloInstruction::CreateParameter(0, shape, "param"));
5901 
5902   HloInstruction* const_value = builder.AddInstruction(
5903       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
5904   builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
5905                                                       param, const_value));
5906 
5907   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
5908   module->AddEntryComputation(builder.Build());
5909 
5910   AlgebraicSimplifierOptions options;
5911   AlgebraicSimplifier simplifier(options);
5912   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5913   HloInstruction* root = module->entry_computation()->root_instruction();
5914   EXPECT_THAT(root, GmockMatch(m::Multiply()));
5915 }
5916 
5917 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)5918 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
5919   const char* kModuleStr = R"(
5920     HloModule m
5921     test {
5922       p0 = f32[] parameter(0)
5923       p1 = f32[] parameter(1)
5924       sqrt = f32[] sqrt(p0)
5925       ROOT div = f32[] divide(p1, sqrt)
5926     }
5927   )";
5928   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5929   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5930   EXPECT_THAT(m->entry_computation()->root_instruction(),
5931               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
5932                                              m::Rsqrt(m::Parameter(0)))));
5933 }
5934 
5935 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)5936 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
5937   const char* kModuleStr = R"(
5938     HloModule m
5939     test {
5940       p0 = f32[] parameter(0)
5941       p1 = f32[] parameter(1)
5942       rsqrt = f32[] rsqrt(p0)
5943       ROOT div = f32[] divide(p1, rsqrt)
5944     }
5945   )";
5946   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5947   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5948   EXPECT_THAT(m->entry_computation()->root_instruction(),
5949               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
5950                                              m::Sqrt(m::Parameter(0)))));
5951 }
5952 
TEST_F(AlgebraicSimplifierTest,CopyReshape)5953 TEST_F(AlgebraicSimplifierTest, CopyReshape) {
5954   const char* kModuleStr = R"(
5955     HloModule m
5956     test {
5957       p0 = f32[168,168,48,48]{3,2,1,0} parameter(0)
5958       r0 = f32[1,168,168,2304]{3,2,1,0} reshape(p0)
5959       ROOT c0 = f32[1,168,168,2304]{3,0,2,1} copy(r0)
5960     })";
5961   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5962   Shape result_shape = m->entry_computation()->root_instruction()->shape();
5963   AlgebraicSimplifierOptions options(
5964       [](const Shape&, const Shape&) { return false; });
5965   options.set_is_layout_sensitive(true);
5966   ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
5967   EXPECT_THAT(
5968       m->entry_computation()->root_instruction(),
5969       GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape)));
5970 }
5971 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RL)5972 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) {
5973   const char* kModuleStr = R"(
5974     HloModule m
5975     test {
5976       rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
5977       t0 = f32[2, 2, 3] parameter(0)
5978       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
5979       lhs = f32[2, 6] reshape(t1)
5980       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5981     }
5982   )";
5983   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5984   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5985   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
5986   auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2});
5987   auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2});
5988   // The transformation of moving transpose and reshape to the constant side
5989   // is layout insensitive. We ignore layout when checking shapes.
5990   const HloInstruction* transpose;
5991   ASSERT_THAT(m->entry_computation()->root_instruction(),
5992               GmockMatch(m::Dot(
5993                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
5994                   m::Reshape(m::Transpose(&transpose,
5995                                           m::Reshape(m::Constant())
5996                                               .WithShapeCompatibleTo(&shape2))
5997                                  .WithShapeCompatibleTo(&shape3)))));
5998   EXPECT_THAT(transpose->dimensions(), ElementsAre(1, 0, 2));
5999 }
6000 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RR)6001 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) {
6002   const char* kModuleStr = R"(
6003     HloModule m
6004     test {
6005       rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6006                                 {1, 1, 1, 1, 1, 1}})
6007       t0 = f32[2, 2, 3] parameter(0)
6008       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6009       lhs = f32[2, 6] reshape(t1)
6010       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
6011     }
6012   )";
6013   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6014   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6015   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6016   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6017   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6018   EXPECT_THAT(m->entry_computation()->root_instruction(),
6019               GmockMatch(m::Dot(
6020                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6021                   m::Reshape(m::Transpose(m::Reshape(m::Constant())
6022                                               .WithShapeCompatibleTo(&shape2))
6023                                  .WithShapeCompatibleTo(&shape3)))));
6024 }
6025 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR)6026 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) {
6027   const char* kModuleStr = R"(
6028     HloModule m
6029     test {
6030       rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6031                                 {1, 1, 1, 1, 1, 1}})
6032       t0 = f32[2, 3, 2] parameter(0)
6033       t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2}
6034       lhs = f32[6, 2] reshape(t1)
6035       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1}
6036     }
6037   )";
6038   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6039   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6040   auto shape1 = ShapeUtil::MakeShape(F32, {6, 2});
6041   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6042   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6043   EXPECT_THAT(m->entry_computation()->root_instruction(),
6044               GmockMatch(m::Dot(
6045                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6046                   m::Reshape(m::Transpose(m::Reshape(m::Constant())
6047                                               .WithShapeCompatibleTo(&shape2))
6048                                  .WithShapeCompatibleTo(&shape3)))));
6049 }
6050 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR2)6051 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) {
6052   const char* kModuleStr = R"(
6053     HloModule m
6054     test {
6055       rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}})
6056       t0 = f32[2, 2, 2, 2] parameter(0)
6057       t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1}
6058       lhs = f32[2, 8] reshape(t1)
6059       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1},
6060                                             rhs_contracting_dims={0}
6061     }
6062   )";
6063   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6064   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6065   auto shape1 = ShapeUtil::MakeShape(F32, {2, 8});
6066   auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
6067   const HloInstruction* transpose;
6068   ASSERT_THAT(
6069       m->entry_computation()->root_instruction(),
6070       GmockMatch(m::Dot(
6071           m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6072           m::Reshape(m::Transpose(
6073               &transpose,
6074               m::Reshape(m::Constant()).WithShapeCompatibleTo(&shape2))))));
6075   EXPECT_THAT(transpose->dimensions(), ElementsAre(2, 0, 1, 3));
6076 }
6077 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_MM)6078 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) {
6079   const char* kModuleStr = R"(
6080     HloModule m
6081     test {
6082       rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}},
6083                                    {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}})
6084       t0 = f32[2, 2, 3, 2] parameter(0)
6085       t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3}
6086       lhs = f32[2, 6, 2] reshape(t1)
6087       ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1},
6088                                                rhs_batch_dims={0}, rhs_contracting_dims={1}
6089     }
6090   )";
6091   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6092   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6093   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2});
6094   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2});
6095   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
6096   const HloInstruction* transpose;
6097   ASSERT_THAT(m->entry_computation()->root_instruction(),
6098               GmockMatch(m::Dot(
6099                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6100                   m::Reshape(m::Transpose(&transpose,
6101                                           m::Reshape(m::Constant())
6102                                               .WithShapeCompatibleTo(&shape2))
6103                                  .WithShapeCompatibleTo(&shape3)))));
6104   EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6105 }
6106 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegTranspose)6107 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) {
6108   const char* kModuleStr = R"(
6109     HloModule m
6110     test {
6111       rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}})
6112       t0 = f32[3, 4, 2] parameter(0)
6113       t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1}
6114       lhs = f32[2, 12] reshape(t1)
6115       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6116     }
6117   )";
6118   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6119   // Transpose affects non-contracting dimension. The transpose and reshape
6120   // should not be moved to the constant side.
6121   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6122 }
6123 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegReshape)6124 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) {
6125   const char* kModuleStr = R"(
6126     HloModule m
6127     test {
6128       rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}})
6129       t0 = f32[2, 4, 3] parameter(0)
6130       t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1}
6131       lhs = f32[3, 8] reshape(t1)
6132       ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6133     }
6134   )";
6135   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6136   // Reshape affects non-contracting dimensions. The transpose and reshape
6137   // should not be moved to the constant side.
6138   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6139 }
6140 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegConstant)6141 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) {
6142   const char* kModuleStr = R"(
6143     HloModule m
6144     test {
6145       t0 = f32[2, 3, 4] parameter(0)
6146       t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1}
6147       lhs = f32[2, 12] reshape(t1)
6148       rhs = f32[12, 2] parameter(1)
6149       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6150     }
6151   )";
6152   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6153   // Both operands are non-constant, so the optimization should not happen.
6154   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6155 }
6156 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegLayout)6157 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) {
6158   const char* kModuleStr = R"(
6159     HloModule m
6160     test {
6161       rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6162       t0 = f32[2, 2, 3] parameter(0)
6163       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6164       lhs = f32[2, 6] reshape(t1)
6165       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6166     }
6167   )";
6168   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6169   // We disable converting reshape to bitcast to make sure algsimp pass does
6170   // not catch the reshape in this test, then we can simply check if algsimp
6171   // pass does not make any change.
6172   AlgebraicSimplifierOptions options(
6173       [](const Shape&, const Shape&) { return false; });
6174   options.set_is_layout_sensitive(true);
6175   // The transformation of moving transpose and reshape to the constant side is
6176   // layout insensitive. It should not happen if AlgebraicSimplifier is set up
6177   // to be layout sensitive.
6178   ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6179 }
6180 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDimsNoChange)6181 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) {
6182   // This isn't transformed (notice that the relative order of the `2` and `3`
6183   // dims doesn't change, so there's no opportunity here), but it's nonetheless
6184   // an interesting testcase because of the presence of the size-1 dimensions.
6185   const char* kModuleStr = R"(
6186     HloModule m
6187     test {
6188      param = f32[1,2,5,3] parameter(0)
6189      transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3}
6190      reshape = f32[5,6] reshape(transpose)
6191      constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6192      ROOT dot = f32[5,4] dot(reshape, constant),
6193        lhs_contracting_dims={1}, rhs_contracting_dims={0}
6194     }
6195   )";
6196   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6197   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6198 }
6199 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDims)6200 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
6201   const char* kModuleStr = R"(
6202     HloModule m
6203     test {
6204      param = f32[1,2,3,5] parameter(0)
6205      transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3}
6206      reshape = f32[6,5] reshape(transpose)
6207      constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6208      ROOT dot = f32[5,4] dot(reshape, constant),
6209        lhs_contracting_dims={0}, rhs_contracting_dims={0}
6210     }
6211   )";
6212   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6213   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6214   auto shape1 = ShapeUtil::MakeShape(F32, {6, 5});
6215   auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4});
6216   auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
6217   const HloInstruction* transpose;
6218   ASSERT_THAT(m->entry_computation()->root_instruction(),
6219               GmockMatch(m::Dot(
6220                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6221                   m::Reshape(m::Transpose(&transpose,
6222                                           m::Reshape(m::Constant())
6223                                               .WithShapeCompatibleTo(&shape2))
6224                                  .WithShapeCompatibleTo(&shape3)))));
6225   EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6226 }
6227 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NoChangeInContractingDimsOrder)6228 TEST_F(AlgebraicSimplifierTest,
6229        DotContractingReorder_NoChangeInContractingDimsOrder) {
6230   // No optimization opportunity here because the transpose does not reorder the
6231   // contracting dims.
6232   const char* kModuleStr = R"(
6233     HloModule m
6234     test {
6235       param = f32[2,5,1,3] parameter(0)
6236       transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
6237       reshape = f32[5,6] reshape(transpose)
6238       constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6239       ROOT dot = f32[5,4] dot(reshape, constant),
6240         lhs_contracting_dims={1}, rhs_contracting_dims={0}
6241     }
6242   )";
6243   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6244   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6245 }
6246 
TEST_F(AlgebraicSimplifierTest,CompareIota)6247 TEST_F(AlgebraicSimplifierTest, CompareIota) {
6248   const char* kModuleStr = R"(
6249     HloModule m
6250     test {
6251       zero = s32[] constant(0)
6252       iota = s32[128] iota(), iota_dimension=0
6253       broad = s32[128] broadcast(zero), dimensions={}
6254       ROOT compare = pred[128] compare(iota, broad), direction=LT
6255     })";
6256   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6257   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6258   EXPECT_THAT(m->entry_computation()->root_instruction(),
6259               GmockMatch(m::Broadcast(m::ConstantScalar(false))));
6260 }
6261 
TEST_F(AlgebraicSimplifierTest,CompareLtZero)6262 TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
6263   const char* kModuleStr = R"(
6264     HloModule m
6265     test {
6266       zero = u32[] constant(0)
6267       param = u32[] parameter(0)
6268       ROOT compare = pred[] compare(param, zero), direction=LT
6269     })";
6270   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6271   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6272   EXPECT_THAT(m->entry_computation()->root_instruction(),
6273               GmockMatch(m::ConstantScalar(false)));
6274 }
6275 
TEST_F(AlgebraicSimplifierTest,CompareLeZero)6276 TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
6277   const char* kModuleStr = R"(
6278     HloModule m
6279     test {
6280       zero = u32[] constant(0)
6281       param = u32[] parameter(0)
6282       ROOT compare = pred[] compare(param, zero), direction=LE
6283     })";
6284   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6285   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6286   EXPECT_THAT(
6287       m->entry_computation()->root_instruction(),
6288       GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6289 }
6290 
TEST_F(AlgebraicSimplifierTest,CompareGeZero)6291 TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
6292   const char* kModuleStr = R"(
6293     HloModule m
6294     test {
6295       zero = u32[] constant(0)
6296       param = u32[] parameter(0)
6297       ROOT compare = pred[] compare(param, zero), direction=GE
6298     })";
6299   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6300   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6301   EXPECT_THAT(m->entry_computation()->root_instruction(),
6302               GmockMatch(m::ConstantScalar(true)));
6303 }
6304 
TEST_F(AlgebraicSimplifierTest,CompareGtZero)6305 TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
6306   const char* kModuleStr = R"(
6307     HloModule m
6308     test {
6309       zero = u32[] constant(0)
6310       param = u32[] parameter(0)
6311       ROOT compare = pred[] compare(param, zero), direction=GT
6312     })";
6313   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6314   EXPECT_THAT(
6315       m->entry_computation()->root_instruction(),
6316       GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6317 }
6318 
TEST_F(AlgebraicSimplifierTest,CompareZeroGt)6319 TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
6320   const char* kModuleStr = R"(
6321     HloModule m
6322     test {
6323       zero = u32[] constant(0)
6324       param = u32[] parameter(0)
6325       ROOT compare = pred[] compare(zero, param), direction=GT
6326     })";
6327   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6328   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6329   EXPECT_THAT(m->entry_computation()->root_instruction(),
6330               GmockMatch(m::ConstantScalar(false)));
6331 }
6332 
TEST_F(AlgebraicSimplifierTest,CompareZeroGe)6333 TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
6334   const char* kModuleStr = R"(
6335     HloModule m
6336     test {
6337       zero = u32[] constant(0)
6338       param = u32[] parameter(0)
6339       ROOT compare = pred[] compare(zero, param), direction=GE
6340     })";
6341   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6342   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6343   EXPECT_THAT(
6344       m->entry_computation()->root_instruction(),
6345       GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
6346 }
6347 
TEST_F(AlgebraicSimplifierTest,CompareZeroLe)6348 TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
6349   const char* kModuleStr = R"(
6350     HloModule m
6351     test {
6352       zero = u32[] constant(0)
6353       param = u32[] parameter(0)
6354       ROOT compare = pred[] compare(zero, param), direction=LE
6355     })";
6356   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6357   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6358   EXPECT_THAT(m->entry_computation()->root_instruction(),
6359               GmockMatch(m::ConstantScalar(true)));
6360 }
6361 
TEST_F(AlgebraicSimplifierTest,CompareZeroLt)6362 TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
6363   const char* kModuleStr = R"(
6364     HloModule m
6365     test {
6366       zero = u32[] constant(0)
6367       param = u32[] parameter(0)
6368       ROOT compare = pred[] compare(zero, param), direction=LT
6369     })";
6370   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6371   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6372   EXPECT_THAT(
6373       m->entry_computation()->root_instruction(),
6374       GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
6375 }
6376 
TEST_F(AlgebraicSimplifierTest,CompareSame)6377 TEST_F(AlgebraicSimplifierTest, CompareSame) {
6378   const char* kModuleStr = R"(
6379     HloModule m
6380     test {
6381       param = s32[123] parameter(0)
6382       ROOT compare = pred[123] compare(param, param), direction=GE
6383     })";
6384   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6385   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6386   EXPECT_THAT(m->entry_computation()->root_instruction(),
6387               GmockMatch(m::Broadcast(m::ConstantScalar(true))));
6388 }
6389 
TEST_F(AlgebraicSimplifierTest,CompareSimplified)6390 TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
6391   const char* kModuleStr = R"(
6392     HloModule m
6393     test {
6394       param = s32[] parameter(0)
6395       c1 = s32[] constant(10)
6396       c2 = s32[] constant(100)
6397       cmp1 = pred[] compare(param, c1), direction=LT
6398       cmp2 = pred[] compare(param, c2), direction=LT
6399       ROOT out = pred[] and(cmp1, cmp2)
6400     })";
6401   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6402   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6403   EXPECT_THAT(
6404       m->entry_computation()->root_instruction(),
6405       GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
6406                      .WithComparisonDirection(ComparisonDirection::kLt)));
6407 }
6408 
TEST_F(AlgebraicSimplifierTest,CompareSimplifiedReversed)6409 TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) {
6410   const char* kModuleStr = R"(
6411     HloModule m
6412     test {
6413       param = s32[] parameter(0)
6414       c1 = s32[] constant(10)
6415       c2 = s32[] constant(100)
6416       cmp1 = pred[] compare(param, c1), direction=LT
6417       cmp2 = pred[] compare(c2, param), direction=GT
6418       ROOT out = pred[] and(cmp1, cmp2)
6419     })";
6420   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6421   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6422   EXPECT_THAT(
6423       m->entry_computation()->root_instruction(),
6424       GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
6425                      .WithComparisonDirection(ComparisonDirection::kLt)));
6426 }
6427 
TEST_F(AlgebraicSimplifierTest,CanDisableDotToMultiplyRewrite)6428 TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
6429   // Some backends may have better performance by treating an outer product as a
6430   // Dot, rather than a broadcast Multiply
6431   const char* kModuleStr = R"(
6432     HloModule m
6433     test {
6434       param1 = f32[64] parameter(0)
6435       param2 = f32[64] parameter(1)
6436       ROOT compare = f32[64, 64] dot(param1, param2),
6437         lhs_contracting_dims={}, rhs_contracting_dims={}
6438     })";
6439 
6440   // Verify that the default is to re-write
6441   TF_ASSERT_OK_AND_ASSIGN(auto m1, ParseAndReturnVerifiedModule(kModuleStr));
6442   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m1.get()).ValueOrDie());
6443   EXPECT_THAT(m1->entry_computation()->root_instruction(),
6444               GmockMatch(m::Multiply(m::Op(), m::Op())));
6445 
6446   // Verify that we can disable the re-write
6447   AlgebraicSimplifierOptions opts = default_options_;
6448   opts.set_enable_dot_to_multiply_rewrite(false);
6449   TF_ASSERT_OK_AND_ASSIGN(auto m2, ParseAndReturnVerifiedModule(kModuleStr));
6450   ASSERT_FALSE(AlgebraicSimplifier(opts).Run(m2.get()).ValueOrDie());
6451 }
6452 
TEST_F(AlgebraicSimplifierTest,RemainderOfIota)6453 TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
6454   const char* kModuleStr = R"(
6455     HloModule m
6456     test {
6457       iota = s32[5,1000] iota(), iota_dimension=0
6458       five = s32[] constant(5)
6459       five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
6460       ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
6461     })";
6462   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6463   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6464   EXPECT_THAT(m->entry_computation()->root_instruction(),
6465               GmockMatch(m::Iota()));
6466 }
6467 
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIota)6468 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
6469   const char* kModuleStr = R"(
6470     HloModule m
6471     test {
6472       iota = s32[5,1000] iota(), iota_dimension=0
6473       five = s32[] constant(5)
6474       five_bcast = s32[5,1000] broadcast(five), dimensions={}
6475       sum = s32[5,1000] add(iota, five_bcast)
6476       ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
6477     })";
6478   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6479   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6480   EXPECT_THAT(m->entry_computation()->root_instruction(),
6481               GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
6482 }
6483 
6484 // No simplification because 125 + 5 overflows S8.
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIotaOverflow)6485 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
6486   const char* kModuleStr = R"(
6487     HloModule m
6488     test {
6489       iota = s8[126] iota(), iota_dimension=0
6490       five = s8[] constant(5)
6491       five_bcast = s8[126] broadcast(five), dimensions={}
6492       sum = s8[126] add(iota, five_bcast)
6493       ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
6494     })";
6495   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6496   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6497 }
6498 
TEST_F(AlgebraicSimplifierTest,RepeatedRemainder)6499 TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
6500   const char* kModuleStr = R"(
6501     HloModule m
6502     test {
6503       p = s32[1000] parameter(0)
6504       q = s32[1000] parameter(1)
6505       r = s32[1000] remainder(p, q)
6506       ROOT rr = s32[1000] remainder(r, q)
6507     })";
6508   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6509   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6510   EXPECT_THAT(m->entry_computation()->root_instruction(),
6511               GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
6512 }
6513 
TEST_F(AlgebraicSimplifierTest,SlicePadLayout)6514 TEST_F(AlgebraicSimplifierTest, SlicePadLayout) {
6515   const char* kModuleStr = R"(
6516     HloModule m
6517     test {
6518       %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0)
6519       %param.1 = f32[] parameter(1)
6520       %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0),
6521         slice={[0:128], [0:9], [0:9], [0:1024]}
6522       ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1),
6523         padding=0_0x-1_0x0_0x0_0
6524     })";
6525   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6526   const Shape root_shape = m->entry_computation()->root_instruction()->shape();
6527   AlgebraicSimplifierOptions options;
6528   options.set_is_layout_sensitive(true);
6529   ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6530   EXPECT_THAT(m->entry_computation()->root_instruction(),
6531               GmockMatch(m::Slice().WithShapeEqualTo(&root_shape)));
6532 }
6533 
TEST_F(AlgebraicSimplifierTest,MinOfMaxToClamp)6534 TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) {
6535   const char* kModuleStr = R"(
6536     HloModule m
6537     test {
6538       p0 = f32[4] parameter(0)
6539       c0 = f32[] constant(3.0)
6540       c1 = f32[] constant(4.0)
6541       b0 = f32[4] broadcast(c0), dimensions={}
6542       b1 = f32[4] broadcast(c1), dimensions={}
6543       m0 = f32[4] maximum(b0, p0)
6544       ROOT m1 = f32[4] minimum(m0, b1)
6545     }
6546   )";
6547   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6548   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6549   EXPECT_THAT(
6550       m->entry_computation()->root_instruction(),
6551       GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
6552                           m::Broadcast(m::ConstantScalar(4.0)))));
6553 }
6554 
TEST_F(AlgebraicSimplifierTest,MaxOfMinToClamp)6555 TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) {
6556   const char* kModuleStr = R"(
6557     HloModule m
6558     test {
6559       p0 = f32[4] parameter(0)
6560       c0 = f32[] constant(3.0)
6561       c1 = f32[] constant(4.0)
6562       b0 = f32[4] broadcast(c0), dimensions={}
6563       b1 = f32[4] broadcast(c1), dimensions={}
6564       m0 = f32[4] minimum(p0, b1)
6565       ROOT m1 = f32[4] maximum(b0, m0)
6566     }
6567   )";
6568   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6569   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6570   EXPECT_THAT(
6571       m->entry_computation()->root_instruction(),
6572       GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
6573                           m::Broadcast(m::ConstantScalar(4.0)))));
6574 }
6575 
TEST_F(AlgebraicSimplifierTest,ClampOfClamp)6576 TEST_F(AlgebraicSimplifierTest, ClampOfClamp) {
6577   const char* kModuleStr = R"(
6578     HloModule m
6579     test {
6580       p0 = f32[] parameter(0)
6581       p1 = f32[] parameter(1)
6582       p2 = f32[] parameter(2)
6583       c0 = f32[] clamp(p0, p1, p2)
6584       ROOT c1 = f32[] clamp(p0, c0, p2)
6585     }
6586   )";
6587   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6588   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6589   EXPECT_THAT(
6590       m->entry_computation()->root_instruction(),
6591       GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
6592 }
6593 
TEST_F(AlgebraicSimplifierTest,MaxOfClamp)6594 TEST_F(AlgebraicSimplifierTest, MaxOfClamp) {
6595   const char* kModuleStr = R"(
6596     HloModule m
6597     test {
6598       p0 = f32[] parameter(0)
6599       p1 = f32[] parameter(1)
6600       p2 = f32[] parameter(2)
6601       c0 = f32[] clamp(p0, p1, p2)
6602       ROOT m0 = f32[] maximum(p0, c0)
6603     }
6604   )";
6605   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6606   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6607   EXPECT_THAT(
6608       m->entry_computation()->root_instruction(),
6609       GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
6610 }
6611 
TEST_F(AlgebraicSimplifierTest,SliceOfConcat)6612 TEST_F(AlgebraicSimplifierTest, SliceOfConcat) {
6613   const char* kModuleStr = R"(
6614     HloModule m
6615     test {
6616       p0 = f32[100,50] parameter(0)
6617       p1 = f32[50,50] parameter(1)
6618       c0 = f32[150,50] concatenate(p0, p1), dimensions={0}
6619       ROOT s0 = f32[50,50] slice(c0), slice={[100:150], [0:50]}
6620     }
6621   )";
6622   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6623   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6624   EXPECT_THAT(m->entry_computation()->root_instruction(),
6625               GmockMatch(m::Parameter(1)));
6626 }
6627 
TEST_F(AlgebraicSimplifierTest,SqrtOfSelfMultiply)6628 TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) {
6629   const char* kModuleStr = R"(
6630     HloModule m
6631     test {
6632       p0 = f32[32]{0} parameter(0)
6633       m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0)
6634       ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0)
6635     }
6636   )";
6637   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6638   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6639   EXPECT_THAT(m->entry_computation()->root_instruction(),
6640               GmockMatch(m::Abs(m::Parameter(0))));
6641 }
6642 
TEST_F(AlgebraicSimplifierTest,ReduceOfBatchDotToContractingDimension)6643 TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
6644   const char* kModuleStr = R"(
6645     HloModule m
6646     a {
6647       p0 = f32[] parameter(0)
6648       p1 = f32[] parameter(1)
6649       ROOT r = f32[] add(p0, p1)
6650     }
6651     test {
6652       p0 = f32[32,8,5,6] parameter(0)
6653       p1 = f32[8,32,6,7] parameter(1)
6654       d = f32[32,8,5,7] dot(p0, p1),
6655         lhs_batch_dims={0,1},
6656         rhs_batch_dims={1,0},
6657         rhs_contracting_dims={2},
6658         lhs_contracting_dims={3}
6659      c = f32[] constant(0)
6660      ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
6661     }
6662   )";
6663   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6664   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6665   EXPECT_THAT(m->entry_computation()->root_instruction(),
6666               GmockMatch(m::Dot(m::Parameter(0), m::Parameter(1))));
6667 }
6668 
TEST_F(AlgebraicSimplifierTest,RsqrtOfRPower)6669 TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) {
6670   const char* kModuleStr = R"(
6671     HloModule m
6672     test {
6673       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6674       p1 = f32[32]{0} parameter(1)
6675       p2 = f32[32]{0} parameter(2)
6676       c0 = f32[] constant(0.001)
6677       c1 = s64[] constant(1)
6678       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6679       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6680       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6681       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6682       c2 = f32[] constant(-2)
6683       broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={}
6684       power = f32[32]{0} power(get-tuple-element, broadcast)
6685       rsqrt = f32[32]{0} rsqrt(f32[32]{0} power)
6686       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
6687     }
6688   )";
6689   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6690   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6691       "__cudnn$batchNormalizationForwardTraining");
6692   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6693   // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2)
6694   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr);
6695   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6696   auto computation = m->entry_computation();
6697   auto root = computation->root_instruction();
6698   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6699   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs);
6700   EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
6701             HloOpcode::kGetTupleElement);
6702 }
6703 
TEST_F(AlgebraicSimplifierTest,RsqrtDivide)6704 TEST_F(AlgebraicSimplifierTest, RsqrtDivide) {
6705   const char* kModuleStr = R"(
6706     HloModule m
6707     test {
6708       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6709       p1 = f32[32]{0} parameter(1)
6710       p2 = f32[32]{0} parameter(2)
6711       constant = f32[] constant(0.001)
6712       constant.1 = s64[] constant(1)
6713       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6714       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6715       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6716       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6717       constant.2 = f32[] constant(1)
6718       broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={}
6719       divide = f32[32]{0} divide(broadcast.1, get-tuple-element)
6720       rsqrt = f32[32]{0} rsqrt(divide)
6721       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
6722     }
6723   )";
6724   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6725   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6726       "__cudnn$batchNormalizationForwardTraining");
6727   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6728   // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2)
6729   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
6730   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6731   auto computation = m->entry_computation();
6732   auto root = computation->root_instruction();
6733   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6734   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt);
6735   EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
6736             HloOpcode::kGetTupleElement);
6737 }
6738 
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt)6739 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) {
6740   const char* kModuleStr = R"(
6741     HloModule m
6742     test {
6743       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6744       p1 = f32[32]{0} parameter(1)
6745       p2 = f32[32]{0} parameter(2)
6746       constant = f32[] constant(0.001)
6747       constant.1 = s64[] constant(1)
6748       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6749       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6750       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6751       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6752       rsqrt = f32[32]{0} rsqrt(get-tuple-element)
6753       multiply = f32[32]{0} multiply(rsqrt, rsqrt)
6754       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
6755     }
6756   )";
6757   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6758   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6759       "__cudnn$batchNormalizationForwardTraining");
6760   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6761 
6762   // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1,
6763   // gte.2)
6764   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
6765   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6766 
6767   auto computation = m->entry_computation();
6768   auto root = computation->root_instruction();
6769   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6770   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide);
6771   EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast);
6772   EXPECT_EQ(root->operand(2)->operand(1)->opcode(),
6773             HloOpcode::kGetTupleElement);
6774 }
6775 
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt_NegativeTestCase)6776 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) {
6777   const char* kModuleStr = R"(
6778     HloModule m
6779     test {
6780       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6781       p1 = f32[32]{0} parameter(1)
6782       p2 = f32[32]{0} parameter(2)
6783       constant = f32[] constant(0.001)
6784       constant.1 = s64[] constant(1)
6785       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6786       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6787       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6788       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6789       rsqrt = f32[32]{0} rsqrt(get-tuple-element)
6790       multiply = f32[32]{0} multiply(rsqrt, rsqrt)
6791       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
6792     }
6793   )";
6794   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6795   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6796       "__cudnn$batchNormalizationForward");
6797   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6798   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
6799   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6800   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
6801   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr);
6802   EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
6803             HloOpcode::kMultiply);
6804 }
6805 
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining)6806 TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) {
6807   const char* kModuleStr = R"(
6808     HloModule m
6809     test {
6810       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6811       p1 = f32[32]{0} parameter(1)
6812       p2 = f32[32]{0} parameter(2)
6813       constant = f32[] constant(0.001)
6814       constant.1 = s64[] constant(1)
6815       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6816       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6817       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6818       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6819       abs = f32[32]{0} abs(get-tuple-element)
6820       ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
6821     }
6822   )";
6823   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6824   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6825       "__cudnn$batchNormalizationForwardTraining");
6826   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6827   // Verify that the module doesn't have any abs node.
6828   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
6829   EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
6830             HloOpcode::kGetTupleElement);
6831 }
6832 
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining_NegativeTestCase)6833 TEST_F(AlgebraicSimplifierTest,
6834        AbsEliminationBatchnormTraining_NegativeTestCase) {
6835   const char* kModuleStr = R"(
6836     HloModule m
6837     test {
6838       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6839       p1 = f32[32]{0} parameter(1)
6840       p2 = f32[32]{0} parameter(2)
6841       constant = f32[] constant(0.001)
6842       constant.1 = s64[] constant(1)
6843       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6844       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6845       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6846       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6847       abs = f32[32]{0} abs(get-tuple-element)
6848       ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
6849     }
6850   )";
6851   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6852   default_options_.set_cudnn_batchnorm_forward_training_metadata(
6853       "__cudnn$batchNormalizationForwardInference");
6854   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6855   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
6856 }
6857 
TEST_F(AlgebraicSimplifierTest,AbsEliminationMultiply)6858 TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
6859   const char* kModuleStr = R"(
6860     HloModule m
6861     test {
6862       p = f32[32]{0} parameter(0)
6863       m = f32[32]{0} multiply(p, p)
6864       ROOT a = f32[32]{0} abs(m)
6865     }
6866   )";
6867   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6868   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6869   EXPECT_THAT(m->entry_computation()->root_instruction(),
6870               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
6871 }
6872 
TEST_F(AlgebraicSimplifierTest,BroadcastCompareSimplification)6873 TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) {
6874   std::string module_string = R"(
6875     HloModule m
6876     test {
6877       a = s32[] parameter(0)
6878       b = s32[] parameter(1)
6879       x = s32[10]{0} parameter(2)
6880       broadcast_a = s32[10]{0} broadcast(a), dimensions={}
6881       broadcast_b = s32[10]{0} broadcast(b), dimensions={}
6882       add = s32[10]{0} add(broadcast_a, x)
6883       ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ
6884     }
6885   )";
6886   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string));
6887   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6888   EXPECT_THAT(m->entry_computation()->root_instruction(),
6889               GmockMatch(m::Compare(m::Parameter(2),
6890                                     m::Broadcast(m::Subtract(
6891                                         m::Parameter(1), m::Parameter(0))))));
6892 
6893   // Numerically unstable transformation shouldn't be applied to floating types.
6894   std::string module_string_f32 =
6895       absl::StrReplaceAll(module_string, {{"s32", "f32"}});
6896   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6897 }
6898 
TEST_F(AlgebraicSimplifierTest,AbsEliminationPower2)6899 TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
6900   const char* kModuleStr = R"(
6901     HloModule m
6902     test {
6903       p0 = f32[32]{0} parameter(0)
6904       c0 = f32[] constant(2)
6905       b0 = f32[32]{0} broadcast(c0), dimensions={}
6906       pow = f32[32]{0} power(p0, b0)
6907       ROOT a = f32[32]{0} abs(pow)
6908     }
6909   )";
6910   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6911   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6912   // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is
6913   // transformed to AA.
6914   EXPECT_THAT(m->entry_computation()->root_instruction(),
6915               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
6916 }
6917 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombined)6918 TEST_F(AlgebraicSimplifierTest, ScatterAddCombined) {
6919   const char* hlo_string = R"(
6920   HloModule m
6921   apply {
6922    a = f32[] parameter(0)
6923    b = f32[] parameter(1)
6924    ROOT c = f32[] add(a, b)
6925   }
6926   test {
6927     z  = f32[] constant(0)
6928     init = f32[100,4] broadcast(z), dimensions={}
6929     shared = f32[100,4] parameter(0)
6930     index0 = s32[20] parameter(1)
6931     index1 = s32[10] parameter(2)
6932     update0 = f32[20,4] parameter(3)
6933     update1 = f32[10,4] parameter(4)
6934     scatter.0 = f32[100,4] scatter(init, index0, update0),
6935               to_apply=apply,
6936               update_window_dims={1},
6937               inserted_window_dims={0},
6938               scatter_dims_to_operand_dims={0},
6939               index_vector_dim=1
6940     scatter.1 = f32[100,4] scatter(init, index1, update1),
6941               to_apply=apply,
6942               update_window_dims={1},
6943               inserted_window_dims={0},
6944               scatter_dims_to_operand_dims={0},
6945               index_vector_dim=1
6946     add.0 = f32[100,4] add(shared, scatter.0)
6947     ROOT add.1 = f32[100,4] add(add.0, scatter.1)
6948   }
6949   )";
6950   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
6951   // Combine Scatters
6952   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6953   // Optimize Add with 0
6954   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6955   EXPECT_THAT(
6956       m->entry_computation()->root_instruction(),
6957       GmockMatch(m::Scatter(m::Parameter(0),
6958                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
6959                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
6960 }
6961 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedSwapped)6962 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedSwapped) {
6963   const char* hlo_string = R"(
6964   HloModule m
6965   apply {
6966    a = f32[] parameter(0)
6967    b = f32[] parameter(1)
6968    ROOT c = f32[] add(a, b)
6969   }
6970   test {
6971     z  = f32[] constant(0)
6972     init = f32[100,4] broadcast(z), dimensions={}
6973     shared = f32[100,4] parameter(0)
6974     index0 = s32[20] parameter(1)
6975     index1 = s32[10] parameter(2)
6976     update0 = f32[20,4] parameter(3)
6977     update1 = f32[10,4] parameter(4)
6978     scatter.0 = f32[100,4] scatter(init, index0, update0),
6979               to_apply=apply,
6980               update_window_dims={1},
6981               inserted_window_dims={0},
6982               scatter_dims_to_operand_dims={0},
6983               index_vector_dim=1
6984     scatter.1 = f32[100,4] scatter(init, index1, update1),
6985               to_apply=apply,
6986               update_window_dims={1},
6987               inserted_window_dims={0},
6988               scatter_dims_to_operand_dims={0},
6989               index_vector_dim=1
6990     add.0 = f32[100,4] add(shared, scatter.0)
6991     ROOT add.1 = f32[100,4] add(scatter.1, add.0)
6992   }
6993   )";
6994   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
6995   // Combine Scatters
6996   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6997   // Optimize Add with 0
6998   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6999   EXPECT_THAT(
7000       m->entry_computation()->root_instruction(),
7001       GmockMatch(m::Scatter(m::Parameter(0),
7002                             m::Concatenate(m::Parameter(2), m::Parameter(1)),
7003                             m::Concatenate(m::Parameter(4), m::Parameter(3)))));
7004 }
7005 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums)7006 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums) {
7007   const char* hlo_string = R"(
7008   HloModule m
7009   apply {
7010    a = f32[] parameter(0)
7011    b = f32[] parameter(1)
7012    ROOT c = f32[] add(a, b)
7013   }
7014   test {
7015     z  = f32[] constant(0)
7016     init = f32[100,4] broadcast(z), dimensions={}
7017     shared = f32[100,4] parameter(0)
7018     index0 = s32[1,4,5] parameter(1)
7019     index1 = s32[1,2,5] parameter(2)
7020     update0 = f32[4,4,5] parameter(3)
7021     update1 = f32[2,4,5] parameter(4)
7022     scatter.0 = f32[100,4] scatter(init, index0, update0),
7023               to_apply=apply,
7024               update_window_dims={1},
7025               inserted_window_dims={0},
7026               scatter_dims_to_operand_dims={0},
7027               index_vector_dim=0
7028     scatter.1 = f32[100,4] scatter(init, index1, update1),
7029               to_apply=apply,
7030               update_window_dims={1},
7031               inserted_window_dims={0},
7032               scatter_dims_to_operand_dims={0},
7033               index_vector_dim=0
7034     ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7035   }
7036   )";
7037   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7038   // Combine Scatters
7039   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7040   // Simplify Add
7041   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7042   EXPECT_THAT(
7043       m->entry_computation()->root_instruction(),
7044       GmockMatch(m::Scatter(m::Broadcast(),
7045                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
7046                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7047 }
7048 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums2)7049 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
7050   const char* hlo_string = R"(
7051   HloModule m
7052   apply {
7053    a = f32[] parameter(0)
7054    b = f32[] parameter(1)
7055    ROOT c = f32[] add(a, b)
7056   }
7057   test {
7058     z  = f32[] constant(0)
7059     init = f32[100,4] broadcast(z), dimensions={}
7060     shared = f32[100,4] parameter(0)
7061     index0 = s32[4,3,1] parameter(1)
7062     index1 = s32[4,5,1] parameter(2)
7063     update0 = f32[4,4,3] parameter(3)
7064     update1 = f32[4,4,5] parameter(4)
7065     scatter.0 = f32[100,4] scatter(init, index0, update0),
7066               to_apply=apply,
7067               update_window_dims={0},
7068               inserted_window_dims={0},
7069               scatter_dims_to_operand_dims={0},
7070               index_vector_dim=2
7071     scatter.1 = f32[100,4] scatter(init, index1, update1),
7072               to_apply=apply,
7073               update_window_dims={0},
7074               inserted_window_dims={0},
7075               scatter_dims_to_operand_dims={0},
7076               index_vector_dim=2
7077     ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7078   }
7079   )";
7080   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7081   // Combine Scatters
7082   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7083   // Simplify Add
7084   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7085   EXPECT_THAT(
7086       m->entry_computation()->root_instruction(),
7087       GmockMatch(m::Scatter(m::Broadcast(),
7088                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
7089                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7090 }
7091 
TEST_F(AlgebraicSimplifierTest,ScalarScatter)7092 TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
7093   const char* hlo_string = R"(
7094   HloModule m
7095   apply {
7096    a = f32[] parameter(0)
7097    b = f32[] parameter(1)
7098    ROOT c = f32[] add(a, b)
7099   }
7100   test {
7101     z  = f32[] constant(0)
7102     init = f32[100,4,20] broadcast(z), dimensions={}
7103     shared = f32[100,4,20] parameter(0)
7104     index0 = s32[1] parameter(1)
7105     index1 = s32[1] parameter(2)
7106     update0 = f32[4,20] parameter(3)
7107     update1 = f32[4,20] parameter(4)
7108     scatter.0 = f32[100,4,20] scatter(init, index0, update0),
7109               to_apply=apply,
7110               update_window_dims={0, 1},
7111               inserted_window_dims={0},
7112               scatter_dims_to_operand_dims={0},
7113               index_vector_dim=0
7114     scatter.1 = f32[100,4,20] scatter(init, index1, update1),
7115               to_apply=apply,
7116               update_window_dims={0, 1},
7117               inserted_window_dims={0},
7118               scatter_dims_to_operand_dims={0},
7119               index_vector_dim=0
7120     ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
7121   }
7122   )";
7123   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7124   // Combine Scatters
7125   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7126 }
7127 
TEST_F(AlgebraicSimplifierTest,SwapConvOperands)7128 TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
7129   const char* hlo_string = R"(
7130   HloModule m
7131   test {
7132     a = f32[3,3,160,160] parameter(0)
7133     b = f32[128,32,32,160] parameter(1)
7134     ROOT c = f32[128,32,32,160] convolution(a,b),
7135      window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1},
7136      dim_labels=01bf_o01i->f01b
7137   }
7138   )";
7139   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7140   // Combine Scatters
7141   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7142   const HloInstruction* conv = m->entry_computation()->root_instruction();
7143   EXPECT_THAT(conv,
7144               GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0))));
7145   EXPECT_EQ(conv->window().dimensions(0).size(), 3);
7146   EXPECT_EQ(conv->window().dimensions(1).size(), 3);
7147   EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true);
7148   EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true);
7149   EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1);
7150   EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1);
7151   EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1);
7152   EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
7153 }
7154 
TEST_F(AlgebraicSimplifierTest,ScalarDividePredicate)7155 TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
7156   const char* kModuleStr = R"(
7157     HloModule m
7158     test {
7159       p0 = pred[2] parameter(0)
7160       cvt = f32[2] convert(p0)
7161       p1 = f32[] parameter(1)
7162       bcast = f32[2] broadcast(p1), dimensions={}
7163       ROOT div = f32[2] divide(cvt, bcast)
7164     }
7165   )";
7166   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7167   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7168   EXPECT_THAT(
7169       m->entry_computation()->root_instruction(),
7170       GmockMatch(m::MultiplyAnyOrder(
7171           m::Convert(m::Parameter(0)),
7172           m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
7173 }
7174 
TEST_F(AlgebraicSimplifierTest,MultipleDotStrengthReductions)7175 TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
7176   constexpr char kModuleStr[] = R"(
7177     HloModule test
7178     ENTRY test {
7179       a = c64[2,2] parameter(0)
7180       b = c64[2] parameter(1)
7181       cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7182       c = f64[2,2] parameter(2)
7183       d = f64[2] parameter(3)
7184       dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7185       ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
7186     }
7187   )";
7188   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7189   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7190   EXPECT_EQ(3, m->computation_count());
7191 }
7192 
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduce)7193 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) {
7194   const char* kModuleStr = R"(
7195     HloModule m
7196     fn {
7197       p0 = f32[] parameter(0)
7198       p1 = f32[] parameter(1)
7199       a = f32[] add(p0, p1)
7200       ROOT t = (f32[]) tuple(a)
7201     }
7202     test {
7203       p0 = f32[32,8,6,7] parameter(0)
7204       c = f32[] constant(0)
7205       ROOT r = (f32[8,6,7]) reduce(p0, c), dimensions={0}, to_apply=fn
7206     }
7207   )";
7208   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7209   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7210   ASSERT_THAT(
7211       m->entry_computation()->root_instruction(),
7212       GmockMatch(m::Tuple(m::Reduce(m::Parameter(0), m::ConstantScalar(0)))));
7213   ASSERT_EQ(m->entry_computation()
7214                 ->root_instruction()
7215                 ->operand(0)
7216                 ->called_computations()
7217                 .size(),
7218             1);
7219   EXPECT_THAT(m->entry_computation()
7220                   ->root_instruction()
7221                   ->operand(0)
7222                   ->called_computations()[0]
7223                   ->root_instruction(),
7224               GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
7225 }
7226 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorder)7227 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) {
7228   const char* kModuleStr = R"(
7229     HloModule m
7230     test {
7231       c1 = pred[] constant(true)
7232       b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={}
7233       c3 = pred[] constant(false)
7234       ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0
7235     }
7236   )";
7237   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7238   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7239   EXPECT_THAT(m->entry_computation()->root_instruction(),
7240               GmockMatch(m::Broadcast(
7241                   m::Pad(m::Broadcast(m::Constant()), m::Constant()))));
7242 }
7243 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithUse)7244 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) {
7245   const char* kModuleStr = R"(
7246     HloModule m
7247     test {
7248       c1 = pred[] constant(true)
7249       b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={}
7250       c3 = pred[] constant(false)
7251       p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
7252       ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
7253     }
7254   )";
7255   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7256   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7257   EXPECT_THAT(m->entry_computation()->root_instruction(),
7258               GmockMatch(m::Tuple(m::Broadcast(
7259                   m::Pad(m::Broadcast(m::Constant()), m::Constant())))));
7260 }
7261 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithNonScalar)7262 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
7263   const char* kModuleStr = R"(
7264     HloModule m
7265     test {
7266       c1 = pred[32] parameter(0)
7267       b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2}
7268       c3 = pred[] constant(false)
7269       p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
7270       ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
7271     }
7272   )";
7273   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7274   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7275   EXPECT_THAT(m->entry_computation()->root_instruction(),
7276               GmockMatch(m::Tuple(m::Broadcast(
7277                   m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
7278 }
7279 
7280 // Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
7281 // start_indices are too big.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPadOob)7282 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
7283   const char* hlo_string = R"(
7284 HloModule module
7285 
7286 ENTRY f {
7287   constant.546 = f32[] constant(0)
7288   broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
7289   parameter.1 = f32[1]{0} parameter(0)
7290   constant.551 = s32[] constant(2)
7291   ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
7292 }
7293 )";
7294   TF_ASSERT_OK_AND_ASSIGN(auto module,
7295                           ParseAndReturnVerifiedModule(hlo_string));
7296   VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
7297   AlgebraicSimplifier simplifier(default_options_);
7298   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
7299   VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
7300   auto* pad = module->entry_computation()->root_instruction();
7301   EXPECT_THAT(pad,
7302               GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
7303   EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
7304   ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
7305   EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
7306   EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
7307 }
7308 
7309 }  // namespace
7310 }  // namespace xla
7311