1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/window_util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status_test_util.h"
43
44 namespace xla {
45 namespace {
46
47 using ::testing::ElementsAre;
48 namespace m = match;
49
50 class AlgebraicSimplifierTest : public HloTestBase {
51 protected:
52 AlgebraicSimplifierOptions default_options_;
53 };
54
55 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)56 TEST_F(AlgebraicSimplifierTest, AddZero) {
57 auto m = CreateNewVerifiedModule();
58 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
59 HloComputation::Builder builder(TestName());
60 HloInstruction* param0 = builder.AddInstruction(
61 HloInstruction::CreateParameter(0, r0f32, "param0"));
62 HloInstruction* zero = builder.AddInstruction(
63 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
64 builder.AddInstruction(
65 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
66
67 auto computation = m->AddEntryComputation(builder.Build());
68 HloInstruction* root = computation->root_instruction();
69 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
70 AlgebraicSimplifier simplifier(default_options_);
71 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
72 root = computation->root_instruction();
73 EXPECT_EQ(root, param0);
74 }
75
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)76 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
77 const char* kModuleStr = R"(
78 HloModule m
79 test {
80 p0 = s32[8] parameter(0)
81 p1 = s32[8] parameter(1)
82 p2 = s32[8] parameter(2)
83 x = s32[8] multiply(p0, p2)
84 y = s32[8] multiply(p1, p2)
85 ROOT sum = s32[8] add(x, y)
86 }
87 )";
88 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
89 AlgebraicSimplifier simplifier(default_options_);
90 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
91 EXPECT_THAT(
92 m->entry_computation()->root_instruction(),
93 GmockMatch(m::MultiplyAnyOrder(
94 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
95 }
96
97 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)98 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
99 const char* kModuleStr = R"(
100 HloModule m
101 test {
102 p0 = f32[] parameter(0)
103 p1 = f32[] parameter(1)
104 c = f32[] constant(0.125)
105 x = f32[] multiply(p0, c)
106 y = f32[] multiply(p1, c)
107 ROOT sum = f32[] add(x, y)
108 }
109 )";
110 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
111 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
112 EXPECT_THAT(m->entry_computation()->root_instruction(),
113 GmockMatch(m::MultiplyAnyOrder(
114 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
115 m::ConstantScalar(0.125))));
116 }
117
118 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)119 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
120 const char* kModuleStr = R"(
121 HloModule m
122 test {
123 p0 = f32[4] parameter(0)
124 p1 = f32[4] parameter(1)
125 c = f32[] constant(0.125)
126 b = f32[4] broadcast(c), dimensions={}
127 x = f32[4] multiply(p0, b)
128 y = f32[4] multiply(p1, b)
129 ROOT sum = f32[4] add(x, y)
130 }
131 )";
132 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
133 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
134 EXPECT_THAT(m->entry_computation()->root_instruction(),
135 GmockMatch(m::MultiplyAnyOrder(
136 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
137 m::Broadcast(m::ConstantScalar(0.125)))));
138 }
139
140 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
141 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)142 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
143 const char* kModuleStr = R"(
144 HloModule m
145 test {
146 p0 = f32[] parameter(0)
147 p1 = f32[] parameter(1)
148 c = f32[] constant(0.3)
149 x = f32[] multiply(p0, c)
150 y = f32[] multiply(p1, c)
151 ROOT sum = f32[] add(x, y)
152 }
153 )";
154 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
155 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
156 }
157
158 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
159 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)160 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
161 const char* kModuleStr = R"(
162 HloModule m
163 test {
164 p0 = c64[8] parameter(0)
165 p1 = c64[8] parameter(1)
166 p2 = c64[8] parameter(2)
167 x = c64[8] multiply(p0, p2)
168 y = c64[8] multiply(p1, p2)
169 ROOT sum = c64[8] add(x, y)
170 }
171 )";
172 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
173 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
174 }
175
176 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)177 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
178 const char* kModuleStr = R"(
179 HloModule m
180 test {
181 p0 = bf16[4] parameter(0)
182 p1 = bf16[4] parameter(1)
183 c = bf16[] constant(0.125)
184 b = bf16[4] broadcast(c), dimensions={}
185 x = bf16[4] multiply(p0, b)
186 y = bf16[4] multiply(p1, b)
187 ROOT sum = bf16[4] add(x, y)
188 }
189 )";
190 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
191 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
192 EXPECT_THAT(m->entry_computation()->root_instruction(),
193 GmockMatch(m::MultiplyAnyOrder(
194 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
195 m::Broadcast(m::ConstantScalar(0.125)))));
196 }
197
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)198 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
199 const char* kModuleStr = R"(
200 HloModule m
201 test {
202 p = u32[4] parameter(0)
203 c = u32[] constant(8)
204 b = u32[4] broadcast(c), dimensions={}
205 ROOT d = u32[4] divide(p, b)
206 }
207 )";
208 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
209 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
210 EXPECT_THAT(m->entry_computation()->root_instruction(),
211 GmockMatch(m::ShiftRightLogical(
212 m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
213 }
214
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)215 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
216 const char* kModuleStr = R"(
217 HloModule m
218 test {
219 p = s32[4] parameter(0)
220 c = s32[] constant(8)
221 b = s32[4] broadcast(c), dimensions={}
222 ROOT d = s32[4] divide(p, b)
223 }
224 )";
225 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
226 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
227 auto match_dividend_is_negative =
228 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
229 auto match_abs = m::Select(match_dividend_is_negative,
230 m::Negate(m::Parameter(0)), m::Parameter(0));
231 auto match_shift =
232 m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
233 EXPECT_THAT(m->entry_computation()->root_instruction(),
234 GmockMatch(m::Select(match_dividend_is_negative,
235 m::Negate(match_shift), match_shift)));
236 }
237
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)238 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
239 const char* kModuleStr = R"(
240 HloModule m
241 test {
242 p = u32[4] parameter(0)
243 c = u32[] constant(8)
244 b = u32[4] broadcast(c), dimensions={}
245 ROOT r = u32[4] remainder(p, b)
246 }
247 )";
248 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
249 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
250 EXPECT_THAT(m->entry_computation()->root_instruction(),
251 GmockMatch(m::AndAnyOrder(m::Parameter(0),
252 m::Broadcast(m::ConstantScalar(7)))));
253 }
254
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)255 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
256 const char* kModuleStr = R"(
257 HloModule m
258 test {
259 p = s32[4] parameter(0)
260 c = s32[] constant(8)
261 b = s32[4] broadcast(c), dimensions={}
262 ROOT r = s32[4] remainder(p, b)
263 }
264 )";
265 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
266 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
267 auto match_dividend_is_negative =
268 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
269 auto match_abs = m::Select(match_dividend_is_negative,
270 m::Negate(m::Parameter(0)), m::Parameter(0));
271 auto match_and =
272 m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
273 EXPECT_THAT(m->entry_computation()->root_instruction(),
274 GmockMatch(m::Select(match_dividend_is_negative,
275 m::Negate(match_and), match_and)));
276 }
277
278 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)279 TEST_F(AlgebraicSimplifierTest, MulZero) {
280 auto m = CreateNewVerifiedModule();
281 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
282 HloComputation::Builder builder(TestName());
283 HloInstruction* param0 = builder.AddInstruction(
284 HloInstruction::CreateParameter(0, r0s32, "param0"));
285 HloInstruction* zero = builder.AddInstruction(
286 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
287 builder.AddInstruction(
288 HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
289
290 auto computation = m->AddEntryComputation(builder.Build());
291 HloInstruction* root = computation->root_instruction();
292 EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
293 AlgebraicSimplifier simplifier(default_options_);
294 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
295 EXPECT_EQ(computation->root_instruction(), zero);
296 }
297
298 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)299 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
300 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
301 HloComputation::Builder builder(TestName());
302 HloInstruction* param0 = builder.AddInstruction(
303 HloInstruction::CreateParameter(0, r0s32, "param0"));
304 HloInstruction* param1 = builder.AddInstruction(
305 HloInstruction::CreateParameter(1, r0s32, "param1"));
306 HloInstruction* one = builder.AddInstruction(
307 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
308 builder.AddInstruction(HloInstruction::CreateTernary(
309 r0s32, HloOpcode::kSelect, one, param0, param1));
310
311 auto module = CreateNewVerifiedModule();
312 auto computation = module->AddEntryComputation(builder.Build());
313 HloInstruction* root = computation->root_instruction();
314 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
315 AlgebraicSimplifier simplifier(default_options_);
316 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
317 EXPECT_EQ(computation->root_instruction(), param0);
318 }
319
320 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)321 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
322 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
323 HloComputation::Builder builder(TestName());
324 HloInstruction* param0 = builder.AddInstruction(
325 HloInstruction::CreateParameter(0, r0s32, "param0"));
326 HloInstruction* param1 = builder.AddInstruction(
327 HloInstruction::CreateParameter(1, r0s32, "param1"));
328 HloInstruction* zero = builder.AddInstruction(
329 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
330 builder.AddInstruction(HloInstruction::CreateTernary(
331 r0s32, HloOpcode::kSelect, zero, param0, param1));
332
333 auto module = CreateNewVerifiedModule();
334 auto computation = module->AddEntryComputation(builder.Build());
335 HloInstruction* root = computation->root_instruction();
336 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
337 AlgebraicSimplifier simplifier(default_options_);
338 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
339 EXPECT_EQ(computation->root_instruction(), param1);
340 }
341
342 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)343 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
344 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
345 HloComputation::Builder builder(TestName());
346 HloInstruction* param0 = builder.AddInstruction(
347 HloInstruction::CreateParameter(0, r0s32, "param0"));
348 HloInstruction* param1 = builder.AddInstruction(
349 HloInstruction::CreateParameter(1, r0s32, "param1"));
350 builder.AddInstruction(HloInstruction::CreateTernary(
351 r0s32, HloOpcode::kSelect, param0, param1, param1));
352
353 auto module = CreateNewVerifiedModule();
354 auto computation = module->AddEntryComputation(builder.Build());
355 HloInstruction* root = computation->root_instruction();
356 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
357 AlgebraicSimplifier simplifier(default_options_);
358 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
359 EXPECT_EQ(computation->root_instruction(), param1);
360 }
361
362 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)363 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
364 auto m = CreateNewVerifiedModule();
365 HloComputation::Builder builder(TestName());
366 // Create add computation.
367 HloInstruction* zero = builder.AddInstruction(
368 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
369 HloComputation* add_computation = nullptr;
370 {
371 HloComputation::Builder builder(TestName() + ".add");
372 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
373 HloInstruction* p0 = builder.AddInstruction(
374 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
375 HloInstruction* p1 = builder.AddInstruction(
376 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
377 builder.AddInstruction(
378 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
379 add_computation = m->AddEmbeddedComputation(builder.Build());
380 }
381 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
382 HloInstruction* param = builder.AddInstruction(
383 HloInstruction::CreateParameter(0, r4f32, "param"));
384 std::vector<int64> dims0({0});
385 Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
386 HloInstruction* reduce0 = builder.AddInstruction(
387 HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
388 std::vector<int64> dims1({1, 2});
389 Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
390 builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
391 dims1, add_computation));
392 m->AddEntryComputation(builder.Build());
393 AlgebraicSimplifier simplifier(default_options_);
394 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
395 HloInstruction* root = m->entry_computation()->root_instruction();
396 EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
397 EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
398 }
399
400 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)401 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
402 auto m = CreateNewVerifiedModule();
403 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
404 HloComputation::Builder builder(TestName());
405 HloInstruction* param0 = builder.AddInstruction(
406 HloInstruction::CreateParameter(0, r0f32, "param0"));
407 HloInstruction* constant = builder.AddInstruction(
408 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
409 builder.AddInstruction(
410 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
411
412 auto computation = m->AddEntryComputation(builder.Build());
413 HloInstruction* root = computation->root_instruction();
414 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
415 AlgebraicSimplifier simplifier(default_options_);
416 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
417 root = computation->root_instruction();
418 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
419 }
420
421 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)422 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
423 auto m = CreateNewVerifiedModule();
424 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
425 HloComputation::Builder builder(TestName());
426 HloInstruction* param0 = builder.AddInstruction(
427 HloInstruction::CreateParameter(0, r0f32, "param0"));
428 HloInstruction* constant1 = builder.AddInstruction(
429 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
430 HloInstruction* constant2 = builder.AddInstruction(
431 HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
432
433 HloInstruction* add1 = builder.AddInstruction(
434 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
435 builder.AddInstruction(
436 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
437
438 auto computation = m->AddEntryComputation(builder.Build());
439 HloInstruction* root = computation->root_instruction();
440 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
441 AlgebraicSimplifier simplifier(default_options_);
442 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
443 root = computation->root_instruction();
444 EXPECT_THAT(root, GmockMatch(m::Add(
445 m::Op().Is(param0),
446 m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
447 }
448
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)449 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
450 auto m = CreateNewVerifiedModule();
451 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
452 HloComputation::Builder builder(TestName());
453 HloInstruction* param0 = builder.AddInstruction(
454 HloInstruction::CreateParameter(0, r2f32, "param0"));
455 HloInstruction* zero = builder.AddInstruction(
456 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
457 HloInstruction* bcast = builder.AddInstruction(
458 HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
459 builder.AddInstruction(
460 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
461
462 auto computation = m->AddEntryComputation(builder.Build());
463 HloInstruction* root = computation->root_instruction();
464 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
465 AlgebraicSimplifier simplifier(default_options_);
466 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
467 root = computation->root_instruction();
468 EXPECT_EQ(root, param0);
469 }
470
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)471 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
472 auto m = CreateNewVerifiedModule();
473 HloComputation::Builder builder(TestName());
474 // Create add computation.
475 HloComputation* add_computation = nullptr;
476 {
477 HloComputation::Builder builder(TestName() + ".add");
478 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
479 HloInstruction* p0 = builder.AddInstruction(
480 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
481 HloInstruction* p1 = builder.AddInstruction(
482 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
483 builder.AddInstruction(
484 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
485 add_computation = m->AddEmbeddedComputation(builder.Build());
486 }
487 Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
488 HloInstruction* param0 = builder.AddInstruction(
489 HloInstruction::CreateParameter(0, r2f32, "param0"));
490 HloInstruction* zero = builder.AddInstruction(
491 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
492 builder.AddInstruction(HloInstruction::CreateMap(
493 r2f32,
494 {param0, builder.AddInstruction(
495 HloInstruction::CreateBroadcast(r2f32, zero, {}))},
496 add_computation));
497
498 auto computation = m->AddEntryComputation(builder.Build());
499 HloInstruction* root = computation->root_instruction();
500 EXPECT_EQ(root->opcode(), HloOpcode::kMap);
501 AlgebraicSimplifier simplifier(default_options_);
502 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
503 root = computation->root_instruction();
504 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
505 m::Broadcast(m::Op().Is(zero)))));
506 }
507
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)508 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
509 const char* kModuleStr = R"(
510 HloModule m
511 fusion {
512 x = f32[] parameter(0)
513 c = f32[] constant(42)
514 m = f32[] multiply(x, x)
515 ROOT a = f32[] add(m, c)
516 }
517
518 map {
519 x = f32[] parameter(0)
520 ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
521 }
522
523 ENTRY test {
524 p = f32[2,2] parameter(0)
525 ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
526 }
527 )";
528 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
529 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
530 }
531
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)532 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
533 auto m = CreateNewVerifiedModule();
534 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
535 HloComputation::Builder builder(TestName());
536 HloInstruction* param0 = builder.AddInstruction(
537 HloInstruction::CreateParameter(0, r2f32, "param0"));
538 HloInstruction* zero = builder.AddInstruction(
539 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
540 HloInstruction* bcast =
541 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
542 builder.AddInstruction(
543 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
544
545 auto computation = m->AddEntryComputation(builder.Build());
546 HloInstruction* root = computation->root_instruction();
547 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
548 AlgebraicSimplifier simplifier(default_options_);
549 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
550 root = computation->root_instruction();
551 EXPECT_EQ(root, param0);
552 }
553
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)554 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
555 auto m = CreateNewVerifiedModule();
556 HloComputation::Builder builder(TestName());
557 builder.AddInstruction(HloInstruction::CreateConstant(
558 LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
559
560 auto computation = m->AddEntryComputation(builder.Build());
561 HloInstruction* root = computation->root_instruction();
562 EXPECT_THAT(root, GmockMatch(m::Constant()));
563 AlgebraicSimplifier simplifier(default_options_);
564 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
565 root = computation->root_instruction();
566 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
567 EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
568 }
569
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)570 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
571 auto m = CreateNewVerifiedModule();
572 HloComputation::Builder builder(TestName());
573 builder.AddInstruction(HloInstruction::CreateConstant(
574 LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
575
576 auto computation = m->AddEntryComputation(builder.Build());
577 HloInstruction* root = computation->root_instruction();
578 EXPECT_THAT(root, GmockMatch(m::Constant()));
579 AlgebraicSimplifier simplifier(default_options_);
580 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
581 root = computation->root_instruction();
582 EXPECT_THAT(root, GmockMatch(m::Constant()));
583 }
584
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)585 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
586 auto m = CreateNewVerifiedModule();
587 HloComputation::Builder builder(TestName());
588 builder.AddInstruction(HloInstruction::CreateConstant(
589 LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
590
591 auto computation = m->AddEntryComputation(builder.Build());
592 HloInstruction* root = computation->root_instruction();
593 EXPECT_THAT(root, GmockMatch(m::Constant()));
594 AlgebraicSimplifier simplifier(default_options_);
595 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
596 root = computation->root_instruction();
597 EXPECT_THAT(root, GmockMatch(m::Iota()));
598 }
599
600 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)601 TEST_F(AlgebraicSimplifierTest, SubZero) {
602 auto m = CreateNewVerifiedModule();
603 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
604 HloComputation::Builder builder(TestName());
605 HloInstruction* param0 = builder.AddInstruction(
606 HloInstruction::CreateParameter(0, r0f32, "param0"));
607 HloInstruction* zero = builder.AddInstruction(
608 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
609 builder.AddInstruction(
610 HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
611
612 auto computation = m->AddEntryComputation(builder.Build());
613 HloInstruction* root = computation->root_instruction();
614 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
615 AlgebraicSimplifier simplifier(default_options_);
616 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
617 root = computation->root_instruction();
618 EXPECT_EQ(root, param0);
619 }
620
621 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)622 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
623 auto m = CreateNewVerifiedModule();
624 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
625 HloComputation::Builder builder(TestName());
626 HloInstruction* param0 = builder.AddInstruction(
627 HloInstruction::CreateParameter(0, r0f32, "param0"));
628 HloInstruction* constant = builder.AddInstruction(
629 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
630 builder.AddInstruction(HloInstruction::CreateBinary(
631 r0f32, HloOpcode::kSubtract, param0, constant));
632
633 auto computation = m->AddEntryComputation(builder.Build());
634 HloInstruction* root = computation->root_instruction();
635 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
636 AlgebraicSimplifier simplifier(default_options_);
637 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
638 root = computation->root_instruction();
639 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
640 m::Negate(m::Op().Is(constant)))));
641 }
642
643 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)644 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
645 auto m = CreateNewVerifiedModule();
646 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
647 HloComputation::Builder builder(TestName());
648 HloInstruction* param0 = builder.AddInstruction(
649 HloInstruction::CreateParameter(0, r0f32, "param0"));
650 HloInstruction* param1 = builder.AddInstruction(
651 HloInstruction::CreateParameter(1, r0f32, "param1"));
652 HloInstruction* param2 = builder.AddInstruction(
653 HloInstruction::CreateParameter(2, r0f32, "param2"));
654 HloInstruction* div = builder.AddInstruction(
655 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
656 builder.AddInstruction(
657 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
658
659 auto computation = m->AddEntryComputation(builder.Build());
660
661 EXPECT_THAT(computation->root_instruction(),
662 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
663 m::Parameter(2))));
664
665 AlgebraicSimplifier simplifier(default_options_);
666 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
667
668 EXPECT_THAT(
669 computation->root_instruction(),
670 GmockMatch(m::Divide(m::Parameter(0),
671 m::Multiply(m::Parameter(1), m::Parameter(2)))));
672 }
673
674 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)675 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
676 auto m = CreateNewVerifiedModule();
677 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
678 HloComputation::Builder builder(TestName());
679 HloInstruction* param0 = builder.AddInstruction(
680 HloInstruction::CreateParameter(0, r0f32, "param0"));
681 HloInstruction* param1 = builder.AddInstruction(
682 HloInstruction::CreateParameter(1, r0f32, "param1"));
683 HloInstruction* param2 = builder.AddInstruction(
684 HloInstruction::CreateParameter(2, r0f32, "param2"));
685 HloInstruction* div = builder.AddInstruction(
686 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
687 builder.AddInstruction(
688 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
689
690 auto computation = m->AddEntryComputation(builder.Build());
691
692 EXPECT_THAT(
693 computation->root_instruction(),
694 GmockMatch(m::Divide(m::Parameter(0),
695 m::Divide(m::Parameter(1), m::Parameter(2)))));
696
697 AlgebraicSimplifier simplifier(default_options_);
698 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
699
700 EXPECT_THAT(
701 computation->root_instruction(),
702 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
703 m::Parameter(1))));
704 }
705
706 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)707 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
708 auto m = CreateNewVerifiedModule();
709 Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
710 HloComputation::Builder builder(TestName());
711 HloInstruction* param0 = builder.AddInstruction(
712 HloInstruction::CreateParameter(0, r2f32, "param0"));
713 HloInstruction* param1 = builder.AddInstruction(
714 HloInstruction::CreateParameter(1, r2f32, "param1"));
715 HloInstruction* param2 = builder.AddInstruction(
716 HloInstruction::CreateParameter(2, r2f32, "param2"));
717 HloInstruction* param3 = builder.AddInstruction(
718 HloInstruction::CreateParameter(3, r2f32, "param3"));
719 HloInstruction* div0 = builder.AddInstruction(
720 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
721 HloInstruction* div1 = builder.AddInstruction(
722 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
723 builder.AddInstruction(
724 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
725
726 auto computation = m->AddEntryComputation(builder.Build());
727
728 EXPECT_THAT(
729 computation->root_instruction(),
730 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
731 m::Divide(m::Parameter(2), m::Parameter(3)))));
732
733 AlgebraicSimplifier simplifier(default_options_);
734 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
735
736 EXPECT_THAT(
737 computation->root_instruction(),
738 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
739 m::Multiply(m::Parameter(1), m::Parameter(2)))));
740 }
741
742 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)743 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
744 auto m = CreateNewVerifiedModule();
745 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
746 HloComputation::Builder builder(TestName());
747 HloInstruction* param0 = builder.AddInstruction(
748 HloInstruction::CreateParameter(0, r0f32, "param0"));
749 HloInstruction* param1 = builder.AddInstruction(
750 HloInstruction::CreateParameter(1, r0f32, "param1"));
751 HloInstruction* exp = builder.AddInstruction(
752 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
753 builder.AddInstruction(
754 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
755
756 auto computation = m->AddEntryComputation(builder.Build());
757
758 EXPECT_THAT(computation->root_instruction(),
759 GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
760
761 AlgebraicSimplifier simplifier(default_options_);
762 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
763
764 EXPECT_THAT(computation->root_instruction(),
765 GmockMatch(m::Multiply(m::Parameter(0),
766 m::Exp(m::Negate(m::Parameter(1))))));
767 }
768
769 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)770 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
771 auto m = CreateNewVerifiedModule();
772 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
773 HloComputation::Builder builder(TestName());
774 HloInstruction* param0 = builder.AddInstruction(
775 HloInstruction::CreateParameter(0, r0f32, "param0"));
776 HloInstruction* param1 = builder.AddInstruction(
777 HloInstruction::CreateParameter(1, r0f32, "param1"));
778 HloInstruction* param2 = builder.AddInstruction(
779 HloInstruction::CreateParameter(2, r0f32, "param2"));
780 HloInstruction* power = builder.AddInstruction(
781 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
782 builder.AddInstruction(
783 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
784
785 auto computation = m->AddEntryComputation(builder.Build());
786
787 EXPECT_THAT(
788 computation->root_instruction(),
789 GmockMatch(m::Divide(m::Parameter(0),
790 m::Power(m::Parameter(1), m::Parameter(2)))));
791
792 AlgebraicSimplifier simplifier(default_options_);
793 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
794
795 EXPECT_THAT(computation->root_instruction(),
796 GmockMatch(m::Multiply(
797 m::Parameter(0),
798 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
799 }
800
801 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
802 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)803 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
804 auto m = CreateNewVerifiedModule();
805 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
806 HloComputation::Builder builder(TestName());
807 HloInstruction* param0 = builder.AddInstruction(
808 HloInstruction::CreateParameter(0, r1f32, "param0"));
809 HloInstruction* param1 = builder.AddInstruction(
810 HloInstruction::CreateParameter(1, r1f32, "param1"));
811 HloInstruction* param2 = builder.AddInstruction(
812 HloInstruction::CreateParameter(2, r1f32, "param2"));
813 HloInstruction* power = builder.AddInstruction(
814 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
815 builder.AddInstruction(
816 HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
817
818 auto computation = m->AddEntryComputation(builder.Build());
819
820 EXPECT_THAT(
821 computation->root_instruction(),
822 GmockMatch(m::Divide(m::Parameter(0),
823 m::Power(m::Parameter(1), m::Parameter(2)))));
824
825 AlgebraicSimplifier simplifier(default_options_);
826 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
827
828 ASSERT_THAT(computation->root_instruction(),
829 GmockMatch(m::Multiply(
830 m::Parameter(0),
831 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
832 }
833
834 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)835 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
836 auto m = CreateNewVerifiedModule();
837 Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
838 HloComputation::Builder builder(TestName());
839 HloInstruction* param0 = builder.AddInstruction(
840 HloInstruction::CreateParameter(0, r1f32, "param0"));
841 HloInstruction* constant =
842 builder.AddInstruction(HloInstruction::CreateConstant(
843 LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
844 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
845 param0, constant));
846
847 auto computation = m->AddEntryComputation(builder.Build());
848
849 AlgebraicSimplifier simplifier(default_options_);
850 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
851
852 EXPECT_THAT(computation->root_instruction(),
853 GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
854 }
855
856 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)857 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
858 auto m = CreateNewVerifiedModule();
859 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
860 HloComputation::Builder builder(TestName());
861 HloInstruction* base = builder.AddInstruction(
862 HloInstruction::CreateParameter(0, r1f32, "param0"));
863 HloInstruction* exp1 = builder.AddInstruction(
864 HloInstruction::CreateParameter(1, r1f32, "param1"));
865 HloInstruction* exp2 = builder.AddInstruction(
866 HloInstruction::CreateParameter(2, r1f32, "param2"));
867 HloInstruction* inner_power = builder.AddInstruction(
868 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
869 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
870 inner_power, exp2));
871
872 auto computation = m->AddEntryComputation(builder.Build());
873 AlgebraicSimplifier simplifier(default_options_);
874 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
875 EXPECT_THAT(
876 computation->root_instruction(),
877 GmockMatch(m::Power(m::Op().Is(base),
878 m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2)))));
879 }
880
881 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
882 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)883 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
884 auto m = CreateNewVerifiedModule();
885 Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
886 HloComputation::Builder builder(TestName());
887 HloInstruction* base = builder.AddInstruction(
888 HloInstruction::CreateParameter(0, r1c64, "param0"));
889 HloInstruction* exp1 = builder.AddInstruction(
890 HloInstruction::CreateParameter(1, r1c64, "param1"));
891 HloInstruction* exp2 = builder.AddInstruction(
892 HloInstruction::CreateParameter(2, r1c64, "param2"));
893 HloInstruction* inner_power = builder.AddInstruction(
894 HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
895 builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
896 inner_power, exp2));
897
898 m->AddEntryComputation(builder.Build());
899 AlgebraicSimplifier simplifier(default_options_);
900 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
901 }
902
903 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)904 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
905 auto m = CreateNewVerifiedModule();
906 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
907 HloComputation::Builder builder(TestName());
908 HloInstruction* param0 = builder.AddInstruction(
909 HloInstruction::CreateParameter(0, r0f32, "param0"));
910 HloInstruction* one = builder.AddInstruction(
911 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
912 HloInstruction* div = builder.AddInstruction(
913 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
914
915 auto computation = m->AddEntryComputation(builder.Build());
916 HloInstruction* root = computation->root_instruction();
917 EXPECT_EQ(root, div);
918 AlgebraicSimplifier simplifier(default_options_);
919 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
920 root = computation->root_instruction();
921 EXPECT_EQ(root, param0);
922 }
923
924 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)925 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
926 auto m = CreateNewVerifiedModule();
927 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
928 HloComputation::Builder builder(TestName());
929 HloInstruction* param0 = builder.AddInstruction(
930 HloInstruction::CreateParameter(0, r2f32, "param0"));
931 HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
932 LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
933 HloInstruction* div = builder.AddInstruction(
934 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
935
936 auto computation = m->AddEntryComputation(builder.Build());
937 HloInstruction* root = computation->root_instruction();
938 EXPECT_EQ(root, div);
939 AlgebraicSimplifier simplifier(default_options_);
940 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
941 root = computation->root_instruction();
942 EXPECT_EQ(root, param0);
943 }
944
945 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)946 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
947 auto m = CreateNewVerifiedModule();
948 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
949 Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
950 HloComputation::Builder builder(TestName());
951 HloInstruction* param0 = builder.AddInstruction(
952 HloInstruction::CreateParameter(0, r2c64, "param0"));
953 HloInstruction* real = builder.AddInstruction(
954 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
955 HloInstruction* imag = builder.AddInstruction(
956 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
957 HloInstruction* cplx = builder.AddInstruction(
958 HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
959
960 auto computation = m->AddEntryComputation(builder.Build());
961 HloInstruction* root = computation->root_instruction();
962 EXPECT_EQ(root, cplx);
963 AlgebraicSimplifier simplifier(default_options_);
964 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
965 root = computation->root_instruction();
966 EXPECT_EQ(root, param0);
967 }
968
969 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)970 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
971 auto m = CreateNewVerifiedModule();
972 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
973 HloComputation::Builder builder(TestName());
974 HloInstruction* param0 = builder.AddInstruction(
975 HloInstruction::CreateParameter(0, r2f32, "param0"));
976 HloInstruction* param1 = builder.AddInstruction(
977 HloInstruction::CreateParameter(1, r2f32, "param1"));
978 HloInstruction* cplx = builder.AddInstruction(
979 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
980 HloOpcode::kComplex, param0, param1));
981 HloInstruction* real = builder.AddInstruction(
982 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
983
984 auto computation = m->AddEntryComputation(builder.Build());
985 HloInstruction* root = computation->root_instruction();
986 EXPECT_EQ(root, real);
987 AlgebraicSimplifier simplifier(default_options_);
988 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
989 root = computation->root_instruction();
990 EXPECT_EQ(root, param0);
991 }
992
993 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)994 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
995 auto m = CreateNewVerifiedModule();
996 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
997 HloComputation::Builder builder(TestName());
998 HloInstruction* param0 = builder.AddInstruction(
999 HloInstruction::CreateParameter(0, r2f32, "param0"));
1000 HloInstruction* param1 = builder.AddInstruction(
1001 HloInstruction::CreateParameter(1, r2f32, "param1"));
1002 HloInstruction* cplx = builder.AddInstruction(
1003 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1004 HloOpcode::kComplex, param0, param1));
1005 HloInstruction* imag = builder.AddInstruction(
1006 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1007
1008 auto computation = m->AddEntryComputation(builder.Build());
1009 HloInstruction* root = computation->root_instruction();
1010 EXPECT_EQ(root, imag);
1011 AlgebraicSimplifier simplifier(default_options_);
1012 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1013 root = computation->root_instruction();
1014 EXPECT_EQ(root, param1);
1015 }
1016
1017 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1018 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1019 auto m = CreateNewVerifiedModule();
1020 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1021 HloComputation::Builder builder(TestName());
1022 HloInstruction* param0 = builder.AddInstruction(
1023 HloInstruction::CreateParameter(0, r0f32, "param0"));
1024 HloInstruction* param1 = builder.AddInstruction(
1025 HloInstruction::CreateParameter(1, r0f32, "param1"));
1026 HloInstruction* param2 = builder.AddInstruction(
1027 HloInstruction::CreateParameter(2, r0f32, "param2"));
1028 HloInstruction* tuple =
1029 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1030 HloInstruction* get = builder.AddInstruction(
1031 HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1032 HloInstruction* add = builder.AddInstruction(
1033 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1034
1035 auto computation = m->AddEntryComputation(builder.Build());
1036 HloInstruction* root = computation->root_instruction();
1037 EXPECT_EQ(root, add);
1038 AlgebraicSimplifier simplifier(default_options_);
1039 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1040 root = computation->root_instruction();
1041 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1042 }
1043
1044 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1045 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1046 auto m = CreateNewVerifiedModule();
1047 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1048 HloComputation::Builder builder(TestName());
1049 HloInstruction* param0 = builder.AddInstruction(
1050 HloInstruction::CreateParameter(0, r0f32, "param0"));
1051 HloInstruction* param1 = builder.AddInstruction(
1052 HloInstruction::CreateParameter(1, r0f32, "param1"));
1053 HloInstruction* exp0 = builder.AddInstruction(
1054 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1055 HloInstruction* exp1 = builder.AddInstruction(
1056 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1057 builder.AddInstruction(
1058 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1059
1060 auto computation = m->AddEntryComputation(builder.Build());
1061
1062 EXPECT_THAT(
1063 computation->root_instruction(),
1064 GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1065
1066 AlgebraicSimplifier simplifier(default_options_);
1067 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1068
1069 EXPECT_THAT(
1070 computation->root_instruction(),
1071 GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1072 }
1073
1074 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1075 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1076 auto m = CreateNewVerifiedModule();
1077 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1078 HloComputation::Builder builder(TestName());
1079 HloInstruction* param0 = builder.AddInstruction(
1080 HloInstruction::CreateParameter(0, r0f32, "param0"));
1081 HloInstruction* param1 = builder.AddInstruction(
1082 HloInstruction::CreateParameter(1, r0f32, "param1"));
1083 HloInstruction* exp0 = builder.AddInstruction(
1084 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1085 HloInstruction* exp1 = builder.AddInstruction(
1086 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1087 builder.AddInstruction(
1088 HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1089
1090 auto computation = m->AddEntryComputation(builder.Build());
1091
1092 EXPECT_THAT(computation->root_instruction(),
1093 GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1094 m::Exp(m::Parameter(1)))));
1095
1096 AlgebraicSimplifier simplifier(default_options_);
1097 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1098
1099 EXPECT_THAT(computation->root_instruction(),
1100 GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1101 }
1102
1103 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1104 TEST_F(AlgebraicSimplifierTest, PowExp) {
1105 auto m = CreateNewVerifiedModule();
1106 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1107 HloComputation::Builder builder(TestName());
1108 HloInstruction* param0 = builder.AddInstruction(
1109 HloInstruction::CreateParameter(0, r0f32, "param0"));
1110 HloInstruction* param1 = builder.AddInstruction(
1111 HloInstruction::CreateParameter(1, r0f32, "param1"));
1112 HloInstruction* exp0 = builder.AddInstruction(
1113 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1114 builder.AddInstruction(
1115 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1116
1117 auto computation = m->AddEntryComputation(builder.Build());
1118
1119 EXPECT_THAT(computation->root_instruction(),
1120 GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1121
1122 AlgebraicSimplifier simplifier(default_options_);
1123 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1124
1125 EXPECT_THAT(
1126 computation->root_instruction(),
1127 GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1128 }
1129
1130 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1131 TEST_F(AlgebraicSimplifierTest, LnPow) {
1132 auto m = CreateNewVerifiedModule();
1133 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1134 HloComputation::Builder builder(TestName());
1135 HloInstruction* param0 = builder.AddInstruction(
1136 HloInstruction::CreateParameter(0, r0f32, "param0"));
1137 HloInstruction* param1 = builder.AddInstruction(
1138 HloInstruction::CreateParameter(1, r0f32, "param1"));
1139 HloInstruction* pow = builder.AddInstruction(
1140 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1141 builder.AddInstruction(
1142 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1143
1144 auto computation = m->AddEntryComputation(builder.Build());
1145
1146 EXPECT_THAT(computation->root_instruction(),
1147 GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1148
1149 AlgebraicSimplifier simplifier(default_options_);
1150 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1151
1152 EXPECT_THAT(
1153 computation->root_instruction(),
1154 GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1))));
1155 }
1156
1157 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1158 TEST_F(AlgebraicSimplifierTest, LnExp) {
1159 auto m = CreateNewVerifiedModule();
1160 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1161 HloComputation::Builder builder(TestName());
1162 HloInstruction* param0 = builder.AddInstruction(
1163 HloInstruction::CreateParameter(0, r0f32, "param0"));
1164 HloInstruction* exp0 = builder.AddInstruction(
1165 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1166 builder.AddInstruction(
1167 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1168
1169 auto computation = m->AddEntryComputation(builder.Build());
1170
1171 EXPECT_THAT(computation->root_instruction(),
1172 GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1173
1174 AlgebraicSimplifier simplifier(default_options_);
1175 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1176
1177 EXPECT_EQ(computation->root_instruction(), param0);
1178 }
1179
1180 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1181 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1182 auto m = CreateNewVerifiedModule();
1183 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1184 HloComputation::Builder builder(TestName());
1185 HloInstruction* param0 = builder.AddInstruction(
1186 HloInstruction::CreateParameter(0, r0f32, "param0"));
1187 HloInstruction* param1 = builder.AddInstruction(
1188 HloInstruction::CreateParameter(1, r0f32, "param1"));
1189 HloInstruction* exp0 = builder.AddInstruction(
1190 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1191 HloInstruction* exp1 = builder.AddInstruction(
1192 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1193 HloInstruction* div = builder.AddInstruction(
1194 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1195 builder.AddInstruction(
1196 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1197
1198 auto computation = m->AddEntryComputation(builder.Build());
1199
1200 EXPECT_THAT(computation->root_instruction(),
1201 GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1202 m::Exp(m::Parameter(1))))));
1203
1204 AlgebraicSimplifier simplifier(default_options_);
1205 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1206
1207 EXPECT_THAT(computation->root_instruction(),
1208 GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1209 }
1210
1211 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1212 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1213 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1214 auto m = CreateNewVerifiedModule();
1215 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1216 HloComputation::Builder builder(TestName());
1217 HloInstruction* param0 = builder.AddInstruction(
1218 HloInstruction::CreateParameter(0, r0f32, "param0"));
1219 HloInstruction* zero = builder.AddInstruction(
1220 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1221 builder.AddInstruction(
1222 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1223
1224 auto computation = m->AddEntryComputation(builder.Build());
1225
1226 EXPECT_THAT(computation->root_instruction(),
1227 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1228
1229 AlgebraicSimplifier simplifier(default_options_);
1230 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1231
1232 HloInstruction* root = computation->root_instruction();
1233 EXPECT_THAT(root, GmockMatch(m::Constant()));
1234 EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1235 }
1236
1237 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1238 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1239 auto m = CreateNewVerifiedModule();
1240 Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1241 HloComputation::Builder builder(TestName());
1242 HloInstruction* param0 = builder.AddInstruction(
1243 HloInstruction::CreateParameter(0, r1f32, "param0"));
1244 HloInstruction* zero = builder.AddInstruction(
1245 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1246 builder.AddInstruction(
1247 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1248
1249 auto computation = m->AddEntryComputation(builder.Build());
1250
1251 EXPECT_THAT(computation->root_instruction(),
1252 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1253
1254 AlgebraicSimplifier simplifier(default_options_);
1255 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1256
1257 HloInstruction* root = computation->root_instruction();
1258 EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1259 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1260 << ShapeUtil::HumanString(root->shape());
1261 EXPECT_EQ(root->dimensions().size(), 0);
1262 EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1263 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1264 }
1265
1266 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1267 TEST_F(AlgebraicSimplifierTest, Pow1) {
1268 auto m = CreateNewVerifiedModule();
1269 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1270 HloComputation::Builder builder(TestName());
1271 HloInstruction* param0 = builder.AddInstruction(
1272 HloInstruction::CreateParameter(0, r0f32, "param0"));
1273 HloInstruction* one = builder.AddInstruction(
1274 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1275 builder.AddInstruction(
1276 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1277
1278 auto computation = m->AddEntryComputation(builder.Build());
1279
1280 EXPECT_THAT(computation->root_instruction(),
1281 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1282
1283 AlgebraicSimplifier simplifier(default_options_);
1284 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1285
1286 EXPECT_EQ(computation->root_instruction(), param0);
1287 }
1288
1289 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1290 TEST_F(AlgebraicSimplifierTest, Pow2) {
1291 auto m = CreateNewVerifiedModule();
1292 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1293 HloComputation::Builder builder(TestName());
1294 HloInstruction* param0 = builder.AddInstruction(
1295 HloInstruction::CreateParameter(0, r0f32, "param0"));
1296 HloInstruction* two = builder.AddInstruction(
1297 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1298 builder.AddInstruction(
1299 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1300
1301 auto computation = m->AddEntryComputation(builder.Build());
1302
1303 EXPECT_THAT(computation->root_instruction(),
1304 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1305
1306 AlgebraicSimplifier simplifier(default_options_);
1307 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1308
1309 EXPECT_THAT(computation->root_instruction(),
1310 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1311 }
1312
1313 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1314 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1315 auto m = CreateNewVerifiedModule();
1316 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1317 HloComputation::Builder builder(TestName());
1318 HloInstruction* param0 = builder.AddInstruction(
1319 HloInstruction::CreateParameter(0, r0f32, "param0"));
1320 HloInstruction* negative_one = builder.AddInstruction(
1321 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1322 builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1323 param0, negative_one));
1324
1325 auto computation = m->AddEntryComputation(builder.Build());
1326
1327 EXPECT_THAT(computation->root_instruction(),
1328 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1329
1330 AlgebraicSimplifier simplifier(default_options_);
1331 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1332
1333 HloInstruction* root = computation->root_instruction();
1334 EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0))));
1335 EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
1336 EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
1337 1);
1338 }
1339
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1340 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1341 auto m = CreateNewVerifiedModule();
1342 auto builder = HloComputation::Builder(TestName());
1343 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1344 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1345
1346 HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1347 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1348
1349 ConvolutionDimensionNumbers dnums;
1350 dnums.set_input_batch_dimension(0);
1351 dnums.add_input_spatial_dimensions(1);
1352 dnums.set_input_feature_dimension(2);
1353
1354 dnums.set_output_batch_dimension(0);
1355 dnums.add_output_spatial_dimensions(1);
1356 dnums.set_output_feature_dimension(2);
1357
1358 dnums.add_kernel_spatial_dimensions(0);
1359 dnums.set_kernel_input_feature_dimension(1);
1360 dnums.set_kernel_output_feature_dimension(2);
1361 Window window;
1362 WindowDimension* dim = window.add_dimensions();
1363 dim->set_size(3);
1364 dim->set_padding_low(0);
1365 dim->set_padding_high(0);
1366 dim->set_stride(1);
1367 dim->set_window_dilation(1);
1368 dim->set_base_dilation(1);
1369 dim->set_window_reversal(false);
1370 // Create add computation.
1371 builder.AddInstruction(HloInstruction::CreateConvolve(
1372 ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1373 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1374 m->AddEntryComputation(builder.Build());
1375 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1376 EXPECT_THAT(m->entry_computation()->root_instruction(),
1377 GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1378 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1379 EXPECT_THAT(m->entry_computation()->root_instruction(),
1380 GmockMatch(m::Broadcast(m::Constant())));
1381 }
1382
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1383 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1384 auto m = CreateNewVerifiedModule();
1385 auto builder = HloComputation::Builder(TestName());
1386 HloInstruction* param =
1387 builder.AddInstruction(HloInstruction::CreateParameter(
1388 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1389 Window window;
1390 for (int64 i = 0; i < 4; ++i) {
1391 WindowDimension* dim = window.add_dimensions();
1392 // Makes 1x2x3x1 window.
1393 dim->set_size((i % 3) + 1);
1394 dim->set_stride(1);
1395 dim->set_padding_low(0);
1396 dim->set_padding_high(0);
1397 dim->set_window_dilation(1);
1398 dim->set_base_dilation(1);
1399 }
1400 // Create add computation.
1401 HloComputation* add_computation = nullptr;
1402 {
1403 HloComputation::Builder builder(TestName() + ".add");
1404 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1405 HloInstruction* p0 = builder.AddInstruction(
1406 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1407 HloInstruction* p1 = builder.AddInstruction(
1408 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1409 builder.AddInstruction(
1410 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1411 add_computation = m->AddEmbeddedComputation(builder.Build());
1412 }
1413 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1414 ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1415 builder.AddInstruction(
1416 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1417 window, add_computation));
1418 m->AddEntryComputation(builder.Build());
1419 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1420 EXPECT_THAT(m->entry_computation()->root_instruction(),
1421 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1422 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1423 EXPECT_THAT(
1424 m->entry_computation()->root_instruction(),
1425 GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1426 }
1427
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1428 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1429 auto m = CreateNewVerifiedModule();
1430 auto builder = HloComputation::Builder(TestName());
1431 HloInstruction* param =
1432 builder.AddInstruction(HloInstruction::CreateParameter(
1433 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1434 Window window;
1435 for (int64 i = 0; i < 2; ++i) {
1436 WindowDimension* dim = window.add_dimensions();
1437 dim->set_size(1);
1438 dim->set_padding_low(1);
1439 dim->set_padding_high(1);
1440 dim->set_window_dilation(1);
1441 dim->set_base_dilation(1);
1442 }
1443 // Create add computation.
1444 HloComputation* add_computation = nullptr;
1445 {
1446 HloComputation::Builder builder(TestName() + ".add");
1447 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1448 HloInstruction* p0 = builder.AddInstruction(
1449 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1450 HloInstruction* p1 = builder.AddInstruction(
1451 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1452 builder.AddInstruction(
1453 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1454 add_computation = m->AddEmbeddedComputation(builder.Build());
1455 }
1456 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1457 ShapeUtil::MakeShape(F32, {5, 2}), param,
1458 builder.AddInstruction(
1459 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1460 window, add_computation));
1461 m->AddEntryComputation(builder.Build());
1462 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1463 EXPECT_THAT(m->entry_computation()->root_instruction(),
1464 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1465 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1466 EXPECT_THAT(m->entry_computation()->root_instruction(),
1467 GmockMatch(m::Broadcast(m::Constant())));
1468 }
1469
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)1470 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
1471 auto m = CreateNewVerifiedModule();
1472 auto builder = HloComputation::Builder(TestName());
1473 HloInstruction* param =
1474 builder.AddInstruction(HloInstruction::CreateParameter(
1475 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1476 PaddingConfig padding;
1477 for (int i = 0; i < 2; ++i) {
1478 PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
1479 dimension->set_edge_padding_low(1);
1480 dimension->set_edge_padding_high(1);
1481 dimension->set_interior_padding(0);
1482 }
1483 builder.AddInstruction(HloInstruction::CreatePad(
1484 ShapeUtil::MakeShape(F32, {5, 2}), param,
1485 builder.AddInstruction(
1486 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
1487 padding));
1488 m->AddEntryComputation(builder.Build());
1489 EXPECT_THAT(m->entry_computation()->root_instruction(),
1490 GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
1491 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1492 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1493 EXPECT_THAT(m->entry_computation()->root_instruction(),
1494 GmockMatch(m::Broadcast(m::Constant())));
1495 }
1496
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)1497 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
1498 auto m = CreateNewVerifiedModule();
1499 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1500
1501 auto builder = HloComputation::Builder(TestName());
1502 auto op = builder.AddInstruction(HloInstruction::CreateParameter(
1503 0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
1504 auto reshape1 = builder.AddInstruction(
1505 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
1506 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1507 ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
1508 builder.AddInstruction(HloInstruction::CreateReshape(
1509 ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
1510
1511 auto computation = builder.Build();
1512 m->AddEntryComputation(std::move(computation));
1513
1514 EXPECT_THAT(m->entry_computation()->root_instruction(),
1515 GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
1516
1517 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1518 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1519
1520 EXPECT_THAT(m->entry_computation()->root_instruction(), op);
1521 }
1522
1523 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)1524 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
1525 auto m = CreateNewVerifiedModule();
1526 HloComputation::Builder builder(TestName());
1527 HloInstruction* input = builder.AddInstruction(
1528 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
1529 builder.AddInstruction(
1530 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
1531
1532 auto computation = m->AddEntryComputation(builder.Build());
1533
1534 EXPECT_THAT(computation->root_instruction(),
1535 GmockMatch(m::Convert(m::Op().Is(input))));
1536
1537 AlgebraicSimplifier simplifier(default_options_);
1538 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1539
1540 EXPECT_THAT(computation->root_instruction(), input);
1541 }
1542
1543 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)1544 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
1545 auto m = CreateNewVerifiedModule();
1546 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1547 HloComputation::Builder builder(TestName());
1548 HloInstruction* param0 = builder.AddInstruction(
1549 HloInstruction::CreateParameter(0, r0f32, "param0"));
1550 builder.AddInstruction(
1551 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1552
1553 auto computation = m->AddEntryComputation(builder.Build());
1554
1555 EXPECT_THAT(computation->root_instruction(),
1556 GmockMatch(m::Copy(m::Parameter(0))));
1557
1558 AlgebraicSimplifier simplifier(default_options_);
1559 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1560
1561 EXPECT_THAT(computation->root_instruction(), param0);
1562 }
1563
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)1564 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
1565 auto m = CreateNewVerifiedModule();
1566 HloComputation::Builder builder(TestName());
1567 HloInstruction* param =
1568 builder.AddInstruction(HloInstruction::CreateParameter(
1569 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1570 "param"));
1571 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
1572 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1573 HloOpcode::kCopy, param));
1574 HloInstruction* reshape =
1575 builder.AddInstruction(HloInstruction::CreateReshape(
1576 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
1577 builder.AddInstruction(HloInstruction::CreateUnary(
1578 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
1579 HloOpcode::kCopy, reshape));
1580 auto computation = m->AddEntryComputation(builder.Build());
1581 EXPECT_THAT(computation->root_instruction(),
1582 GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
1583
1584 AlgebraicSimplifierOptions options;
1585 options.set_is_layout_sensitive(true);
1586 AlgebraicSimplifier simplifier(options);
1587 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1588 // Verify that the copy of reshape of copy is replaced.
1589 EXPECT_THAT(computation->root_instruction(),
1590 GmockMatch(m::Bitcast(m::Parameter(0))));
1591 }
1592
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)1593 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
1594 auto m = CreateNewVerifiedModule();
1595 HloComputation::Builder builder(TestName());
1596 HloInstruction* param =
1597 builder.AddInstruction(HloInstruction::CreateParameter(
1598 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1599 "param"));
1600 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
1601 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1602 HloOpcode::kCopy, param));
1603 builder.AddInstruction(HloInstruction::CreateReshape(
1604 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
1605
1606 auto computation = m->AddEntryComputation(builder.Build());
1607 EXPECT_THAT(computation->root_instruction(),
1608 GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
1609
1610 AlgebraicSimplifierOptions options;
1611 options.set_is_layout_sensitive(true);
1612 AlgebraicSimplifier simplifier(options);
1613 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1614 // Verify that the copy of reshape of copy is replaced.
1615 EXPECT_THAT(computation->root_instruction(),
1616 GmockMatch(m::Bitcast(m::Parameter(0))));
1617 }
1618
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)1619 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
1620 auto m = CreateNewVerifiedModule();
1621 HloComputation::Builder builder(TestName());
1622 HloInstruction* param =
1623 builder.AddInstruction(HloInstruction::CreateParameter(
1624 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
1625 "param"));
1626 builder.AddInstruction(HloInstruction::CreateUnary(
1627 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
1628 HloOpcode::kCopy, param));
1629 auto computation = m->AddEntryComputation(builder.Build());
1630 EXPECT_THAT(computation->root_instruction(),
1631 GmockMatch(m::Copy(m::Parameter(0))));
1632
1633 AlgebraicSimplifierOptions options(
1634 [](const Shape&, const Shape&) { return false; });
1635 options.set_is_layout_sensitive(true);
1636 AlgebraicSimplifier simplifier1(options);
1637 ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
1638 // Verify that the copy is not replaced.
1639 EXPECT_THAT(computation->root_instruction(),
1640 GmockMatch(m::Copy(m::Parameter(0))));
1641
1642 AlgebraicSimplifierOptions options2;
1643 options2.set_is_layout_sensitive(true);
1644 AlgebraicSimplifier simplifier2(options2);
1645 EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
1646 // Verify that the copy is replaced.
1647 EXPECT_THAT(computation->root_instruction(),
1648 GmockMatch(m::Bitcast(m::Parameter(0))));
1649 }
1650
1651 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)1652 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
1653 auto m = CreateNewVerifiedModule();
1654 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1655 HloComputation::Builder builder(TestName());
1656 HloInstruction* param0 = builder.AddInstruction(
1657 HloInstruction::CreateParameter(0, r1f32, "param0"));
1658 builder.AddInstruction(
1659 HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
1660
1661 auto computation = m->AddEntryComputation(builder.Build());
1662
1663 EXPECT_THAT(computation->root_instruction(),
1664 GmockMatch(m::Concatenate(m::Parameter(0))));
1665
1666 AlgebraicSimplifier simplifier(default_options_);
1667 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1668
1669 EXPECT_THAT(computation->root_instruction(), param0);
1670 }
1671
1672 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)1673 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
1674 auto m = CreateNewVerifiedModule();
1675 const int kParamLength = 100;
1676 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1677 HloComputation::Builder builder(TestName());
1678 HloInstruction* param0 = builder.AddInstruction(
1679 HloInstruction::CreateParameter(0, r1f32, "param0"));
1680 HloInstruction* param1 = builder.AddInstruction(
1681 HloInstruction::CreateParameter(1, r1f32, "param1"));
1682 HloInstruction* empty_literal = builder.AddInstruction(
1683 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
1684 HloInstruction* empty_slice =
1685 builder.AddInstruction(HloInstruction::CreateSlice(
1686 ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
1687 Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
1688 builder.AddInstruction(HloInstruction::CreateConcatenate(
1689 result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
1690
1691 auto computation = m->AddEntryComputation(builder.Build());
1692
1693 EXPECT_THAT(computation->root_instruction(),
1694 GmockMatch(m::Concatenate(
1695 m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
1696 m::Op().Is(empty_slice), m::Parameter(1))));
1697
1698 AlgebraicSimplifier simplifier(default_options_);
1699 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1700
1701 EXPECT_THAT(computation->root_instruction(),
1702 GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
1703 m::Parameter(1))));
1704 }
1705
1706 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)1707 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
1708 auto m = CreateNewVerifiedModule();
1709 const int kParamLength = 100;
1710 Shape r3f32 =
1711 ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
1712 HloComputation::Builder builder(TestName());
1713 HloInstruction* param0 = builder.AddInstruction(
1714 HloInstruction::CreateParameter(0, r3f32, "param0"));
1715 HloInstruction* param1 = builder.AddInstruction(
1716 HloInstruction::CreateParameter(1, r3f32, "param1"));
1717 HloInstruction* param2 = builder.AddInstruction(
1718 HloInstruction::CreateParameter(2, r3f32, "param2"));
1719 Shape concat_shape =
1720 ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
1721 HloInstruction* Concatenate =
1722 builder.AddInstruction(HloInstruction::CreateConcatenate(
1723 concat_shape, {param0, param1, param2}, 1));
1724 HloComputation* add_computation = nullptr;
1725 {
1726 HloComputation::Builder builder(TestName() + ".add");
1727 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1728 HloInstruction* p0 = builder.AddInstruction(
1729 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1730 HloInstruction* p1 = builder.AddInstruction(
1731 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1732 builder.AddInstruction(
1733 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1734 add_computation = m->AddEmbeddedComputation(builder.Build());
1735 }
1736 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
1737 Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
1738
1739 HloInstruction* zero = builder.AddInstruction(
1740 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1741 builder.AddInstruction(HloInstruction::CreateReduce(
1742 reduce_shape, Concatenate, zero, {1, 2}, add_computation));
1743
1744 auto computation = m->AddEntryComputation(builder.Build());
1745
1746 AlgebraicSimplifier simplifier(default_options_);
1747 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1748
1749 EXPECT_THAT(
1750 computation->root_instruction(),
1751 GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
1752 m::Reduce(m::Parameter(1), m::Op().Is(zero))),
1753 m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
1754 }
1755
1756 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)1757 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
1758 auto m = CreateNewVerifiedModule();
1759 const int kParamLength = 100;
1760 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1761 HloComputation::Builder builder(TestName());
1762 HloInstruction* param0 = builder.AddInstruction(
1763 HloInstruction::CreateParameter(0, r1f32, "param0"));
1764 HloInstruction* empty_literal = builder.AddInstruction(
1765 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
1766 HloInstruction* empty_slice =
1767 builder.AddInstruction(HloInstruction::CreateSlice(
1768 ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
1769 Shape result_shape = ShapeUtil::MakeShape(F32, {0});
1770 builder.AddInstruction(HloInstruction::CreateConcatenate(
1771 result_shape, {empty_literal, empty_slice}, 0));
1772
1773 auto computation = m->AddEntryComputation(builder.Build());
1774
1775 EXPECT_THAT(computation->root_instruction(),
1776 GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
1777 m::Op().Is(empty_slice))));
1778
1779 AlgebraicSimplifier simplifier(default_options_);
1780 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1781
1782 EXPECT_EQ(computation->root_instruction(), empty_literal);
1783 }
1784
1785 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)1786 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
1787 auto m = CreateNewVerifiedModule();
1788 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1789 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1790 HloComputation::Builder builder(TestName());
1791 HloInstruction* param0 = builder.AddInstruction(
1792 HloInstruction::CreateParameter(0, r1f32, "param0"));
1793 HloInstruction* param1 = builder.AddInstruction(
1794 HloInstruction::CreateParameter(1, r0f32, "param1"));
1795 HloInstruction* broadcast = builder.AddInstruction(
1796 HloInstruction::CreateBroadcast(r1f32, param1, {}));
1797 builder.AddInstruction(HloInstruction::CreateConcatenate(
1798 ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
1799
1800 auto computation = m->AddEntryComputation(builder.Build());
1801
1802 AlgebraicSimplifier simplifier(default_options_);
1803 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1804 EXPECT_THAT(computation->root_instruction(),
1805 GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
1806 }
1807
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)1808 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
1809 auto m = CreateNewVerifiedModule();
1810 Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
1811 Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80});
1812 HloComputation::Builder builder(TestName());
1813 HloInstruction* param0 = builder.AddInstruction(
1814 HloInstruction::CreateParameter(0, r2f32, "param0"));
1815 HloInstruction* param1 = builder.AddInstruction(
1816 HloInstruction::CreateParameter(1, r2f32, "param1"));
1817
1818 HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
1819 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
1820 /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
1821
1822 // Cannot merge 'slice0' and 'slice1' because of different start indices in
1823 // dimension 0.
1824 HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
1825 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
1826 /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
1827
1828 // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
1829 HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
1830 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
1831 /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
1832
1833 // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
1834 HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
1835 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
1836 /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
1837
1838 // Can merge 'slice3' and 'slice4'.
1839 HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
1840 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
1841 /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
1842
1843 // Can merge 'slice4' and 'slice5'.
1844 HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
1845 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
1846 /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
1847
1848 // Cannot merge 'slice5' and 'slice6' because of overlap.
1849 HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
1850 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
1851 /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
1852
1853 // Cannot merge 'slice6' and 'slice7' because of slicing from a different
1854 // parameter.
1855 HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
1856 ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
1857 /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
1858
1859 builder.AddInstruction(HloInstruction::CreateConcatenate(
1860 concat_shape,
1861 {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1));
1862 auto computation = m->AddEntryComputation(builder.Build());
1863
1864 AlgebraicSimplifier simplifier(default_options_);
1865 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1866 auto s = m::Slice(m::Parameter(0));
1867 EXPECT_THAT(
1868 computation->root_instruction(),
1869 GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
1870 // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
1871 // shape should have dimensions {50, 30}.
1872 EXPECT_TRUE(
1873 ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
1874 ShapeUtil::MakeShape(F32, {50, 30})));
1875 EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
1876 }
1877
1878 // Test that a simplification which changes layouts is not performed if layout
1879 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)1880 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
1881 auto m = CreateNewVerifiedModule();
1882 HloComputation::Builder builder(TestName());
1883 HloInstruction* param0 =
1884 builder.AddInstruction(HloInstruction::CreateParameter(
1885 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1886 HloInstruction* copy = builder.AddInstruction(
1887 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1888
1889 auto computation = m->AddEntryComputation(builder.Build());
1890
1891 // Set to different layouts.
1892 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1893 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
1894
1895 EXPECT_THAT(computation->root_instruction(),
1896 GmockMatch(m::Copy(m::Parameter(0))));
1897
1898 AlgebraicSimplifierOptions options;
1899 options.set_is_layout_sensitive(true);
1900 AlgebraicSimplifier simplifier(options);
1901 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1902
1903 // Copy has not been removed.
1904 EXPECT_THAT(computation->root_instruction(),
1905 GmockMatch(m::Copy(m::Parameter(0))));
1906 }
1907
1908 // Test that a simplification which preserves layouts is performed if layout
1909 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)1910 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
1911 auto m = CreateNewVerifiedModule();
1912 HloComputation::Builder builder(TestName());
1913 HloInstruction* param0 =
1914 builder.AddInstruction(HloInstruction::CreateParameter(
1915 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1916 HloInstruction* copy = builder.AddInstruction(
1917 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1918
1919 auto computation = m->AddEntryComputation(builder.Build());
1920
1921 // Set to same layouts.
1922 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1923 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1924
1925 EXPECT_THAT(computation->root_instruction(),
1926 GmockMatch(m::Copy(m::Parameter(0))));
1927
1928 AlgebraicSimplifierOptions options;
1929 options.set_is_layout_sensitive(true);
1930 AlgebraicSimplifier simplifier(options);
1931 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1932
1933 // Copy has been removed.
1934 EXPECT_THAT(computation->root_instruction(), param0);
1935 }
1936
1937 // Test that a reshape which could be replaced with a bitcast is not if
1938 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)1939 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
1940 auto m = CreateNewVerifiedModule();
1941 HloComputation::Builder builder(TestName());
1942 HloInstruction* param0 =
1943 builder.AddInstruction(HloInstruction::CreateParameter(
1944 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1945 HloInstruction* reshape =
1946 builder.AddInstruction(HloInstruction::CreateReshape(
1947 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
1948
1949 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1950 *reshape->mutable_shape()->mutable_layout() =
1951 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
1952
1953 auto computation = m->AddEntryComputation(builder.Build());
1954
1955 EXPECT_THAT(computation->root_instruction(),
1956 GmockMatch(m::Reshape(m::Parameter(0))));
1957
1958 AlgebraicSimplifierOptions options(
1959 [](const Shape&, const Shape&) { return false; });
1960 options.set_is_layout_sensitive(true);
1961 AlgebraicSimplifier simplifier(options);
1962 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1963
1964 // Reshape is not replaced with a bitcast.
1965 EXPECT_THAT(computation->root_instruction(),
1966 GmockMatch(m::Reshape(m::Parameter(0))));
1967 }
1968
1969 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)1970 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
1971 auto m = CreateNewVerifiedModule();
1972 HloComputation::Builder builder(TestName());
1973 HloInstruction* zero = builder.AddInstruction(
1974 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
1975 HloInstruction* one = builder.AddInstruction(
1976 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1977 HloInstruction* rng0 = builder.AddInstruction(
1978 HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
1979 RandomDistribution::RNG_UNIFORM, {zero, one}));
1980
1981 HloInstruction* transpose = builder.AddInstruction(
1982 HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
1983 Shape reshape_shape = builder
1984 .AddInstruction(HloInstruction::CreateReshape(
1985 ShapeUtil::MakeShape(F32, {4}), transpose))
1986 ->shape();
1987
1988 auto computation = m->AddEntryComputation(builder.Build());
1989
1990 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
1991 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1992
1993 // Verify that reshape(transpose(rng)) is replace by a single rng of the
1994 // same shape as the reshape.
1995 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
1996 EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
1997 reshape_shape));
1998 }
1999
2000 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2001 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2002 auto m = CreateNewVerifiedModule();
2003 HloComputation::Builder builder(TestName());
2004 HloInstruction* param0 =
2005 builder.AddInstruction(HloInstruction::CreateParameter(
2006 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2007 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2008
2009 // Reshape which can be transformed into a bitcast.
2010 HloInstruction* transformable_reshape =
2011 builder.AddInstruction(HloInstruction::CreateReshape(
2012 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2013 *transformable_reshape->mutable_shape()->mutable_layout() =
2014 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2015
2016 // Reshape does not just add degenerate dimensions.
2017 HloInstruction* dimensions_wrong_reshape =
2018 builder.AddInstruction(HloInstruction::CreateReshape(
2019 ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2020 *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2021 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2022
2023 // Reshape has wrong layout.
2024 HloInstruction* layout_wrong_reshape =
2025 builder.AddInstruction(HloInstruction::CreateReshape(
2026 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2027 *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2028 LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2029
2030 // Collect all the reshapes into a tuple so they are not dead.
2031 builder.AddInstruction(HloInstruction::CreateTuple(
2032 {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2033
2034 auto computation = m->AddEntryComputation(builder.Build());
2035
2036 EXPECT_THAT(computation->root_instruction(),
2037 GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2038 m::Op().Is(dimensions_wrong_reshape),
2039 m::Op().Is(layout_wrong_reshape))));
2040
2041 AlgebraicSimplifierOptions options;
2042 options.set_is_layout_sensitive(true);
2043 AlgebraicSimplifier simplifier(options);
2044 simplifier.Run(m.get()).ValueOrDie();
2045
2046 // Verify that only the first reshape is replaced.
2047 EXPECT_THAT(
2048 computation->root_instruction(),
2049 GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
2050 m::Op().Is(layout_wrong_reshape))));
2051 }
2052
2053 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2054 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)2055 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
2056 auto m = CreateNewVerifiedModule();
2057 HloComputation::Builder builder(TestName());
2058
2059 // This add (param0 + 0) can be simplified.
2060 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2061 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2062 shape, HloOpcode::kAdd,
2063 builder.AddInstruction(
2064 HloInstruction::CreateParameter(0, shape, "param0")),
2065 builder.AddInstruction(HloInstruction::CreateConstant(
2066 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2067
2068 builder.AddInstruction(
2069 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
2070
2071 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2072 m->AddEntryComputation(builder.Build());
2073 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2074 }
2075
2076 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2077 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)2078 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
2079 auto m = CreateNewVerifiedModule();
2080 HloComputation::Builder builder(TestName());
2081
2082 // This add (param0 + 0) can be simplified.
2083 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2084 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2085 shape, HloOpcode::kAdd,
2086 builder.AddInstruction(
2087 HloInstruction::CreateParameter(0, shape, "param0")),
2088 builder.AddInstruction(HloInstruction::CreateConstant(
2089 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2090
2091 builder.AddInstruction(
2092 HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
2093 /*broadcast_dimensions=*/{0, 1}));
2094
2095 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2096 m->AddEntryComputation(builder.Build());
2097 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2098 }
2099
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)2100 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
2101 auto m = CreateNewVerifiedModule();
2102 HloComputation::Builder builder(TestName());
2103 HloInstruction* param =
2104 builder.AddInstruction(HloInstruction::CreateParameter(
2105 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
2106 *param->mutable_shape()->mutable_layout() =
2107 LayoutUtil::MakeLayout({1, 2, 0, 3});
2108
2109 HloInstruction* transpose =
2110 builder.AddInstruction(HloInstruction::CreateTranspose(
2111 ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
2112 *transpose->mutable_shape()->mutable_layout() =
2113 LayoutUtil::MakeLayout({0, 1, 2, 3});
2114
2115 auto computation = m->AddEntryComputation(builder.Build());
2116
2117 EXPECT_THAT(computation->root_instruction(),
2118 GmockMatch(m::Transpose(m::Parameter(0))));
2119
2120 AlgebraicSimplifierOptions options;
2121 options.set_is_layout_sensitive(true);
2122 AlgebraicSimplifier simplifier(options);
2123 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2124
2125 // Verify that the reshape is replaced.
2126 EXPECT_THAT(computation->root_instruction(),
2127 GmockMatch(m::Bitcast(m::Parameter(0))));
2128 }
2129
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)2130 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
2131 auto m = CreateNewVerifiedModule();
2132 HloComputation::Builder builder(TestName());
2133 HloInstruction* param =
2134 builder.AddInstruction(HloInstruction::CreateParameter(
2135 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
2136 *param->mutable_shape()->mutable_layout() =
2137 LayoutUtil::MakeLayout({1, 2, 3, 0});
2138
2139 HloInstruction* transpose =
2140 builder.AddInstruction(HloInstruction::CreateTranspose(
2141 ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
2142 *transpose->mutable_shape()->mutable_layout() =
2143 LayoutUtil::MakeLayout({3, 1, 2, 0});
2144
2145 auto computation = m->AddEntryComputation(builder.Build());
2146
2147 EXPECT_THAT(computation->root_instruction(),
2148 GmockMatch(m::Transpose(m::Parameter(0))));
2149
2150 AlgebraicSimplifierOptions options;
2151 options.set_is_layout_sensitive(true);
2152 AlgebraicSimplifier simplifier(options);
2153 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2154
2155 // Verify that the reshape is replaced.
2156 EXPECT_THAT(computation->root_instruction(),
2157 GmockMatch(m::Bitcast(m::Parameter(0))));
2158 }
2159
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)2160 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
2161 auto m = CreateNewVerifiedModule();
2162 HloComputation::Builder builder(TestName());
2163 HloInstruction* param0 =
2164 builder.AddInstruction(HloInstruction::CreateParameter(
2165 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2166
2167 HloInstruction* reshape1 =
2168 builder.AddInstruction(HloInstruction::CreateReshape(
2169 ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
2170
2171 builder.AddInstruction(HloInstruction::CreateReshape(
2172 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
2173
2174 auto computation = m->AddEntryComputation(builder.Build());
2175
2176 EXPECT_THAT(computation->root_instruction(),
2177 GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
2178
2179 AlgebraicSimplifier simplifier(default_options_);
2180 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2181
2182 EXPECT_THAT(computation->root_instruction(),
2183 GmockMatch(m::Reshape(m::Parameter(0))));
2184 }
2185
TEST_F(AlgebraicSimplifierTest,CopiesMerged)2186 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
2187 auto m = CreateNewVerifiedModule();
2188 HloComputation::Builder builder(TestName());
2189 HloInstruction* param0 =
2190 builder.AddInstruction(HloInstruction::CreateParameter(
2191 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
2192 "param0"));
2193
2194 HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2195 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
2196 HloOpcode::kCopy, param0));
2197
2198 builder.AddInstruction(HloInstruction::CreateUnary(
2199 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
2200 HloOpcode::kCopy, copy1));
2201
2202 auto computation = m->AddEntryComputation(builder.Build());
2203
2204 EXPECT_THAT(computation->root_instruction(),
2205 GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
2206
2207 AlgebraicSimplifierOptions options;
2208 options.set_is_layout_sensitive(true);
2209 AlgebraicSimplifier simplifier(options);
2210 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2211
2212 EXPECT_THAT(computation->root_instruction(),
2213 GmockMatch(m::Copy(m::Parameter(0))));
2214 }
2215
TEST_F(AlgebraicSimplifierTest,TransposesMerged)2216 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
2217 auto m = CreateNewVerifiedModule();
2218 HloComputation::Builder builder(TestName());
2219 HloInstruction* param0 =
2220 builder.AddInstruction(HloInstruction::CreateParameter(
2221 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
2222
2223 HloInstruction* transpose1 =
2224 builder.AddInstruction(HloInstruction::CreateTranspose(
2225 ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
2226
2227 builder.AddInstruction(HloInstruction::CreateTranspose(
2228 ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
2229
2230 auto computation = m->AddEntryComputation(builder.Build());
2231
2232 EXPECT_THAT(computation->root_instruction(),
2233 GmockMatch(m::Transpose(m::Op().Is(transpose1))));
2234
2235 AlgebraicSimplifier simplifier(default_options_);
2236 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2237
2238 EXPECT_THAT(computation->root_instruction(),
2239 GmockMatch(m::Transpose(m::Parameter(0))));
2240 EXPECT_EQ(std::vector<int64>({2, 1, 0}),
2241 computation->root_instruction()->dimensions());
2242 }
2243
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)2244 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
2245 const char* hlo_string = R"(
2246 HloModule module
2247
2248 ENTRY test {
2249 param = f32[10] parameter(0)
2250 reshaped = f32[1,1,10] reshape(f32[10] param)
2251 transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
2252 ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
2253 }
2254 )";
2255 TF_ASSERT_OK_AND_ASSIGN(auto module,
2256 ParseAndReturnVerifiedModule(hlo_string));
2257
2258 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2259 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2260 auto root = module->entry_computation()->root_instruction();
2261 EXPECT_THAT(root, GmockMatch(m::Parameter()));
2262 }
2263
2264 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)2265 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
2266 auto m = CreateNewVerifiedModule();
2267 HloComputation::Builder builder(TestName());
2268 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2269 0, ShapeUtil::MakeShape(F32, {5}), "param0"));
2270 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
2271 ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
2272 builder.AddInstruction(HloInstruction::CreateBroadcast(
2273 ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
2274
2275 auto computation = m->AddEntryComputation(builder.Build());
2276
2277 EXPECT_THAT(computation->root_instruction(),
2278 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2279
2280 AlgebraicSimplifier simplifier(default_options_);
2281 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2282
2283 EXPECT_THAT(computation->root_instruction(),
2284 GmockMatch(m::Broadcast(m::Parameter(0))));
2285 }
2286
2287 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)2288 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
2289 auto m = CreateNewVerifiedModule();
2290 HloComputation::Builder builder(TestName());
2291 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2292 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
2293 auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
2294 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
2295 builder.AddInstruction(HloInstruction::CreateReshape(
2296 ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
2297
2298 auto computation = m->AddEntryComputation(builder.Build());
2299
2300 EXPECT_THAT(computation->root_instruction(),
2301 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2302
2303 AlgebraicSimplifier simplifier(default_options_);
2304 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2305
2306 EXPECT_THAT(computation->root_instruction(),
2307 GmockMatch(m::Broadcast(m::Parameter(0))));
2308 }
2309
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)2310 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
2311 auto m = CreateNewVerifiedModule();
2312 HloComputation::Builder builder(TestName());
2313 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2314 0, ShapeUtil::MakeShape(F32, {1}), "param"));
2315 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2316 ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
2317 builder.AddInstruction(
2318 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
2319
2320 auto computation = m->AddEntryComputation(builder.Build());
2321
2322 EXPECT_THAT(computation->root_instruction(),
2323 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2324
2325 AlgebraicSimplifier simplifier(default_options_);
2326 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2327
2328 EXPECT_THAT(computation->root_instruction(),
2329 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2330 }
2331
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)2332 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
2333 auto m = CreateNewVerifiedModule();
2334 HloComputation::Builder builder(TestName());
2335 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2336 0, ShapeUtil::MakeShape(F32, {4}), "param"));
2337 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2338 ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
2339 builder.AddInstruction(HloInstruction::CreateReshape(
2340 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
2341
2342 HloComputation* computation = m->AddEntryComputation(builder.Build());
2343
2344 EXPECT_THAT(computation->root_instruction(),
2345 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2346
2347 AlgebraicSimplifier simplifier(default_options_);
2348 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2349
2350 EXPECT_THAT(computation->root_instruction(),
2351 GmockMatch(m::Broadcast(m::Parameter(0))));
2352 EXPECT_THAT(computation->root_instruction()->dimensions(),
2353 ::testing::ElementsAre(3));
2354 }
2355
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)2356 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
2357 auto m = CreateNewVerifiedModule();
2358 HloComputation::Builder builder(TestName());
2359 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2360 0, ShapeUtil::MakeShape(F32, {1}), "param"));
2361 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2362 ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
2363 builder.AddInstruction(HloInstruction::CreateReshape(
2364 ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
2365
2366 HloComputation* computation = m->AddEntryComputation(builder.Build());
2367
2368 EXPECT_THAT(computation->root_instruction(),
2369 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2370
2371 AlgebraicSimplifier simplifier(default_options_);
2372 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2373
2374 EXPECT_THAT(computation->root_instruction(),
2375 GmockMatch(m::Broadcast(m::Parameter(0))));
2376 const std::vector<int64> broadcast_dims =
2377 computation->root_instruction()->dimensions();
2378 EXPECT_EQ(1, broadcast_dims.size());
2379 EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3));
2380 }
2381
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)2382 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
2383 auto m = CreateNewVerifiedModule();
2384 HloComputation::Builder builder(TestName());
2385 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2386 0, ShapeUtil::MakeShape(F32, {4}), "param"));
2387 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2388 ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
2389 builder.AddInstruction(HloInstruction::CreateReshape(
2390 ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
2391
2392 HloComputation* computation = m->AddEntryComputation(builder.Build());
2393
2394 EXPECT_THAT(computation->root_instruction(),
2395 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2396
2397 AlgebraicSimplifier simplifier(default_options_);
2398 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2399
2400 EXPECT_THAT(computation->root_instruction(),
2401 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2402 }
2403
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)2404 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
2405 auto m = CreateNewVerifiedModule();
2406 HloComputation::Builder builder(TestName());
2407 auto iota = builder.AddInstruction(HloInstruction::CreateIota(
2408 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
2409 Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
2410 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
2411
2412 auto computation = m->AddEntryComputation(builder.Build());
2413
2414 EXPECT_THAT(computation->root_instruction(),
2415 GmockMatch(m::Reshape(m::Iota())));
2416
2417 AlgebraicSimplifier simplifier(default_options_);
2418 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2419
2420 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2421 EXPECT_TRUE(
2422 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
2423 }
2424
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)2425 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
2426 auto m = CreateNewVerifiedModule();
2427 HloComputation::Builder builder(TestName());
2428 auto iota = builder.AddInstruction(
2429 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
2430 auto result_shape = iota->shape();
2431
2432 auto computation = m->AddEntryComputation(builder.Build());
2433
2434 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2435
2436 AlgebraicSimplifier simplifier(default_options_);
2437 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2438
2439 auto root = computation->root_instruction();
2440 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
2441 EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
2442 EXPECT_TRUE(
2443 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
2444 }
2445
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)2446 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
2447 auto m = CreateNewVerifiedModule();
2448 HloComputation::Builder builder(TestName());
2449 auto iota = builder.AddInstruction(
2450 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
2451 builder.AddInstruction(
2452 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
2453
2454 auto computation = m->AddEntryComputation(builder.Build());
2455
2456 EXPECT_THAT(computation->root_instruction(),
2457 GmockMatch(m::Reshape(m::Iota())));
2458
2459 AlgebraicSimplifier simplifier(default_options_);
2460 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2461
2462 EXPECT_THAT(computation->root_instruction(),
2463 GmockMatch(m::Reshape(m::Iota())));
2464 }
2465
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)2466 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
2467 auto m = CreateNewVerifiedModule();
2468 HloComputation::Builder builder(TestName());
2469 auto iota = builder.AddInstruction(
2470 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
2471 builder.AddInstruction(HloInstruction::CreateReshape(
2472 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
2473
2474 HloComputation* computation = m->AddEntryComputation(builder.Build());
2475
2476 EXPECT_THAT(computation->root_instruction(),
2477 GmockMatch(m::Reshape(m::Iota())));
2478
2479 AlgebraicSimplifier simplifier(default_options_);
2480 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2481
2482 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2483 EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
2484 ->iota_dimension(),
2485 3);
2486 }
2487
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)2488 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
2489 auto m = CreateNewVerifiedModule();
2490 HloComputation::Builder builder(TestName());
2491 auto iota = builder.AddInstruction(
2492 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
2493 builder.AddInstruction(HloInstruction::CreateReshape(
2494 ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
2495
2496 HloComputation* computation = m->AddEntryComputation(builder.Build());
2497
2498 EXPECT_THAT(computation->root_instruction(),
2499 GmockMatch(m::Reshape(m::Iota())));
2500
2501 AlgebraicSimplifier simplifier(default_options_);
2502 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2503
2504 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
2505 const int64 iota_dim =
2506 Cast<HloIotaInstruction>(computation->root_instruction())
2507 ->iota_dimension();
2508 EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
2509 }
2510
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)2511 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
2512 auto m = CreateNewVerifiedModule();
2513 HloComputation::Builder builder(TestName());
2514 auto iota = builder.AddInstruction(
2515 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
2516 builder.AddInstruction(
2517 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
2518
2519 HloComputation* computation = m->AddEntryComputation(builder.Build());
2520
2521 EXPECT_THAT(computation->root_instruction(),
2522 GmockMatch(m::Reshape(m::Iota())));
2523
2524 AlgebraicSimplifier simplifier(default_options_);
2525 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2526
2527 EXPECT_THAT(computation->root_instruction(),
2528 GmockMatch(m::Reshape(m::Iota())));
2529 }
2530
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)2531 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
2532 HloComputation::Builder builder(TestName());
2533 HloInstruction* param =
2534 builder.AddInstruction(HloInstruction::CreateParameter(
2535 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
2536 HloInstruction* zero = builder.AddInstruction(
2537 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2538 PaddingConfig no_padding;
2539 for (int i = 0; i < 2; ++i) {
2540 auto dimension = no_padding.add_dimensions();
2541 dimension->set_edge_padding_low(0);
2542 dimension->set_edge_padding_high(0);
2543 dimension->set_interior_padding(0);
2544 }
2545 builder.AddInstruction(HloInstruction::CreatePad(
2546 ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
2547
2548 auto module = CreateNewVerifiedModule();
2549 HloComputation* computation = module->AddEntryComputation(builder.Build());
2550
2551 EXPECT_THAT(computation->root_instruction(),
2552 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2553
2554 AlgebraicSimplifier simplifier(default_options_);
2555 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2556
2557 EXPECT_THAT(computation->root_instruction(), param);
2558 }
2559
TEST_F(AlgebraicSimplifierTest,NegativePadding)2560 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
2561 // Verify that a pad instruction with negative padding is replaced with a
2562 // pad with non-negative padding followed by a slice.
2563 HloComputation::Builder builder(TestName());
2564 HloInstruction* param =
2565 builder.AddInstruction(HloInstruction::CreateParameter(
2566 0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
2567 HloInstruction* zero = builder.AddInstruction(
2568 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2569 PaddingConfig padding;
2570 int64 low_padding[2] = {-1, -2};
2571 int64 high_padding[2] = {2, -3};
2572 for (int i = 0; i < 2; ++i) {
2573 auto dimension = padding.add_dimensions();
2574 dimension->set_edge_padding_low(low_padding[i]);
2575 dimension->set_edge_padding_high(high_padding[i]);
2576 dimension->set_interior_padding(0);
2577 }
2578 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2579 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
2580
2581 auto module = CreateNewVerifiedModule();
2582 HloComputation* computation = module->AddEntryComputation(builder.Build());
2583
2584 AlgebraicSimplifier simplifier(default_options_);
2585
2586 auto has_negative_padding = [](const HloInstruction* pad) {
2587 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2588 if (padding_dimension.edge_padding_low() < 0 ||
2589 padding_dimension.edge_padding_high() < 0) {
2590 return true;
2591 }
2592 }
2593 return false;
2594 };
2595
2596 EXPECT_THAT(computation->root_instruction(),
2597 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2598 EXPECT_TRUE(has_negative_padding(pad));
2599
2600 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2601
2602 EXPECT_THAT(computation->root_instruction(),
2603 GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
2604 EXPECT_FALSE(
2605 has_negative_padding(computation->root_instruction()->operand(0)));
2606 }
2607
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)2608 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
2609 // Verify that a pad instruction with interior padding on one-sized
2610 // dimensions, removes the interior padding.
2611 HloComputation::Builder builder(TestName());
2612 HloInstruction* param =
2613 builder.AddInstruction(HloInstruction::CreateParameter(
2614 0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
2615 HloInstruction* zero = builder.AddInstruction(
2616 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2617 PaddingConfig padding;
2618 for (int i = 0; i < 2; ++i) {
2619 auto dimension = padding.add_dimensions();
2620 dimension->set_edge_padding_low(3);
2621 dimension->set_edge_padding_high(3);
2622 dimension->set_interior_padding(i * 3);
2623 }
2624 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2625 ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
2626
2627 auto module = CreateNewVerifiedModule();
2628 HloComputation* computation = module->AddEntryComputation(builder.Build());
2629
2630 AlgebraicSimplifier simplifier(default_options_);
2631
2632 ASSERT_THAT(computation->root_instruction(),
2633 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2634 ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
2635
2636 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2637
2638 EXPECT_THAT(computation->root_instruction(),
2639 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
2640 EXPECT_FALSE(
2641 HasInteriorPadding(computation->root_instruction()->padding_config()));
2642 }
2643
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)2644 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
2645 HloComputation::Builder builder(TestName());
2646 HloInstruction* param =
2647 builder.AddInstruction(HloInstruction::CreateParameter(
2648 0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
2649 builder.AddInstruction(
2650 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
2651
2652 auto module = CreateNewVerifiedModule();
2653 HloComputation* computation = module->AddEntryComputation(builder.Build());
2654
2655 EXPECT_THAT(computation->root_instruction(),
2656 GmockMatch(m::Reshape(m::Parameter(0))));
2657
2658 AlgebraicSimplifier simplifier(default_options_);
2659 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2660
2661 EXPECT_THAT(computation->root_instruction(), param);
2662 }
2663
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)2664 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
2665 HloComputation::Builder builder(TestName());
2666 const int64 dim0 = 2;
2667 const int64 dim1 = 3;
2668 HloInstruction* param =
2669 builder.AddInstruction(HloInstruction::CreateParameter(
2670 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
2671 builder.AddInstruction(HloInstruction::CreateSlice(
2672 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
2673 /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
2674
2675 auto module = CreateNewVerifiedModule();
2676 HloComputation* computation = module->AddEntryComputation(builder.Build());
2677
2678 EXPECT_THAT(computation->root_instruction(),
2679 GmockMatch(m::Slice(m::Parameter(0))));
2680
2681 AlgebraicSimplifier simplifier(default_options_);
2682 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2683
2684 EXPECT_THAT(computation->root_instruction(), param);
2685 }
2686
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)2687 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
2688 HloComputation::Builder builder(TestName());
2689 const int64 dim0 = 11;
2690 const int64 dim1 = 12;
2691 HloInstruction* param =
2692 builder.AddInstruction(HloInstruction::CreateParameter(
2693 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
2694 HloInstruction* original_slice =
2695 builder.AddInstruction(HloInstruction::CreateSlice(
2696 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
2697 /*start_indices=*/{1, 2},
2698 /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
2699
2700 builder.AddInstruction(HloInstruction::CreateSlice(
2701 ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
2702 /*start_indices=*/{2, 3},
2703 /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
2704 auto module = CreateNewVerifiedModule();
2705 HloComputation* computation = module->AddEntryComputation(builder.Build());
2706
2707 EXPECT_THAT(computation->root_instruction(),
2708 GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
2709
2710 AlgebraicSimplifier simplifier(default_options_);
2711 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2712
2713 EXPECT_THAT(computation->root_instruction(),
2714 GmockMatch(m::Slice(m::Parameter(0))));
2715 EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
2716 EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
2717 EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
2718 EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
2719 }
2720
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)2721 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
2722 HloComputation::Builder builder(TestName());
2723 const int64 dim0 = 11;
2724 const int64 dim1 = 12;
2725 HloInstruction* param =
2726 builder.AddInstruction(HloInstruction::CreateParameter(
2727 0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
2728 HloInstruction* broadcast =
2729 builder.AddInstruction(HloInstruction::CreateBroadcast(
2730 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
2731 builder.AddInstruction(HloInstruction::CreateSlice(
2732 ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
2733 /*start_indices=*/{0, 3},
2734 /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
2735 auto module = CreateNewVerifiedModule();
2736 HloComputation* computation = module->AddEntryComputation(builder.Build());
2737
2738 EXPECT_THAT(computation->root_instruction(),
2739 GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
2740
2741 AlgebraicSimplifier simplifier(default_options_);
2742 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2743
2744 EXPECT_THAT(computation->root_instruction(),
2745 GmockMatch(m::Broadcast(m::Parameter(0))));
2746 }
2747
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)2748 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
2749 HloComputation::Builder builder(TestName());
2750 const int64 dim0 = 11;
2751 const int64 dim1 = 12;
2752 const int64 dim2 = 13;
2753 HloInstruction* param =
2754 builder.AddInstruction(HloInstruction::CreateParameter(
2755 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
2756 HloInstruction* original_reshape =
2757 builder.AddInstruction(HloInstruction::CreateReshape(
2758 ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
2759
2760 builder.AddInstruction(HloInstruction::CreateSlice(
2761 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
2762 /*start_indices=*/{0, 0, 0},
2763 /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
2764 auto module = CreateNewVerifiedModule();
2765 HloComputation* computation = module->AddEntryComputation(builder.Build());
2766
2767 EXPECT_THAT(computation->root_instruction(),
2768 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
2769
2770 AlgebraicSimplifier simplifier(default_options_);
2771 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2772
2773 EXPECT_THAT(computation->root_instruction(),
2774 GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
2775 }
2776
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)2777 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
2778 HloComputation::Builder builder(TestName());
2779 HloInstruction* param =
2780 builder.AddInstruction(HloInstruction::CreateParameter(
2781 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
2782 HloInstruction* original_reshape =
2783 builder.AddInstruction(HloInstruction::CreateReshape(
2784 ShapeUtil::MakeShape(F32, {3600, 512}), param));
2785
2786 builder.AddInstruction(HloInstruction::CreateSlice(
2787 ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
2788 /*start_indices=*/{0, 0},
2789 /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
2790 auto module = CreateNewVerifiedModule();
2791 HloComputation* computation = module->AddEntryComputation(builder.Build());
2792
2793 EXPECT_THAT(computation->root_instruction(),
2794 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
2795
2796 AlgebraicSimplifier simplifier(default_options_);
2797 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
2798 }
2799
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)2800 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
2801 auto builder = HloComputation::Builder(TestName());
2802 auto module = CreateNewVerifiedModule();
2803
2804 Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
2805 auto keys = builder.AddInstruction(
2806 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2807 TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
2808 module.get())
2809 .status());
2810 HloComputation* computation = module->AddEntryComputation(builder.Build());
2811 AlgebraicSimplifier simplifier(default_options_);
2812 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2813 EXPECT_THAT(computation->root_instruction(), keys);
2814 }
2815
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)2816 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
2817 auto builder = HloComputation::Builder(TestName());
2818 auto module = CreateNewVerifiedModule();
2819
2820 Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
2821 Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
2822 auto keys = builder.AddInstruction(
2823 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2824 auto values0 = builder.AddInstruction(
2825 HloInstruction::CreateParameter(1, values_shape, "values0"));
2826 auto values1 = builder.AddInstruction(
2827 HloInstruction::CreateParameter(2, values_shape, "values1"));
2828 TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
2829 {keys_shape, values_shape, values_shape}),
2830 {keys, values0, values1}, 0, /*is_stable=*/false,
2831 &builder, module.get())
2832 .status());
2833 HloComputation* computation = module->AddEntryComputation(builder.Build());
2834 AlgebraicSimplifier simplifier(default_options_);
2835 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2836 EXPECT_THAT(computation->root_instruction(),
2837 GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
2838 m::Op().Is(values1))));
2839 }
2840
2841 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)2842 TEST_F(AlgebraicSimplifierTest, AndTrue) {
2843 auto m = CreateNewVerifiedModule();
2844 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2845 HloComputation::Builder builder(TestName());
2846 HloInstruction* param0 = builder.AddInstruction(
2847 HloInstruction::CreateParameter(0, r0pred, "param0"));
2848 HloInstruction* const_true = builder.AddInstruction(
2849 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2850 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2851 param0, const_true));
2852
2853 auto computation = m->AddEntryComputation(builder.Build());
2854 HloInstruction* root = computation->root_instruction();
2855 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2856 AlgebraicSimplifier simplifier(default_options_);
2857 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2858 root = computation->root_instruction();
2859 EXPECT_EQ(root, param0);
2860 }
2861
2862 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)2863 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
2864 auto m = CreateNewVerifiedModule();
2865 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2866 HloComputation::Builder builder(TestName());
2867 HloInstruction* param0 = builder.AddInstruction(
2868 HloInstruction::CreateParameter(0, r0pred, "param0"));
2869 HloInstruction* const_true = builder.AddInstruction(
2870 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2871 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2872 const_true, param0));
2873
2874 auto computation = m->AddEntryComputation(builder.Build());
2875 HloInstruction* root = computation->root_instruction();
2876 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2877 AlgebraicSimplifier simplifier(default_options_);
2878 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2879 root = computation->root_instruction();
2880 EXPECT_EQ(root, param0);
2881 }
2882
2883 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)2884 TEST_F(AlgebraicSimplifierTest, AndFalse) {
2885 auto m = CreateNewVerifiedModule();
2886 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2887 HloComputation::Builder builder(TestName());
2888 HloInstruction* param0 = builder.AddInstruction(
2889 HloInstruction::CreateParameter(0, r0pred, "param0"));
2890 HloInstruction* const_false = builder.AddInstruction(
2891 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2892 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2893 param0, const_false));
2894
2895 auto computation = m->AddEntryComputation(builder.Build());
2896 HloInstruction* root = computation->root_instruction();
2897 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2898 AlgebraicSimplifier simplifier(default_options_);
2899 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2900 root = computation->root_instruction();
2901 EXPECT_EQ(root, const_false);
2902 }
2903
2904 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)2905 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
2906 auto m = CreateNewVerifiedModule();
2907 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2908 HloComputation::Builder builder(TestName());
2909 HloInstruction* param0 = builder.AddInstruction(
2910 HloInstruction::CreateParameter(0, r0pred, "param0"));
2911 HloInstruction* const_false = builder.AddInstruction(
2912 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2913 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
2914 const_false, param0));
2915
2916 auto computation = m->AddEntryComputation(builder.Build());
2917 HloInstruction* root = computation->root_instruction();
2918 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
2919 AlgebraicSimplifier simplifier(default_options_);
2920 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2921 root = computation->root_instruction();
2922 EXPECT_EQ(root, const_false);
2923 }
2924
2925 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)2926 TEST_F(AlgebraicSimplifierTest, OrTrue) {
2927 auto m = CreateNewVerifiedModule();
2928 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2929 HloComputation::Builder builder(TestName());
2930 HloInstruction* param0 = builder.AddInstruction(
2931 HloInstruction::CreateParameter(0, r0pred, "param0"));
2932 HloInstruction* const_true = builder.AddInstruction(
2933 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2934 builder.AddInstruction(
2935 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
2936
2937 auto computation = m->AddEntryComputation(builder.Build());
2938 HloInstruction* root = computation->root_instruction();
2939 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2940 AlgebraicSimplifier simplifier(default_options_);
2941 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2942 root = computation->root_instruction();
2943 EXPECT_EQ(root, const_true);
2944 }
2945
2946 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)2947 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
2948 auto m = CreateNewVerifiedModule();
2949 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2950 HloComputation::Builder builder(TestName());
2951 HloInstruction* param0 = builder.AddInstruction(
2952 HloInstruction::CreateParameter(0, r0pred, "param0"));
2953 HloInstruction* const_true = builder.AddInstruction(
2954 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2955 builder.AddInstruction(
2956 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
2957
2958 auto computation = m->AddEntryComputation(builder.Build());
2959 HloInstruction* root = computation->root_instruction();
2960 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2961 AlgebraicSimplifier simplifier(default_options_);
2962 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2963 root = computation->root_instruction();
2964 EXPECT_EQ(root, const_true);
2965 }
2966
2967 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)2968 TEST_F(AlgebraicSimplifierTest, OrFalse) {
2969 auto m = CreateNewVerifiedModule();
2970 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2971 HloComputation::Builder builder(TestName());
2972 HloInstruction* param0 = builder.AddInstruction(
2973 HloInstruction::CreateParameter(0, r0pred, "param0"));
2974 HloInstruction* const_false = builder.AddInstruction(
2975 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2976 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
2977 param0, const_false));
2978
2979 auto computation = m->AddEntryComputation(builder.Build());
2980 HloInstruction* root = computation->root_instruction();
2981 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
2982 AlgebraicSimplifier simplifier(default_options_);
2983 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2984 root = computation->root_instruction();
2985 EXPECT_EQ(root, param0);
2986 }
2987
2988 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)2989 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
2990 auto m = CreateNewVerifiedModule();
2991 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
2992 HloComputation::Builder builder(TestName());
2993 HloInstruction* param0 = builder.AddInstruction(
2994 HloInstruction::CreateParameter(0, r0pred, "param0"));
2995 HloInstruction* const_false = builder.AddInstruction(
2996 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2997 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
2998 const_false, param0));
2999
3000 auto computation = m->AddEntryComputation(builder.Build());
3001 HloInstruction* root = computation->root_instruction();
3002 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3003 AlgebraicSimplifier simplifier(default_options_);
3004 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3005 root = computation->root_instruction();
3006 EXPECT_EQ(root, param0);
3007 }
3008
3009 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
3010 // convolution's Window.
3011 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3012 ConvPaddingTestcase(absl::string_view padding,
3013 absl::string_view orig_conv_window,
3014 absl::string_view expected_conv_window)
3015 : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
3016 /*pad_value=*/0) {}
3017
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3018 ConvPaddingTestcase(absl::string_view padding,
3019 absl::string_view orig_conv_window,
3020 absl::string_view expected_conv_window, float pad_value)
3021 : padding(padding),
3022 orig_conv_window(orig_conv_window),
3023 expected_conv_window(expected_conv_window),
3024 pad_value(pad_value) {}
3025
ToStringxla::__anonac242c730111::ConvPaddingTestcase3026 string ToString() const {
3027 return absl::StrFormat(
3028 "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
3029 "pad_value=%f",
3030 padding, orig_conv_window, expected_conv_window, pad_value);
3031 }
3032
3033 string padding;
3034 string orig_conv_window;
3035 string expected_conv_window;
3036 float pad_value;
3037 };
3038
3039 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
3040 // computation that does
3041 //
3042 // conv(pad(param0, padding=padding), param1), window=orig_conv_window
3043 //
3044 // gets transformed by AlgebraicSimplifier to
3045 //
3046 // conv(param0, param1), window=expected_conv_window
3047 //
3048 // or, if expected_conv_window is the empty string, checks that
3049 // AlgebraicSimplifier does *not* transform the original convolution.
3050 class ConvInputPaddingTest
3051 : public AlgebraicSimplifierTest,
3052 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3053
3054 INSTANTIATE_TEST_SUITE_P(
3055 ConvInputPaddingTestCases, ConvInputPaddingTest,
3056 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3057 // Merge this edge padding into the conv.
3058 {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
3059 // Merge this edge padding with the conv's edge padding.
3060 {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
3061 // Merge this interior-padded kPad with the unpadded conv. The 3x6
3062 // interior padding gets transformed to 4x7 conv lhs dilation.
3063 {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
3064 // kPad has dilation on one dim, conv has it on the other; merge them.
3065 {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
3066 // kPad has dilation and edge padding on one dim, conv has them on the
3067 // other; merge them.
3068 {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
3069 "pad=0_1x3_0 lhs_dilate=2x10"},
3070
3071 // Don't transform if the pad value is nonzero.
3072 {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
3073
3074 // We refuse to transform the following because on some dimension, one
3075 // of the kPad and conv has dilation and the other has some sort of
3076 // padding.
3077 {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
3078 {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
3079 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3080 {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
3081 {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
3082 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3083
3084 // We can't merge feature or batch padding into the conv.
3085 {"1_0x0_0x0_0x0_0", "", ""},
3086 {"0_0x1_0x0_0x0_0", "", ""},
3087 }));
3088
TEST_P(ConvInputPaddingTest,DoTest)3089 TEST_P(ConvInputPaddingTest, DoTest) {
3090 ConvPaddingTestcase testcase = GetParam();
3091
3092 // It would be better to put the testcase's ToString into the test name, but
3093 // gUnit has constraints on what can go into test names, and any reasonable
3094 // implementation of ToString() seems to violate them.
3095 SCOPED_TRACE(testcase.ToString());
3096
3097 auto builder = HloComputation::Builder(TestName());
3098 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3099 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
3100 "input"));
3101 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3102 LiteralUtil::CreateR0(testcase.pad_value)));
3103
3104 PaddingConfig padding_config =
3105 ParsePaddingConfig(testcase.padding).ValueOrDie();
3106 auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3107 ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
3108 padding_config)
3109 .ValueOrDie(),
3110 input, pad_value, padding_config));
3111
3112 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3113 1,
3114 ShapeUtil::MakeShape(
3115 F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
3116 "input"));
3117
3118 ConvolutionDimensionNumbers dnums =
3119 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3120 Window window =
3121 ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
3122 .ValueOrDie();
3123 builder.AddInstruction(HloInstruction::CreateConvolve(
3124 ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
3125 /*feature_group_count=*/1,
3126 /*batch_group_count=*/1, window, dnums)
3127 .ValueOrDie(),
3128 lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
3129 window, dnums, DefaultPrecisionConfig(2)));
3130 auto module = CreateNewVerifiedModule();
3131 module->AddEntryComputation(builder.Build());
3132
3133 AlgebraicSimplifier simplifier(default_options_);
3134 if (testcase.expected_conv_window.empty()) {
3135 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3136 } else {
3137 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3138 auto* conv = module->entry_computation()->root_instruction();
3139 SCOPED_TRACE(module->ToString());
3140 ASSERT_THAT(conv,
3141 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3142 EXPECT_EQ(window_util::ToString(conv->window()),
3143 absl::StrCat("size=3x3 ", testcase.expected_conv_window));
3144 }
3145 }
3146
3147 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
3148 // computation that does
3149 //
3150 // conv(param0, pad(param1, padding=padding)), window=orig_conv_window
3151 //
3152 // gets transformed by AlgebraicSimplifier to
3153 //
3154 // conv(param0, param1), window=expected_conv_window
3155 //
3156 // or, if expected_conv_window is the empty string, checks that
3157 // AlgebraicSimplifier does *not* transform the original convolution.
3158 class ConvFilterPaddingTest
3159 : public AlgebraicSimplifierTest,
3160 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3161
3162 INSTANTIATE_TEST_SUITE_P(
3163 ConvFilterPaddingTestCases, ConvFilterPaddingTest,
3164 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3165 // Can only merge interior padding on the filter's spatial dimensions;
3166 // all
3167 // other paddings (edge padding and interior padding on the channel
3168 // dims)
3169 // should be rejected out of hand.
3170 {"1_0_0x0_0_0x0_0x0_0", "", ""},
3171 {"0_1_0x0_0_0x0_0x0_0", "", ""},
3172 {"0_0_1x0_0_0x0_0x0_0", "", ""},
3173 {"0_0_0x1_0_0x0_0x0_0", "", ""},
3174 {"0_0_0x0_1_0x0_0x0_0", "", ""},
3175 {"0_0_0x0_0_1x0_0x0_0", "", ""},
3176 {"0_0_0x0_0_0x1_0x0_0", "", ""},
3177 {"0_0_0x0_0_0x0_1x0_0", "", ""},
3178 {"0_0_0x0_0_0x0_0x1_0", "", ""},
3179 {"0_0_0x0_0_0x0_0x0_1", "", ""},
3180
3181 // Interior padding on channel dims can be merged into the conv, so long
3182 // as the conv and pad don't have interior padding on the same dim.
3183 {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
3184 {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
3185 {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
3186 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
3187 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
3188
3189 // Can't merge if for a given dim there's interior padding on both the
3190 // pad and conv.
3191 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
3192 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
3193
3194 // Don't transform if the pad value is nonzero.
3195 {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
3196 }));
3197
TEST_P(ConvFilterPaddingTest,DoIt)3198 TEST_P(ConvFilterPaddingTest, DoIt) {
3199 ConvPaddingTestcase testcase = GetParam();
3200
3201 // It would be better to put the testcase's ToString into the test name, but
3202 // gUnit has constraints on what can go into test names, and any reasonable
3203 // implementation of ToString() seems to violate them.
3204 SCOPED_TRACE(testcase.ToString());
3205
3206 auto builder = HloComputation::Builder(TestName());
3207 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3208 LiteralUtil::CreateR0(testcase.pad_value)));
3209 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3210 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
3211 "input"));
3212 PaddingConfig padding_config =
3213 ParsePaddingConfig(testcase.padding).ValueOrDie();
3214 auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3215 ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
3216 padding_config)
3217 .ValueOrDie(),
3218 filter, pad_value, padding_config));
3219
3220 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3221 0,
3222 ShapeUtil::MakeShape(
3223 F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
3224 "input"));
3225
3226 ConvolutionDimensionNumbers dnums =
3227 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3228 Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
3229 rhs_pad->shape().dimensions(2),
3230 rhs_pad->shape().dimensions(3),
3231 testcase.orig_conv_window))
3232 .ValueOrDie();
3233
3234 // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
3235 // after the transformation.
3236 PrecisionConfig precision_config;
3237 precision_config.add_operand_precision(PrecisionConfig::HIGH);
3238 precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
3239
3240 builder.AddInstruction(HloInstruction::CreateConvolve(
3241 ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
3242 /*feature_group_count=*/1,
3243 /*batch_group_count=*/1, window, dnums)
3244 .ValueOrDie(),
3245 input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
3246 window, dnums, precision_config));
3247
3248 auto module = CreateNewVerifiedModule();
3249 module->AddEntryComputation(builder.Build());
3250
3251 AlgebraicSimplifier simplifier(default_options_);
3252 if (testcase.expected_conv_window.empty()) {
3253 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3254 } else {
3255 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3256 auto* conv = module->entry_computation()->root_instruction();
3257 SCOPED_TRACE(module->ToString());
3258 ASSERT_THAT(conv,
3259 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3260 EXPECT_EQ(window_util::ToString(conv->window()),
3261 absl::StrFormat("size=%dx%d %s",
3262 conv->operand(1)->shape().dimensions(2),
3263 conv->operand(1)->shape().dimensions(3),
3264 testcase.expected_conv_window));
3265 EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
3266 ->precision_config()
3267 .operand_precision(),
3268 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
3269 }
3270 }
3271
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)3272 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
3273 struct ConvTestOptions {
3274 int in_batch = 10;
3275 int in_height = 2;
3276 int in_width = 2;
3277 int in_channels = 3;
3278 int f_width = 1;
3279 int f_height = 1;
3280 int f_output_channels = 10;
3281 int row_stride = 1;
3282 int row_padding = 0;
3283 int col_stride = 1;
3284 int col_padding = 0;
3285 bool input_minor_to_major_layout = false;
3286 bool filter_minor_to_major_layout = false;
3287 bool output_minor_to_major_layout = false;
3288
3289 const char* dim_order = "NHWC"; // can use chars NHWC in any order.
3290 const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order.
3291
3292 ConvTestOptions& Reset() {
3293 *this = ConvTestOptions();
3294 return *this;
3295 }
3296 };
3297
3298 ConvTestOptions options;
3299
3300 // Builds a convolution from <options> and runs algebraic simplification on
3301 // the computation. Returns a string description of the result of
3302 // simplification.
3303 auto build_and_simplify = [&]() -> string {
3304 HloComputation::Builder b(TestName());
3305
3306 Window window;
3307 auto* f_dim_1 = window.add_dimensions();
3308 f_dim_1->set_size(options.f_height);
3309 f_dim_1->set_stride(options.row_stride);
3310 f_dim_1->set_padding_low(options.row_padding);
3311 f_dim_1->set_padding_high(options.row_padding);
3312 f_dim_1->set_window_dilation(1);
3313 f_dim_1->set_base_dilation(1);
3314 auto* f_dim_2 = window.add_dimensions();
3315 f_dim_2->set_size(options.f_width);
3316 f_dim_2->set_stride(options.col_stride);
3317 f_dim_2->set_padding_low(options.col_padding);
3318 f_dim_2->set_padding_high(options.col_padding);
3319 f_dim_2->set_window_dilation(1);
3320 f_dim_2->set_base_dilation(1);
3321
3322 ConvolutionDimensionNumbers dnums;
3323 std::vector<int64> in_dims;
3324 int in_channel_idx = -1;
3325 // filled in later
3326 dnums.add_input_spatial_dimensions(-1);
3327 dnums.add_output_spatial_dimensions(-1);
3328 dnums.add_input_spatial_dimensions(-1);
3329 dnums.add_output_spatial_dimensions(-1);
3330 for (int i = 0; i < strlen(options.dim_order); ++i) {
3331 char ch = options.dim_order[i];
3332 if (ch == 'N') {
3333 dnums.set_input_batch_dimension(i);
3334 dnums.set_output_batch_dimension(i);
3335 in_dims.push_back(options.in_batch);
3336 } else if (ch == 'H') {
3337 dnums.set_input_spatial_dimensions(0, i);
3338 dnums.set_output_spatial_dimensions(0, i);
3339 in_dims.push_back(options.in_height);
3340 } else if (ch == 'W') {
3341 dnums.set_input_spatial_dimensions(1, i);
3342 dnums.set_output_spatial_dimensions(1, i);
3343 in_dims.push_back(options.in_width);
3344 } else if (ch == 'C') {
3345 dnums.set_input_feature_dimension(i);
3346 dnums.set_output_feature_dimension(i);
3347 in_dims.push_back(options.in_channels);
3348 in_channel_idx = i;
3349 }
3350 }
3351
3352 std::vector<int64> f_dims;
3353 dnums.add_kernel_spatial_dimensions(-1); // filled in later
3354 dnums.add_kernel_spatial_dimensions(-1); // filled in later
3355 for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
3356 char ch = options.kernel_dim_order[i];
3357 if (ch == 'H') {
3358 dnums.set_kernel_spatial_dimensions(0, i);
3359 f_dims.push_back(options.f_height);
3360 } else if (ch == 'W') {
3361 dnums.set_kernel_spatial_dimensions(1, i);
3362 f_dims.push_back(options.f_width);
3363 } else if (ch == 'I') {
3364 dnums.set_kernel_input_feature_dimension(i);
3365 f_dims.push_back(options.in_channels);
3366 } else if (ch == 'O') {
3367 dnums.set_kernel_output_feature_dimension(i);
3368 f_dims.push_back(options.f_output_channels);
3369 }
3370 }
3371
3372 auto out_dims = in_dims;
3373 out_dims[in_channel_idx] = options.f_output_channels;
3374
3375 auto make_shape = [](absl::Span<const int64> dims,
3376 bool minor_to_major_layout) {
3377 if (minor_to_major_layout) {
3378 return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
3379 } else {
3380 return ShapeUtil::MakeShape(F32, dims);
3381 }
3382 };
3383 auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
3384 auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
3385 auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
3386
3387 HloInstruction* input =
3388 b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
3389 HloInstruction* filter =
3390 b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
3391
3392 b.AddInstruction(HloInstruction::CreateConvolve(
3393 out_shape, input, filter,
3394 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
3395 DefaultPrecisionConfig(2)));
3396
3397 // TODO(b/80488902): verify this module.
3398 auto module = CreateNewUnverifiedModule();
3399 auto* computation = module->AddEntryComputation(b.Build());
3400
3401 AlgebraicSimplifierOptions simplifier_options;
3402 simplifier_options.set_is_layout_sensitive(true);
3403 AlgebraicSimplifier simplifier(simplifier_options);
3404 if (!simplifier.Run(module.get()).ValueOrDie()) {
3405 return "NO_CHANGE";
3406 }
3407 auto* root = computation->root_instruction();
3408 if (root->opcode() == HloOpcode::kBitcast &&
3409 root->operand(0)->opcode() == HloOpcode::kDot) {
3410 auto lhs_shape = root->operand(0)->operand(0)->shape();
3411 auto rhs_shape = root->operand(0)->operand(1)->shape();
3412 return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
3413 absl::StrJoin(rhs_shape.dimensions(), "x"));
3414 }
3415 return "UNEXPECTED CHANGE";
3416 };
3417
3418 // Default options are the simplest case and succeed.
3419 options.Reset();
3420 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3421
3422 // Swapping dim spatial and batch order works.
3423 options.Reset().dim_order = "NWHC";
3424 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3425 options.Reset().dim_order = "WHNC";
3426 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3427 // Channel dimension earlier fails.
3428 options.Reset().dim_order = "HWCN";
3429 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3430 options.Reset().dim_order = "CHWN";
3431 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3432
3433 // Filtering dims spatial dims can be anywhere, since they are 1x1.
3434 options.Reset().kernel_dim_order = "WHIO";
3435 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3436 options.Reset().kernel_dim_order = "IWOH";
3437 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3438 options.Reset().kernel_dim_order = "IWHO";
3439 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3440 // But moving output channel before input channel fails.
3441 options.Reset().kernel_dim_order = "HWOI";
3442 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3443 options.Reset().kernel_dim_order = "WHOI";
3444 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3445 options.Reset().kernel_dim_order = "OWIH";
3446 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3447 options.Reset().kernel_dim_order = "OWHI";
3448 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3449
3450 // Combine different dim and kernel dim orders.
3451 options.Reset().kernel_dim_order = "IWHO";
3452 options.dim_order = "WHNC";
3453 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3454
3455 // Test invalid cases from wrong filter size, strides, or padding.
3456 options.Reset().f_width = 2;
3457 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3458 options.Reset().f_height = 2;
3459 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3460 options.Reset().row_stride = 2;
3461 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3462 options.Reset().col_stride = 2;
3463 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3464 options.Reset().col_padding = 1;
3465 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3466 options.Reset().row_padding = 1;
3467 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3468
3469 // The default dim_order is "NHWC". Col-major layout makes C the most major.
3470 options.Reset().input_minor_to_major_layout = true;
3471 options.output_minor_to_major_layout = true;
3472 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3473
3474 // The input and output have different layouts.
3475 options.Reset().input_minor_to_major_layout = true;
3476 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3477
3478 // C is most minor, and I is more major than O.
3479 options.Reset().input_minor_to_major_layout = true;
3480 options.filter_minor_to_major_layout = true;
3481 options.output_minor_to_major_layout = true;
3482 options.dim_order = "CHWN";
3483 options.kernel_dim_order = "OIHW";
3484 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
3485
3486 // C is not the most minor dimension.
3487 options.Reset().input_minor_to_major_layout = true;
3488 options.filter_minor_to_major_layout = true;
3489 options.output_minor_to_major_layout = true;
3490 options.dim_order = "HWNC";
3491 options.kernel_dim_order = "OIHW";
3492 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3493
3494 // I is more minor than O.
3495 options.Reset().input_minor_to_major_layout = true;
3496 options.filter_minor_to_major_layout = true;
3497 options.output_minor_to_major_layout = true;
3498 options.dim_order = "CHWN";
3499 options.kernel_dim_order = "IOHW";
3500 EXPECT_EQ("NO_CHANGE", build_and_simplify());
3501 }
3502
3503 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
3504 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)3505 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
3506 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
3507 HloComputation::Builder builder(TestName());
3508 HloInstruction* scalar_param = builder.AddInstruction(
3509 HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
3510
3511 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
3512 HloInstruction* broadcast = builder.AddInstruction(
3513 HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
3514
3515 Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
3516 HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
3517 slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
3518
3519 auto module = CreateNewVerifiedModule();
3520 auto computation = module->AddEntryComputation(builder.Build());
3521
3522 HloInstruction* root = computation->root_instruction();
3523 EXPECT_EQ(root, slice);
3524 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
3525
3526 AlgebraicSimplifier simplifier(default_options_);
3527
3528 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3529
3530 // Running simplification again should not result in any further changes.
3531 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3532 EXPECT_THAT(computation->root_instruction(),
3533 GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
3534 .WithShapeEqualTo(&slice_shape)));
3535 }
3536
3537 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
3538 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)3539 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
3540 HloComputation::Builder builder(TestName());
3541 HloInstruction* forty_two = builder.AddInstruction(
3542 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
3543
3544 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
3545 HloInstruction* broadcast = builder.AddInstruction(
3546 HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
3547
3548 HloInstruction* transpose =
3549 builder.AddInstruction(HloInstruction::CreateTranspose(
3550 ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
3551
3552 Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
3553 HloInstruction* reshape = builder.AddInstruction(
3554 HloInstruction::CreateReshape(reshape_shape, transpose));
3555
3556 auto module = CreateNewVerifiedModule();
3557 auto computation = module->AddEntryComputation(builder.Build());
3558
3559 HloInstruction* root = computation->root_instruction();
3560 EXPECT_EQ(root, reshape);
3561 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
3562
3563 AlgebraicSimplifier simplifier(default_options_);
3564 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3565 EXPECT_THAT(computation->root_instruction(),
3566 GmockMatch(m::Broadcast(m::Op().Is(forty_two))
3567 .WithShapeEqualTo(&reshape_shape)));
3568 }
3569
3570 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)3571 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
3572 // TODO(b/80488902): verify this module.
3573 auto module = CreateNewUnverifiedModule();
3574 HloComputation::Builder builder(TestName());
3575
3576 // Create operand to the pad.
3577 HloInstruction* operand =
3578 builder.AddInstruction(HloInstruction::CreateParameter(
3579 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
3580
3581 // Create the pad.
3582 PaddingConfig padding = MakeNoPaddingConfig(4);
3583 padding.mutable_dimensions(1)->set_edge_padding_low(1);
3584 padding.mutable_dimensions(3)->set_edge_padding_high(2);
3585
3586 HloInstruction* pad_value = builder.AddInstruction(
3587 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3588 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3589 ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
3590
3591 // Create add computation.
3592 HloComputation* add_computation = nullptr;
3593 {
3594 HloComputation::Builder builder(TestName() + ".add");
3595 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
3596 HloInstruction* p0 = builder.AddInstruction(
3597 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
3598 HloInstruction* p1 = builder.AddInstruction(
3599 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
3600 builder.AddInstruction(
3601 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
3602 add_computation = module->AddEmbeddedComputation(builder.Build());
3603 }
3604
3605 // Create the reduce-window.
3606 Window window;
3607 for (int64 i = 0; i < pad->shape().rank(); ++i) {
3608 auto* dim = window.add_dimensions();
3609 dim->set_size(1);
3610 dim->set_padding_low(10);
3611 dim->set_padding_high(100);
3612 dim->set_window_dilation(1);
3613 dim->set_base_dilation(1);
3614 }
3615 const Shape reduce_window_shape =
3616 ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
3617 HloInstruction* reduce_init_value = builder.AddInstruction(
3618 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3619 HloInstruction* reduce_window =
3620 builder.AddInstruction(HloInstruction::CreateReduceWindow(
3621 reduce_window_shape, pad, reduce_init_value, window,
3622 add_computation));
3623
3624 // Build the computation and run the simplifier.
3625 auto computation = module->AddEntryComputation(builder.Build());
3626 HloInstruction* root = computation->root_instruction();
3627 EXPECT_EQ(root, reduce_window);
3628 AlgebraicSimplifier simplifier(default_options_);
3629 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3630
3631 // Running simplification again should not result in any further changes.
3632 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3633
3634 // Verify the result
3635 root = computation->root_instruction();
3636 EXPECT_THAT(root,
3637 GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant())));
3638 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
3639 << ShapeUtil::HumanString(root->shape()) << " vs "
3640 << ShapeUtil::HumanString(reduce_window_shape);
3641 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
3642 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
3643 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
3644 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
3645 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
3646 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
3647 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
3648 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
3649 }
3650
3651 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
3652 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)3653 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
3654 // TODO(b/80488902): verify this module.
3655 auto module = CreateNewUnverifiedModule();
3656 HloComputation::Builder builder(TestName());
3657
3658 // Create operand to the pad.
3659 HloInstruction* parameter =
3660 builder.AddInstruction(HloInstruction::CreateParameter(
3661 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
3662
3663 // Create the pad.
3664 PaddingConfig padding = MakeNoPaddingConfig(4);
3665 padding.mutable_dimensions(1)->set_edge_padding_low(1);
3666 padding.mutable_dimensions(3)->set_edge_padding_high(2);
3667
3668 HloInstruction* pad_value = builder.AddInstruction(
3669 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3670 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3671 ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
3672
3673 HloInstruction* convert =
3674 builder.AddInstruction(HloInstruction::CreateConvert(
3675 ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
3676
3677 // Create add computation.
3678 HloComputation* add_computation = nullptr;
3679 {
3680 HloComputation::Builder builder(TestName() + ".add");
3681 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
3682 HloInstruction* p0 = builder.AddInstruction(
3683 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
3684 HloInstruction* p1 = builder.AddInstruction(
3685 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
3686 builder.AddInstruction(
3687 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
3688 add_computation = module->AddEmbeddedComputation(builder.Build());
3689 }
3690
3691 // Create the reduce-window.
3692 Window window;
3693 for (int64 i = 0; i < pad->shape().rank(); ++i) {
3694 auto* dim = window.add_dimensions();
3695 dim->set_size(1);
3696 dim->set_padding_low(10);
3697 dim->set_padding_high(100);
3698 dim->set_window_dilation(1);
3699 dim->set_base_dilation(1);
3700 }
3701 const Shape reduce_window_shape =
3702 ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
3703 HloInstruction* reduce_init_value = builder.AddInstruction(
3704 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
3705 HloInstruction* reduce_window =
3706 builder.AddInstruction(HloInstruction::CreateReduceWindow(
3707 reduce_window_shape, convert, reduce_init_value, window,
3708 add_computation));
3709
3710 // Build the computation and run the simplifier.
3711 auto computation = module->AddEntryComputation(builder.Build());
3712 HloInstruction* root = computation->root_instruction();
3713 EXPECT_EQ(root, reduce_window);
3714 AlgebraicSimplifier simplifier(default_options_);
3715 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3716
3717 // Running simplification again should not result in any further changes.
3718 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3719
3720 // Verify the result
3721 root = computation->root_instruction();
3722 EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
3723 m::Constant())));
3724 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
3725 << ShapeUtil::HumanString(root->shape()) << " vs "
3726 << ShapeUtil::HumanString(reduce_window_shape);
3727 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
3728 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
3729 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
3730 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
3731 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
3732 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
3733 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
3734 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
3735 }
3736
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)3737 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
3738 HloComputation::Builder builder(TestName());
3739 const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
3740 HloInstruction* a =
3741 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
3742 builder.AddInstruction(
3743 HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
3744
3745 auto module = CreateNewVerifiedModule();
3746 auto computation = module->AddEntryComputation(builder.Build());
3747
3748 AlgebraicSimplifier simplifier(default_options_);
3749 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3750
3751 HloInstruction* root = computation->root_instruction();
3752 EXPECT_EQ(a, root);
3753 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
3754 }
3755
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)3756 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
3757 // Dots add computations to the parent module. Test that, when the HloModule's
3758 // computations are updated, then iterator invalidation doesn't occur
3759 // when running on subsequent computations.
3760 auto m = CreateNewVerifiedModule();
3761 Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
3762 HloComputation::Builder builder(TestName() + ".Dot");
3763 HloInstruction* x =
3764 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
3765 HloInstruction* y =
3766 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
3767 DotDimensionNumbers dot_dnums;
3768 dot_dnums.add_lhs_batch_dimensions(0);
3769 dot_dnums.add_rhs_batch_dimensions(0);
3770 builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
3771 DefaultPrecisionConfig(2)));
3772 std::unique_ptr<HloComputation> dot_computation(builder.Build());
3773
3774 HloComputation::Builder call_builder(TestName() + ".Call");
3775 HloInstruction* zero = call_builder.AddInstruction(
3776 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
3777 HloInstruction* one = call_builder.AddInstruction(
3778 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
3779 call_builder.AddInstruction(
3780 HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
3781
3782 m->AddEmbeddedComputation(std::move(dot_computation));
3783 m->AddEntryComputation(call_builder.Build());
3784 AlgebraicSimplifier simplifier(default_options_);
3785 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3786 }
3787
3788 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)3789 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
3790 auto m = CreateNewVerifiedModule();
3791 HloComputation::Builder builder(TestName());
3792 const float constant_scalar = 7.3f;
3793 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
3794 Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
3795 LiteralUtil::CreateR1<float>(constant_vector)};
3796 Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
3797 builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
3798
3799 auto computation = m->AddEntryComputation(builder.Build());
3800
3801 AlgebraicSimplifier simplifier(default_options_);
3802 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3803 EXPECT_THAT(computation->root_instruction(),
3804 GmockMatch(m::Tuple(m::Constant(), m::Constant())));
3805 }
3806
3807 // A dynamic-slice is trivial if its start indices are all zeroes and the size
3808 // of its input equals the size of its output. In this case, the dynamic slice
3809 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)3810 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
3811 auto m = CreateNewVerifiedModule();
3812 HloComputation::Builder builder(TestName());
3813
3814 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
3815 std::vector<HloInstruction*> params;
3816 for (int i = 0; i < 3; ++i) {
3817 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
3818 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
3819 }
3820 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
3821 shape,
3822 builder.AddInstruction(
3823 HloInstruction::CreateParameter(0, shape, "slice_from")),
3824 params,
3825 /*slice_sizes=*/{10, 100, 1000}));
3826
3827 auto computation = m->AddEntryComputation(builder.Build());
3828 AlgebraicSimplifier simplifier(default_options_);
3829 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3830 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
3831 }
3832
3833 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
3834 // size of its "update" equals the size of its output. In this case, the
3835 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)3836 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
3837 auto m = CreateNewVerifiedModule();
3838 HloComputation::Builder builder(TestName());
3839
3840 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
3841 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
3842
3843 std::vector<HloInstruction*> slice_indices, update_indices;
3844 for (int i = 0; i < 3; ++i) {
3845 slice_indices.push_back(
3846 builder.AddInstruction(HloInstruction::CreateParameter(
3847 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
3848 update_indices.push_back(
3849 builder.AddInstruction(HloInstruction::CreateParameter(
3850 i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
3851 }
3852 HloInstruction* slice =
3853 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
3854 slice_shape,
3855 builder.AddInstruction(
3856 HloInstruction::CreateParameter(0, full_shape, "slice_from")),
3857 slice_indices,
3858 /*slice_sizes=*/{10, 1, 1000}));
3859
3860 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3861 slice_shape,
3862 builder.AddInstruction(
3863 HloInstruction::CreateParameter(4, slice_shape, "to_update")),
3864 slice, update_indices));
3865
3866 auto computation = m->AddEntryComputation(builder.Build());
3867 AlgebraicSimplifier simplifier(default_options_);
3868 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3869 EXPECT_THAT(computation->root_instruction(),
3870 GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
3871 m::Parameter(), m::Parameter())));
3872 }
3873
3874 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)3875 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
3876 auto m = CreateNewVerifiedModule();
3877 HloComputation::Builder builder(TestName());
3878 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
3879 HloInstruction* input_array = builder.AddInstruction(
3880 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
3881 HloInstruction* inner_bcast = builder.AddInstruction(
3882 HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
3883 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
3884 builder.AddInstruction(
3885 HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
3886
3887 auto computation = m->AddEntryComputation(builder.Build());
3888 HloInstruction* root = computation->root_instruction();
3889 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3890 AlgebraicSimplifier simplifier(default_options_);
3891 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3892 root = computation->root_instruction();
3893 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3894 EXPECT_THAT(root->dimensions(), ElementsAre(2));
3895 }
3896
3897 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)3898 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
3899 auto m = CreateNewVerifiedModule();
3900 HloComputation::Builder builder(TestName());
3901 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
3902 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
3903 HloInstruction* param0 = builder.AddInstruction(
3904 HloInstruction::CreateParameter(0, r2f32, "param0"));
3905 // The initial dimensions go to places 0 and 2 in the 3-dim array,
3906 // and to places 1 and 3 in the 4-dim array,
3907 HloInstruction* inner_bcast = builder.AddInstruction(
3908 HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
3909 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
3910 builder.AddInstruction(
3911 HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
3912
3913 auto computation = m->AddEntryComputation(builder.Build());
3914 HloInstruction* root = computation->root_instruction();
3915 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3916 AlgebraicSimplifier simplifier(default_options_);
3917 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3918 root = computation->root_instruction();
3919 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
3920 EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
3921 }
3922
3923 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)3924 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
3925 auto m = CreateNewVerifiedModule();
3926 HloComputation::Builder builder(TestName());
3927 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
3928 HloInstruction* iota =
3929 builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
3930 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
3931 builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
3932
3933 auto computation = m->AddEntryComputation(builder.Build());
3934 HloInstruction* root = computation->root_instruction();
3935 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3936 AlgebraicSimplifier simplifier(default_options_);
3937 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3938 root = computation->root_instruction();
3939 EXPECT_THAT(root, GmockMatch(m::Iota()));
3940 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
3941 }
3942
3943 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)3944 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
3945 auto m = CreateNewVerifiedModule();
3946 HloComputation::Builder builder(TestName());
3947 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
3948 HloInstruction* iota =
3949 builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
3950 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
3951 builder.AddInstruction(
3952 HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
3953
3954 auto computation = m->AddEntryComputation(builder.Build());
3955 HloInstruction* root = computation->root_instruction();
3956 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
3957 AlgebraicSimplifier simplifier(default_options_);
3958 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3959 root = computation->root_instruction();
3960 EXPECT_THAT(root, GmockMatch(m::Iota()));
3961 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
3962 }
3963
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)3964 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
3965 const char* hlo_string = R"(
3966 HloModule module
3967
3968 ENTRY test {
3969 param = f32[3,4] parameter(0)
3970 constant = f32[] constant(0.0)
3971 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
3972 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
3973 }
3974 )";
3975 TF_ASSERT_OK_AND_ASSIGN(auto module,
3976 ParseAndReturnVerifiedModule(hlo_string));
3977
3978 AlgebraicSimplifierOptions options;
3979 AlgebraicSimplifier simplifier(options);
3980 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3981 auto root = module->entry_computation()->root_instruction();
3982 EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
3983 }
3984
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)3985 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
3986 const char* hlo_string = R"(
3987 HloModule module
3988
3989 ENTRY test {
3990 param = f32[3,4] parameter(0)
3991 constant = f32[] constant(0.0)
3992 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
3993 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
3994 }
3995 )";
3996 TF_ASSERT_OK_AND_ASSIGN(auto module,
3997 ParseAndReturnVerifiedModule(hlo_string));
3998
3999 AlgebraicSimplifierOptions options;
4000 AlgebraicSimplifier simplifier(options);
4001 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4002 auto root = module->entry_computation()->root_instruction();
4003 EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
4004 }
4005
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)4006 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
4007 const char* hlo_string = R"(
4008 HloModule module
4009
4010 ENTRY test {
4011 param = f32[3,4] parameter(0)
4012 constant = f32[] constant(0.0)
4013 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4014 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
4015 }
4016 )";
4017 TF_ASSERT_OK_AND_ASSIGN(auto module,
4018 ParseAndReturnVerifiedModule(hlo_string));
4019
4020 AlgebraicSimplifierOptions options;
4021 AlgebraicSimplifier simplifier(options);
4022 EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4023 }
4024
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)4025 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
4026 const char* hlo_string = R"(
4027 HloModule module
4028
4029 ENTRY test {
4030 param = f32[3,4] parameter(0)
4031 constant = f32[] constant(0.0)
4032 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4033 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
4034 }
4035 )";
4036 TF_ASSERT_OK_AND_ASSIGN(auto module,
4037 ParseAndReturnVerifiedModule(hlo_string));
4038
4039 AlgebraicSimplifierOptions options;
4040 AlgebraicSimplifier simplifier(options);
4041 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4042 auto root = module->entry_computation()->root_instruction();
4043 EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
4044 }
4045
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)4046 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
4047 const char* hlo_string = R"(
4048 HloModule module
4049
4050 ENTRY test {
4051 param = f32[1,1] parameter(0)
4052 constant = f32[] constant(0.0)
4053 pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
4054 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
4055 }
4056 )";
4057 TF_ASSERT_OK_AND_ASSIGN(auto module,
4058 ParseAndReturnVerifiedModule(hlo_string));
4059
4060 AlgebraicSimplifierOptions options;
4061 AlgebraicSimplifier simplifier(options);
4062 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4063 auto root = module->entry_computation()->root_instruction();
4064 EXPECT_THAT(root, GmockMatch(m::Parameter()));
4065 }
4066
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)4067 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
4068 const char* hlo_string = R"(
4069 HloModule module
4070
4071 ENTRY entry () -> f32[1]{0} {
4072 constant.val = f32[] constant(4)
4073 constant.pad = f32[] constant(-7)
4074 reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
4075 pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
4076 slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
4077 ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
4078 }
4079 )";
4080 TF_ASSERT_OK_AND_ASSIGN(auto module,
4081 ParseAndReturnVerifiedModule(hlo_string));
4082
4083 AlgebraicSimplifierOptions options;
4084 AlgebraicSimplifier simplifier(options);
4085 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4086 auto root = module->entry_computation()->root_instruction();
4087 EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
4088 }
4089
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)4090 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
4091 const char* hlo_string = R"(
4092 HloModule module
4093
4094 ENTRY test {
4095 param.0 = f32[2] parameter(0)
4096 param.1 = f32[1] parameter(1)
4097 param.2 = f32[3] parameter(2)
4098 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4099 ROOT slice = f32[1] slice(concat), slice={[2:3]}
4100 }
4101 )";
4102 TF_ASSERT_OK_AND_ASSIGN(auto module,
4103 ParseAndReturnVerifiedModule(hlo_string));
4104
4105 AlgebraicSimplifierOptions options;
4106 AlgebraicSimplifier simplifier(options);
4107 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4108 auto root = module->entry_computation()->root_instruction();
4109 EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
4110 }
4111
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)4112 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
4113 const char* hlo_string = R"(
4114 HloModule module
4115
4116 ENTRY test {
4117 param.0 = f32[2] parameter(0)
4118 param.1 = f32[1] parameter(1)
4119 param.2 = f32[3] parameter(2)
4120 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4121 ROOT slice = f32[1] slice(concat), slice={[4:5]}
4122 }
4123 )";
4124 TF_ASSERT_OK_AND_ASSIGN(auto module,
4125 ParseAndReturnVerifiedModule(hlo_string));
4126
4127 AlgebraicSimplifierOptions options;
4128 AlgebraicSimplifier simplifier(options);
4129 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4130 auto root = module->entry_computation()->root_instruction();
4131 EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
4132 EXPECT_EQ(root->slice_starts(0), 1);
4133 EXPECT_EQ(root->slice_limits(0), 2);
4134 }
4135
TEST_F(AlgebraicSimplifierTest,NegateNegate)4136 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
4137 const char* hlo_string = R"(
4138 HloModule module
4139
4140 ENTRY test {
4141 param.0 = f32[2] parameter(0)
4142 neg.0 = f32[2] negate(param.0)
4143 ROOT neg.1 = f32[2] negate(neg.0)
4144 }
4145 )";
4146 TF_ASSERT_OK_AND_ASSIGN(auto module,
4147 ParseAndReturnVerifiedModule(hlo_string));
4148
4149 AlgebraicSimplifierOptions options;
4150 AlgebraicSimplifier simplifier(options);
4151 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4152 auto root = module->entry_computation()->root_instruction();
4153 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
4154 }
4155
TEST_F(AlgebraicSimplifierTest,NotNot)4156 TEST_F(AlgebraicSimplifierTest, NotNot) {
4157 const char* hlo_string = R"(
4158 HloModule module
4159
4160 ENTRY test {
4161 param.0 = pred[2] parameter(0)
4162 not.0 = pred[2] not(param.0)
4163 ROOT not.1 = pred[2] not(not.0)
4164 }
4165 )";
4166 TF_ASSERT_OK_AND_ASSIGN(auto module,
4167 ParseAndReturnVerifiedModule(hlo_string));
4168
4169 AlgebraicSimplifierOptions options;
4170 AlgebraicSimplifier simplifier(options);
4171 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4172 auto root = module->entry_computation()->root_instruction();
4173 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
4174 }
4175
4176 struct PadReduceWindowEffectiveBroadcastCase {
4177 std::vector<int64> input_spatials;
4178 std::vector<int64> symmetric_pad_spatials;
4179 std::vector<int64> reduce_window_spatials;
4180 // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
4181 //
4182 // This doesn't test any different functionality but is useful for making sure
4183 // kBroadcast nodes are well formed.
4184 bool prepend_a;
4185 bool should_become_broadcast;
4186
ToTestCaseNamexla::__anonac242c730111::PadReduceWindowEffectiveBroadcastCase4187 string ToTestCaseName() const {
4188 return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
4189 absl::StrJoin(symmetric_pad_spatials, ","), ";",
4190 absl::StrJoin(reduce_window_spatials, ","), ";",
4191 prepend_a, ";", should_become_broadcast);
4192 }
4193 };
4194
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)4195 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
4196 *os << c.ToTestCaseName();
4197 }
4198
4199 class PadReduceWindowEffectiveBroadcastTest
4200 : public AlgebraicSimplifierTest,
4201 public ::testing::WithParamInterface<
4202 PadReduceWindowEffectiveBroadcastCase> {};
4203
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)4204 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
4205 auto m = CreateNewVerifiedModule();
4206 const auto& param = GetParam();
4207
4208 // a and b are parallel bounds we can either turn into a B F S0 S1 or
4209 // `B S0 S1 F` kind of pattern.
4210 auto decorate_spatials = [¶m](absl::Span<const int64> spatials, int64 a,
4211 int64 b) {
4212 std::vector<int64> result;
4213 if (param.prepend_a) {
4214 result.push_back(a);
4215 }
4216 for (int64 s : spatials) {
4217 result.push_back(s);
4218 }
4219 if (!param.prepend_a) {
4220 result.push_back(a);
4221 }
4222 result.push_back(b);
4223 return result;
4224 };
4225
4226 HloComputation::Builder builder(TestName());
4227 const Shape input_shape = ShapeUtil::MakeShape(
4228 F32, decorate_spatials(param.input_spatials, 128, 2048));
4229 HloInstruction* input = builder.AddInstruction(
4230 HloInstruction::CreateParameter(0, input_shape, "input"));
4231
4232 PaddingConfig padding = window_util::MakeSymmetricPadding(
4233 decorate_spatials(param.symmetric_pad_spatials, 0, 0));
4234 TF_ASSERT_OK_AND_ASSIGN(
4235 const Shape pad_shape,
4236 ShapeInference::InferPadShape(input->shape(),
4237 ShapeUtil::MakeShape(F32, {}), padding));
4238 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4239 pad_shape, input,
4240 builder.AddInstruction(
4241 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
4242 padding));
4243
4244 HloComputation* add_computation = nullptr;
4245 {
4246 HloComputation::Builder builder(TestName() + ".add");
4247 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4248 HloInstruction* p0 = builder.AddInstruction(
4249 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4250 HloInstruction* p1 = builder.AddInstruction(
4251 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4252 builder.AddInstruction(
4253 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4254 add_computation = m->AddEmbeddedComputation(builder.Build());
4255 }
4256
4257 Window window = window_util::MakeWindow(
4258 decorate_spatials(param.reduce_window_spatials, 1, 1));
4259 auto zero = builder.AddInstruction(
4260 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
4261 TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
4262 ShapeInference::InferReduceWindowShape(
4263 pad->shape(), zero->shape(), window,
4264 add_computation->ComputeProgramShape()));
4265 builder.AddInstruction(HloInstruction::CreateReduceWindow(
4266 output_shape, pad, zero, window, add_computation));
4267
4268 auto computation = m->AddEntryComputation(builder.Build());
4269 AlgebraicSimplifier simplifier(default_options_);
4270 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4271 ASSERT_TRUE(run_successful);
4272
4273 EXPECT_TRUE(
4274 ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
4275
4276 if (param.should_become_broadcast) {
4277 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
4278 } else {
4279 EXPECT_THAT(computation->root_instruction(),
4280 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
4281 }
4282 }
4283
4284 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()4285 PadReduceWindowEffectiveBroadcastCases() {
4286 static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
4287 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
4288 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
4289 /*should_become_broadcast=*/true}, //
4290 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
4291 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
4292 /*should_become_broadcast=*/true}, //
4293 {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
4294 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
4295 /*should_become_broadcast=*/false}, //
4296 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
4297 /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
4298 /*should_become_broadcast=*/false}, //
4299 {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
4300 /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
4301 /*should_become_broadcast=*/false}, //
4302 };
4303 return *cases;
4304 }
4305
4306 INSTANTIATE_TEST_SUITE_P(
4307 PadReduceWindowEffectiveBroadcastInstantiation,
4308 PadReduceWindowEffectiveBroadcastTest,
4309 ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
4310
4311 class BatchDotStrengthReductionTest
4312 : public AlgebraicSimplifierTest,
4313 public ::testing::WithParamInterface<
4314 ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)4315 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
4316 auto module = CreateNewVerifiedModule();
4317 int m, k, n;
4318 PrimitiveType element_type;
4319 std::tie(m, k, n, element_type) = GetParam();
4320 std::vector<int64> lhs_dims = {1, 3, 5};
4321 std::vector<int64> rhs_dims = lhs_dims;
4322 std::vector<int64> output_dims = lhs_dims;
4323 if (m > 0) {
4324 lhs_dims.push_back(m);
4325 output_dims.push_back(m);
4326 }
4327 if (k > 0) {
4328 lhs_dims.push_back(k);
4329 rhs_dims.push_back(k);
4330 }
4331 if (n > 0) {
4332 rhs_dims.push_back(n);
4333 output_dims.push_back(n);
4334 }
4335 Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
4336 Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
4337 Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
4338 HloComputation::Builder builder(TestName());
4339
4340 auto lhs = builder.AddInstruction(
4341 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
4342 auto rhs = builder.AddInstruction(
4343 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
4344 DotDimensionNumbers dot_dnums;
4345 dot_dnums.add_lhs_batch_dimensions(0);
4346 dot_dnums.add_lhs_batch_dimensions(1);
4347 dot_dnums.add_lhs_batch_dimensions(2);
4348 dot_dnums.add_rhs_batch_dimensions(0);
4349 dot_dnums.add_rhs_batch_dimensions(1);
4350 dot_dnums.add_rhs_batch_dimensions(2);
4351 if (k > 0) {
4352 dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
4353 dot_dnums.add_rhs_contracting_dimensions(3);
4354 }
4355 builder.AddInstruction(HloInstruction::CreateDot(
4356 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4357 auto computation = module->AddEntryComputation(builder.Build());
4358 AlgebraicSimplifier simplifier(default_options_);
4359 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
4360 const bool dot_should_be_transformed =
4361 m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
4362 EXPECT_EQ(changed, dot_should_be_transformed);
4363 bool has_no_dot = true;
4364 for (const auto& hlo : computation->instructions()) {
4365 if (hlo->opcode() == HloOpcode::kDot) {
4366 has_no_dot = false;
4367 break;
4368 }
4369 }
4370 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
4371 }
4372
4373 INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation,
4374 BatchDotStrengthReductionTest,
4375 ::testing::Combine(::testing::Values(-1, 1, 2),
4376 ::testing::Values(-1, 1, 2),
4377 ::testing::Values(-1, 1, 2),
4378 ::testing::Values(F32, BF16)));
4379
4380 class DotStrengthReductionTest
4381 : public AlgebraicSimplifierTest,
4382 public ::testing::WithParamInterface<
4383 ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)4384 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
4385 auto module = CreateNewVerifiedModule();
4386 int m, k, n;
4387 bool transpose_lhs, transpose_rhs;
4388 PrimitiveType element_type;
4389 std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
4390
4391 Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
4392 Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
4393 Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
4394 Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
4395 Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
4396 HloComputation::Builder builder(TestName());
4397
4398 auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
4399 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
4400 if (transpose_lhs) {
4401 lhs = builder.AddInstruction(
4402 HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
4403 }
4404 auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
4405 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
4406 if (transpose_rhs) {
4407 rhs = builder.AddInstruction(
4408 HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
4409 }
4410 DotDimensionNumbers dot_dnums;
4411 dot_dnums.add_lhs_contracting_dimensions(1);
4412 dot_dnums.add_rhs_contracting_dimensions(0);
4413 builder.AddInstruction(HloInstruction::CreateDot(
4414 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4415 auto computation = module->AddEntryComputation(builder.Build());
4416 AlgebraicSimplifier simplifier(default_options_);
4417 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
4418 const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
4419 const bool computation_should_be_modified =
4420 dot_should_be_transformed || (transpose_lhs && transpose_rhs);
4421 EXPECT_EQ(changed, computation_should_be_modified);
4422 bool has_no_dot = true;
4423 for (const auto& hlo : computation->instructions()) {
4424 if (hlo->opcode() == HloOpcode::kDot) {
4425 has_no_dot = false;
4426 break;
4427 }
4428 }
4429 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
4430 }
4431
4432 INSTANTIATE_TEST_SUITE_P(
4433 DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
4434 ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
4435 ::testing::Values(1, 2), ::testing::Bool(),
4436 ::testing::Bool(), ::testing::Values(F32, BF16)));
4437
4438 struct DotOfConcatTestSpec {
4439 int64 m;
4440 int64 k;
4441 int64 n;
4442 };
4443
4444 class DotOfConcatSimplificationTest
4445 : public AlgebraicSimplifierTest,
4446 public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
4447
4448 // Test that we transform
4449 // dot(const, concat(A, B, C))
4450 // to
4451 // add(dot(const_0, A), dot(const_1, B), dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)4452 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
4453 auto m = CreateNewVerifiedModule();
4454 HloComputation::Builder builder(TestName());
4455
4456 DotOfConcatTestSpec spec = GetParam();
4457
4458 ASSERT_GE(spec.k, 3);
4459
4460 int64 k0 = spec.k / 3;
4461 int64 k1 = spec.k / 3;
4462 int64 k2 = spec.k - k0 - k1;
4463
4464 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
4465 auto* lhs = builder.AddInstruction(
4466 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4467 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
4468
4469 Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
4470 Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
4471 Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
4472
4473 HloInstruction* rhs0 = builder.AddInstruction(
4474 HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
4475 HloInstruction* rhs1 = builder.AddInstruction(
4476 HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
4477 HloInstruction* rhs2 = builder.AddInstruction(
4478 HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
4479
4480 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
4481 HloInstruction* rhs = builder.AddInstruction(
4482 HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
4483
4484 DotDimensionNumbers dot_dnums;
4485 dot_dnums.add_lhs_contracting_dimensions(1);
4486 dot_dnums.add_rhs_contracting_dimensions(0);
4487
4488 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
4489 builder.AddInstruction(HloInstruction::CreateDot(
4490 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4491
4492 auto computation = m->AddEntryComputation(builder.Build());
4493 AlgebraicSimplifier simplifier(default_options_);
4494 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4495 ASSERT_TRUE(run_successful);
4496
4497 EXPECT_TRUE(
4498 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4499
4500 auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
4501 auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
4502 auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
4503 EXPECT_THAT(
4504 computation->root_instruction(),
4505 GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
4506 }
4507
4508 // Test that we transform
4509 // dot(concat(A, B, C), const)
4510 // to
4511 // add(dot(A, const_0), dot(B, const_1), dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)4512 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
4513 auto m = CreateNewVerifiedModule();
4514 HloComputation::Builder builder(TestName());
4515
4516 DotOfConcatTestSpec spec = GetParam();
4517
4518 ASSERT_GE(spec.k, 4);
4519
4520 int64 k0 = spec.k / 4;
4521 int64 k1 = spec.k / 4;
4522 int64 k2 = spec.k / 4;
4523 int64 k3 = spec.k - k0 - k1 - k2;
4524
4525 Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
4526 Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
4527 Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
4528 Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
4529
4530 HloInstruction* lhs0 = builder.AddInstruction(
4531 HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
4532 HloInstruction* lhs1 = builder.AddInstruction(
4533 HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
4534 HloInstruction* lhs2 = builder.AddInstruction(
4535 HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
4536 HloInstruction* lhs3 = builder.AddInstruction(
4537 HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
4538
4539 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
4540 HloInstruction* lhs =
4541 builder.AddInstruction(HloInstruction::CreateConcatenate(
4542 lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
4543
4544 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
4545 auto* rhs = builder.AddInstruction(
4546 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4547 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
4548
4549 DotDimensionNumbers dot_dnums;
4550 dot_dnums.add_lhs_contracting_dimensions(1);
4551 dot_dnums.add_rhs_contracting_dimensions(0);
4552
4553 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
4554 builder.AddInstruction(HloInstruction::CreateDot(
4555 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4556
4557 auto computation = m->AddEntryComputation(builder.Build());
4558 AlgebraicSimplifier simplifier(default_options_);
4559 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4560 ASSERT_TRUE(run_successful);
4561 EXPECT_TRUE(
4562 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4563
4564 auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
4565 auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
4566 auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
4567 auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
4568 EXPECT_THAT(
4569 computation->root_instruction(),
4570 GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
4571 match_dot_3)));
4572 }
4573
4574 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
4575 {/*m=*/3, /*k=*/9, /*n=*/3}, //
4576 {/*m=*/3, /*k=*/20, /*n=*/3}, //
4577 {/*m=*/1, /*k=*/18, /*n=*/5}, //
4578 {/*m=*/20, /*k=*/20, /*n=*/1}, //
4579 {/*m=*/1, /*k=*/16, /*n=*/1}, //
4580 };
4581
4582 // Test that DynamicUpdateSlice update param with any dimension equal to zero
4583 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)4584 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
4585 auto m = CreateNewVerifiedModule();
4586 HloComputation::Builder builder(TestName());
4587 const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
4588 HloInstruction* const operand = builder.AddInstruction(
4589 HloInstruction::CreateParameter(0, dslice_shape, "operand"));
4590 const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
4591 HloInstruction* const update = builder.AddInstruction(
4592 HloInstruction::CreateParameter(1, update_shape, "update"));
4593 HloInstruction* const start_indices = builder.AddInstruction(
4594 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
4595 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4596 dslice_shape, operand, update,
4597 std::initializer_list<HloInstruction*>({start_indices})));
4598 const HloComputation* const computation =
4599 m->AddEntryComputation(builder.Build());
4600
4601 AlgebraicSimplifier simplifier(default_options_);
4602 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4603 EXPECT_THAT(computation->root_instruction(), operand);
4604 }
4605
4606 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
4607 DotOfConcatSimplificationTest,
4608 ::testing::ValuesIn(kDotOfConcatTestSpecs));
4609
4610 struct DotOfGatherTestSpec {
4611 int64 m;
4612 int64 k;
4613 int64 n;
4614 int s; // start index for dynamic slice on the non-contracting dimension
4615 int64 lcd; // left contracting dimension
4616 int64 rcd; // right contracting dimension
4617 bool neg; // is negative testcase
4618 };
4619
4620 class DotOfGatherSimplificationTest
4621 : public AlgebraicSimplifierTest,
4622 public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
4623
4624 // input: dot(DS(ctA), ctB))
4625 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
4626 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
4627 // output: DS(dot(ctA, ctB))
4628 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)4629 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
4630 auto m = CreateNewVerifiedModule();
4631 HloComputation::Builder builder(TestName());
4632
4633 DotOfGatherTestSpec spec = GetParam();
4634
4635 ASSERT_LE(spec.s, spec.m);
4636
4637 // For negative tests, increase k of the dynamic slice argument to prevent the
4638 // optimization (constants ctA, ctB must have equal contracting dimensions).
4639 int64 k_increase = spec.neg ? 5 : 0;
4640 int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
4641 int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
4642 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
4643 auto* lhs = builder.AddInstruction(
4644 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4645 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
4646 /*cols=*/lhs_cols)));
4647
4648 int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
4649 int32 start_col = (spec.lcd == 0) ? spec.s : 0;
4650 std::vector<HloInstruction*> start_indices = {
4651 builder.AddInstruction(HloInstruction::CreateConstant(
4652 LiteralUtil::CreateR0<int32>(start_row))),
4653 builder.AddInstruction(HloInstruction::CreateConstant(
4654 LiteralUtil::CreateR0<int32>(start_col)))};
4655 int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
4656 int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
4657 std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
4658 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
4659 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4660 ds_shape, lhs, start_indices, slice_sizes));
4661
4662 int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
4663 int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
4664 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
4665 auto* rhs = builder.AddInstruction(
4666 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4667 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
4668 /*cols=*/rhs_cols)));
4669
4670 DotDimensionNumbers dot_dnums;
4671 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
4672 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
4673
4674 int64 dot_row_size = 1;
4675 int64 dot_col_size = spec.n;
4676 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
4677 builder.AddInstruction(HloInstruction::CreateDot(
4678 dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4679
4680 auto computation = m->AddEntryComputation(builder.Build());
4681 AlgebraicSimplifier simplifier(default_options_);
4682 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4683 ASSERT_TRUE(run_successful);
4684 EXPECT_TRUE(
4685 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4686
4687 if (spec.neg) {
4688 EXPECT_NE(computation->root_instruction()->opcode(),
4689 HloOpcode::kDynamicSlice);
4690 } else {
4691 EXPECT_THAT(computation->root_instruction(),
4692 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
4693 m::Constant(), m::Constant())));
4694 }
4695 }
4696
4697 // input: dot(ctA, DS(ctB))
4698 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
4699 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
4700 // output: DS(dot(ctA, ctB))
4701 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)4702 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
4703 auto m = CreateNewVerifiedModule();
4704 HloComputation::Builder builder(TestName());
4705
4706 DotOfGatherTestSpec spec = GetParam();
4707
4708 ASSERT_LE(spec.s, spec.n);
4709
4710 int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
4711 int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
4712 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
4713 auto* lhs = builder.AddInstruction(
4714 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4715 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
4716 /*cols=*/lhs_cols)));
4717
4718 // For negative tests increase k of the dynamic slice argument to prevent the
4719 // optimization
4720 int64 k_increase = spec.neg ? 5 : 0;
4721 int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
4722 int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
4723 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
4724 auto* rhs = builder.AddInstruction(
4725 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
4726 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
4727 /*cols=*/rhs_cols)));
4728
4729 int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
4730 int32 start_col = (spec.rcd == 0) ? spec.s : 0;
4731 std::vector<HloInstruction*> start_indices = {
4732 builder.AddInstruction(HloInstruction::CreateConstant(
4733 LiteralUtil::CreateR0<int32>(start_row))),
4734 builder.AddInstruction(HloInstruction::CreateConstant(
4735 LiteralUtil::CreateR0<int32>(start_col)))};
4736 int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
4737 int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
4738 std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
4739 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
4740 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4741 ds_shape, rhs, start_indices, slice_sizes));
4742
4743 DotDimensionNumbers dot_dnums;
4744 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
4745 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
4746
4747 int64 dot_row_size = spec.m;
4748 int64 dot_col_size = 1;
4749 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
4750 builder.AddInstruction(HloInstruction::CreateDot(
4751 dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
4752
4753 auto computation = m->AddEntryComputation(builder.Build());
4754 AlgebraicSimplifier simplifier(default_options_);
4755 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
4756 ASSERT_TRUE(run_successful);
4757 EXPECT_TRUE(
4758 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
4759
4760 if (spec.neg) {
4761 EXPECT_NE(computation->root_instruction()->opcode(),
4762 HloOpcode::kDynamicSlice);
4763 } else {
4764 EXPECT_THAT(computation->root_instruction(),
4765 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
4766 m::Constant(), m::Constant())));
4767 }
4768 }
4769
DotOfGatherPositiveNegativeTests()4770 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
4771 std::vector<DotOfGatherTestSpec> positives = {
4772 // "Classical dot", i.e. matrix multiply:
4773 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
4774 /*neg=*/false},
4775 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
4776 /*neg=*/false},
4777 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
4778 /*neg=*/false},
4779 // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
4780 // dot(ct, ct) before DotOfGather optimization kicks in.
4781 // Contract on rows:
4782 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
4783 /*neg=*/false},
4784 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
4785 /*neg=*/false},
4786 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
4787 /*neg=*/false},
4788 // Reverse matrix multiply:
4789 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
4790 /*neg=*/false},
4791 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
4792 /*neg=*/false},
4793 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
4794 /*neg=*/false},
4795 // Contract on columns:
4796 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
4797 /*neg=*/false},
4798 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
4799 /*neg=*/false},
4800 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
4801 /*neg=*/false},
4802 };
4803 std::vector<DotOfGatherTestSpec> all;
4804 for (int i = 0; i < positives.size(); i++) {
4805 DotOfGatherTestSpec positive_test = positives[i];
4806 all.push_back(positive_test);
4807 DotOfGatherTestSpec negative_test = positive_test;
4808 negative_test.neg = true;
4809 all.push_back(negative_test);
4810 }
4811 return all;
4812 }
4813
4814 INSTANTIATE_TEST_SUITE_P(
4815 DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
4816 ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
4817
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)4818 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
4819 const char* hlo_string = R"(
4820 HloModule module
4821
4822 reducer {
4823 parameter.1 = f32[] parameter(0)
4824 parameter.3 = f32[] parameter(2)
4825 add.2 = f32[] add(parameter.1, parameter.3)
4826 parameter.0 = f32[] parameter(1)
4827 parameter.2 = f32[] parameter(3)
4828 add.3 = f32[] add(parameter.0, parameter.2)
4829 ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
4830 }
4831
4832 ENTRY entry {
4833 parameter.6 = (f32[], f32[]) parameter(0)
4834 get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
4835 get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
4836 constant = f32[] constant(0)
4837 ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
4838 }
4839 )";
4840 TF_ASSERT_OK_AND_ASSIGN(auto module,
4841 ParseAndReturnVerifiedModule(hlo_string));
4842
4843 AlgebraicSimplifierOptions options;
4844 AlgebraicSimplifier simplifier(options);
4845 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4846 auto root = module->entry_computation()->root_instruction();
4847 EXPECT_THAT(root, GmockMatch(m::Tuple(
4848 m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
4849 m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
4850 }
4851
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)4852 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
4853 const char* hlo_string = R"(
4854 HloModule module
4855
4856 reducer {
4857 parameter.1 = f32[] parameter(0)
4858 parameter.3 = f32[] parameter(2)
4859 mul.2 = f32[] add(parameter.1, parameter.3)
4860 parameter.0 = f32[] parameter(1)
4861 parameter.2 = f32[] parameter(3)
4862 add.3 = f32[] add(parameter.0, parameter.2)
4863 ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
4864 }
4865
4866 ENTRY entry {
4867 parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
4868 get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
4869 get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
4870 constant.0 = f32[] constant(0)
4871 constant.1 = f32[] constant(1)
4872 ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
4873 }
4874 )";
4875 TF_ASSERT_OK_AND_ASSIGN(auto module,
4876 ParseAndReturnVerifiedModule(hlo_string));
4877
4878 AlgebraicSimplifierOptions options;
4879 AlgebraicSimplifier simplifier(options);
4880 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4881 auto root = module->entry_computation()->root_instruction();
4882 EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
4883 m::Broadcast(m::ConstantScalar(1)))));
4884 }
4885
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)4886 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
4887 auto builder = HloComputation::Builder(TestName());
4888 HloInstruction* param =
4889 builder.AddInstruction(HloInstruction::CreateParameter(
4890 0, ShapeUtil::MakeShape(F32, {1}), "param"));
4891 HloInstruction* broadcast =
4892 builder.AddInstruction(HloInstruction::CreateBroadcast(
4893 ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
4894
4895 // Create a reshape with zero sized result and without layout.
4896 Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
4897 reshaped_shape.clear_layout();
4898 builder.AddInstruction(
4899 HloInstruction::CreateReshape(reshaped_shape, broadcast));
4900
4901 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
4902 module->AddEntryComputation(builder.Build());
4903
4904 AlgebraicSimplifierOptions options;
4905 AlgebraicSimplifier simplifier(options);
4906 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4907 HloInstruction* root = module->entry_computation()->root_instruction();
4908 EXPECT_THAT(root, GmockMatch(m::Constant()));
4909 }
4910
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)4911 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
4912 Shape shape = ShapeUtil::MakeShape(F32, {});
4913 shape.clear_layout();
4914 auto builder = HloComputation::Builder(TestName());
4915 HloInstruction* param = builder.AddInstruction(
4916 HloInstruction::CreateParameter(0, shape, "param"));
4917
4918 HloInstruction* const_value = builder.AddInstruction(
4919 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
4920 builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
4921 param, const_value));
4922
4923 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
4924 module->AddEntryComputation(builder.Build());
4925
4926 AlgebraicSimplifierOptions options;
4927 AlgebraicSimplifier simplifier(options);
4928 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4929 HloInstruction* root = module->entry_computation()->root_instruction();
4930 EXPECT_THAT(root, GmockMatch(m::Multiply()));
4931 }
4932
4933 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)4934 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
4935 const char* kModuleStr = R"(
4936 HloModule m
4937 test {
4938 p0 = f32[] parameter(0)
4939 p1 = f32[] parameter(1)
4940 sqrt = f32[] sqrt(p0)
4941 ROOT div = f32[] divide(p1, sqrt)
4942 }
4943 )";
4944 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
4945 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
4946 EXPECT_THAT(m->entry_computation()->root_instruction(),
4947 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
4948 m::Rsqrt(m::Parameter(0)))));
4949 }
4950
4951 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)4952 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
4953 const char* kModuleStr = R"(
4954 HloModule m
4955 test {
4956 p0 = f32[] parameter(0)
4957 p1 = f32[] parameter(1)
4958 rsqrt = f32[] rsqrt(p0)
4959 ROOT div = f32[] divide(p1, rsqrt)
4960 }
4961 )";
4962 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
4963 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
4964 EXPECT_THAT(m->entry_computation()->root_instruction(),
4965 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
4966 m::Sqrt(m::Parameter(0)))));
4967 }
4968
4969 } // namespace
4970 } // namespace xla
4971