1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/test.h"
40 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/window_util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45
46 namespace xla {
47 namespace {
48
49 using ::testing::ElementsAre;
50 namespace m = match;
51
52 class AlgebraicSimplifierTest : public HloTestBase {
53 protected:
54 AlgebraicSimplifierOptions default_options_;
55 };
56
57 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)58 TEST_F(AlgebraicSimplifierTest, AddZero) {
59 auto m = CreateNewVerifiedModule();
60 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
61 HloComputation::Builder builder(TestName());
62 HloInstruction* param0 = builder.AddInstruction(
63 HloInstruction::CreateParameter(0, r0f32, "param0"));
64 HloInstruction* zero = builder.AddInstruction(
65 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
66 builder.AddInstruction(
67 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
68
69 auto computation = m->AddEntryComputation(builder.Build());
70 HloInstruction* root = computation->root_instruction();
71 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
72 AlgebraicSimplifier simplifier(default_options_);
73 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
74 root = computation->root_instruction();
75 EXPECT_EQ(root, param0);
76 }
77
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)78 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
79 const char* kModuleStr = R"(
80 HloModule m
81 test {
82 p0 = s32[8] parameter(0)
83 p1 = s32[8] parameter(1)
84 p2 = s32[8] parameter(2)
85 x = s32[8] multiply(p0, p2)
86 y = s32[8] multiply(p1, p2)
87 ROOT sum = s32[8] add(x, y)
88 }
89 )";
90 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
91 AlgebraicSimplifier simplifier(default_options_);
92 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
93 EXPECT_THAT(
94 m->entry_computation()->root_instruction(),
95 GmockMatch(m::MultiplyAnyOrder(
96 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
97 }
98
99 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)100 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
101 const char* kModuleStr = R"(
102 HloModule m
103 test {
104 p0 = f32[] parameter(0)
105 p1 = f32[] parameter(1)
106 c = f32[] constant(0.125)
107 x = f32[] multiply(p0, c)
108 y = f32[] multiply(p1, c)
109 ROOT sum = f32[] add(x, y)
110 }
111 )";
112 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
113 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
114 EXPECT_THAT(m->entry_computation()->root_instruction(),
115 GmockMatch(m::MultiplyAnyOrder(
116 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
117 m::ConstantScalar(0.125))));
118 }
119
120 // (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest,SquareOfAbs)121 TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
122 const char* kModuleStr = R"(
123 HloModule m
124 test {
125 p = f32[] parameter(0)
126 a = f32[] abs(p)
127 ROOT z = f32[] multiply(a, a)
128 }
129 )";
130 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
131 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
132 EXPECT_THAT(m->entry_computation()->root_instruction(),
133 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
134 }
135
136 // (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain)137 TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
138 const char* kModuleStr = R"(
139 HloModule m
140 test {
141 p0 = f32[] parameter(0)
142 p1 = f32[] parameter(1)
143 c = f32[] constant(2)
144 d = f32[] constant(4)
145 x = f32[] multiply(p0, c)
146 y = f32[] multiply(p1, d)
147 ROOT z = f32[] multiply(x, y)
148 }
149 )";
150 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
151 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
152 EXPECT_THAT(
153 m->entry_computation()->root_instruction(),
154 GmockMatch(m::MultiplyAnyOrder(
155 m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
156 m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
157 }
158
159 // (a*C1)*C2 => a*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain2)160 TEST_F(AlgebraicSimplifierTest, MultiplyChain2) {
161 const char* kModuleStr = R"(
162 HloModule m
163 test {
164 p0 = f32[] parameter(0)
165 a = f32[] constant(2)
166 b = f32[] constant(4)
167 c = f32[] multiply(p0, a)
168 ROOT y = f32[] multiply(c, b)
169 }
170 )";
171 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
172 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
173 EXPECT_THAT(m->entry_computation()->root_instruction(),
174 GmockMatch(m::MultiplyAnyOrder(
175 m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2),
176 m::ConstantScalar(4)))));
177 }
178
179 // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==>
180 // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant))))
TEST_F(AlgebraicSimplifierTest,MultiplyBroadcastReassoc)181 TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) {
182 const char* kModuleStr = R"(
183 HloModule m
184 test {
185 p0 = f32[2,2] parameter(0)
186 p1 = f32[] parameter(1)
187 b = f32[] constant(2)
188 c = f32[2, 2] broadcast(b), dimensions={}
189 x = f32[2,2] multiply(p0, c)
190 y = f32[2,2] broadcast(p1), dimensions={}
191 ROOT z = f32[2,2] multiply(y, x)
192 }
193 )";
194 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
195 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
196 EXPECT_THAT(m->entry_computation()->root_instruction(),
197 GmockMatch(m::MultiplyAnyOrder(
198 m::Parameter(0), m::Broadcast(m::MultiplyAnyOrder(
199 m::Parameter(1), m::Constant())))));
200 }
201
202 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)203 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
204 const char* kModuleStr = R"(
205 HloModule m
206 test {
207 p0 = f32[4] parameter(0)
208 p1 = f32[4] parameter(1)
209 c = f32[] constant(0.125)
210 b = f32[4] broadcast(c), dimensions={}
211 x = f32[4] multiply(p0, b)
212 y = f32[4] multiply(p1, b)
213 ROOT sum = f32[4] add(x, y)
214 }
215 )";
216 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
217 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
218 EXPECT_THAT(m->entry_computation()->root_instruction(),
219 GmockMatch(m::MultiplyAnyOrder(
220 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
221 m::Broadcast(m::ConstantScalar(0.125)))));
222 }
223
224 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
225 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)226 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
227 const char* kModuleStr = R"(
228 HloModule m
229 test {
230 p0 = f32[] parameter(0)
231 p1 = f32[] parameter(1)
232 c = f32[] constant(0.3)
233 x = f32[] multiply(p0, c)
234 y = f32[] multiply(p1, c)
235 ROOT sum = f32[] add(x, y)
236 }
237 )";
238 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
239 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
240 }
241
242 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
243 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)244 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
245 const char* kModuleStr = R"(
246 HloModule m
247 test {
248 p0 = c64[8] parameter(0)
249 p1 = c64[8] parameter(1)
250 p2 = c64[8] parameter(2)
251 x = c64[8] multiply(p0, p2)
252 y = c64[8] multiply(p1, p2)
253 ROOT sum = c64[8] add(x, y)
254 }
255 )";
256 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
257 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
258 }
259
260 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)261 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
262 const char* kModuleStr = R"(
263 HloModule m
264 test {
265 p0 = bf16[4] parameter(0)
266 p1 = bf16[4] parameter(1)
267 c = bf16[] constant(0.125)
268 b = bf16[4] broadcast(c), dimensions={}
269 x = bf16[4] multiply(p0, b)
270 y = bf16[4] multiply(p1, b)
271 ROOT sum = bf16[4] add(x, y)
272 }
273 )";
274 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
275 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
276 EXPECT_THAT(m->entry_computation()->root_instruction(),
277 GmockMatch(m::MultiplyAnyOrder(
278 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
279 m::Broadcast(m::ConstantScalar(0.125)))));
280 }
281
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)282 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
283 const char* kModuleStr = R"(
284 HloModule m
285 test {
286 p = u32[4] parameter(0)
287 c = u32[] constant(8)
288 b = u32[4] broadcast(c), dimensions={}
289 ROOT d = u32[4] divide(p, b)
290 }
291 )";
292 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
293 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
294 EXPECT_THAT(m->entry_computation()->root_instruction(),
295 GmockMatch(m::ShiftRightLogical(
296 m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
297 }
298
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)299 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
300 const char* kModuleStr = R"(
301 HloModule m
302 test {
303 p = s32[4] parameter(0)
304 c = s32[] constant(8)
305 b = s32[4] broadcast(c), dimensions={}
306 ROOT d = s32[4] divide(p, b)
307 }
308 )";
309 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
310 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
311 auto match_dividend_is_negative =
312 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
313 auto match_abs = m::Select(match_dividend_is_negative,
314 m::Negate(m::Parameter(0)), m::Parameter(0));
315 auto match_shift =
316 m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
317 EXPECT_THAT(m->entry_computation()->root_instruction(),
318 GmockMatch(m::Select(match_dividend_is_negative,
319 m::Negate(match_shift), match_shift)));
320 }
321
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)322 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
323 const char* kModuleStr = R"(
324 HloModule m
325 test {
326 p = u32[4] parameter(0)
327 c = u32[] constant(8)
328 b = u32[4] broadcast(c), dimensions={}
329 ROOT r = u32[4] remainder(p, b)
330 }
331 )";
332 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
333 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
334 EXPECT_THAT(m->entry_computation()->root_instruction(),
335 GmockMatch(m::AndAnyOrder(m::Parameter(0),
336 m::Broadcast(m::ConstantScalar(7)))));
337 }
338
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)339 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
340 const char* kModuleStr = R"(
341 HloModule m
342 test {
343 p = s32[4] parameter(0)
344 c = s32[] constant(8)
345 b = s32[4] broadcast(c), dimensions={}
346 ROOT r = s32[4] remainder(p, b)
347 }
348 )";
349 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
350 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
351 auto match_dividend_is_negative =
352 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
353 auto match_abs = m::Select(match_dividend_is_negative,
354 m::Negate(m::Parameter(0)), m::Parameter(0));
355 auto match_and =
356 m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
357 EXPECT_THAT(m->entry_computation()->root_instruction(),
358 GmockMatch(m::Select(match_dividend_is_negative,
359 m::Negate(match_and), match_and)));
360 }
361
362 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)363 TEST_F(AlgebraicSimplifierTest, MulZero) {
364 auto m = CreateNewVerifiedModule();
365 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
366 HloComputation::Builder builder(TestName());
367 HloInstruction* param0 = builder.AddInstruction(
368 HloInstruction::CreateParameter(0, r0s32, "param0"));
369 HloInstruction* zero = builder.AddInstruction(
370 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
371 builder.AddInstruction(
372 HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
373
374 auto computation = m->AddEntryComputation(builder.Build());
375 HloInstruction* root = computation->root_instruction();
376 EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
377 AlgebraicSimplifier simplifier(default_options_);
378 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
379 EXPECT_EQ(computation->root_instruction(), zero);
380 }
381
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeConstants)382 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
383 const char* kModuleStr = R"(
384 HloModule m
385 test {
386 p0 = f32[] parameter(0)
387 c0 = f32[] constant(2.0)
388 c1 = f32[] constant(3.0)
389 multiply0 = f32[] multiply(p0, c0)
390 ROOT multiply1 = f32[] multiply(multiply0, c1)
391 }
392 )";
393 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
394 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
395 EXPECT_THAT(m->entry_computation()->root_instruction(),
396 GmockMatch(m::Multiply(m::Parameter(0),
397 m::Multiply(m::ConstantScalar(2.0),
398 m::ConstantScalar(3.0)))));
399 }
400
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeBroadcastedConstants)401 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
402 const char* kModuleStr = R"(
403 HloModule m
404 test {
405 p0 = f32[4] parameter(0)
406 c0 = f32[] constant(2.0)
407 c1 = f32[] constant(3.0)
408 b0 = f32[4] broadcast(c0), dimensions={}
409 b1 = f32[4] broadcast(c1), dimensions={}
410 multiply0 = f32[4] multiply(p0, b0)
411 ROOT multiply1 = f32[4] multiply(multiply0, b1)
412 }
413 )";
414 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
415 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
416 EXPECT_THAT(
417 m->entry_computation()->root_instruction(),
418 GmockMatch(m::Multiply(
419 m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
420 m::ConstantScalar(3.0))))));
421 }
422
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsScalar)423 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
424 const char* kModuleStr = R"(
425 HloModule m
426 test {
427 p0 = f32[] parameter(0)
428 p1 = f32[] parameter(1)
429 b0 = f32[4] broadcast(p0), dimensions={}
430 b1 = f32[4] broadcast(p1), dimensions={}
431 ROOT multiply = f32[4] multiply(b1, b0)
432 }
433 )";
434 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
435 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
436 EXPECT_THAT(
437 m->entry_computation()->root_instruction(),
438 GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
439 m::Broadcast(m::Parameter(0))))));
440 }
441
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsConstantMix)442 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
443 const char* kModuleStr = R"(
444 HloModule m
445 test {
446 p0 = f32[4] parameter(0)
447 c0 = f32[] constant(2.0)
448 b0 = f32[4,2] broadcast(c0), dimensions={}
449 b1 = f32[4,2] broadcast(p0), dimensions={0}
450 ROOT multiply = f32[4,2] multiply(b1, b0)
451 }
452 )";
453 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
454 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
455 EXPECT_THAT(m->entry_computation()->root_instruction(),
456 GmockMatch(m::Broadcast(m::Multiply(
457 m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
458 }
459
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsNonScalar)460 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
461 const char* kModuleStr = R"(
462 HloModule m
463 test {
464 p0 = f32[4] parameter(0)
465 p1 = f32[4] parameter(1)
466 b0 = f32[4,2] broadcast(p0), dimensions={0}
467 b1 = f32[4,2] broadcast(p1), dimensions={0}
468 ROOT multiply = f32[4,2] multiply(b1, b0)
469 }
470 )";
471 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
472 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
473 EXPECT_THAT(
474 m->entry_computation()->root_instruction(),
475 GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
476 }
477
TEST_F(AlgebraicSimplifierTest,ElementwiseNoSinkBroadcastsDifferentDims)478 TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
479 const char* kModuleStr = R"(
480 HloModule m
481 test {
482 p0 = f32[4] parameter(0)
483 p1 = f32[8] parameter(1)
484 b0 = f32[4,8] broadcast(p0), dimensions={0}
485 b1 = f32[4,8] broadcast(p1), dimensions={1}
486 ROOT multiply = f32[4,8] multiply(b1, b0)
487 }
488 )";
489 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
490 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
491 EXPECT_THAT(m->entry_computation()->root_instruction(),
492 GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
493 m::Broadcast(m::Parameter(0)))));
494 }
495
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMultiplyOfConstantAndBroadcast)496 TEST_F(AlgebraicSimplifierTest,
497 MultiplyReassociateMultiplyOfConstantAndBroadcast) {
498 const char* kModuleStr = R"(
499 HloModule m
500 test {
501 c0 = f32[4] constant({2.0, 3.0, 4.0, 5.0})
502 c1 = f32[] constant(3.0)
503 c2 = f32[] constant(4.0)
504 b0 = f32[4] broadcast(c1), dimensions={}
505 b1 = f32[4] broadcast(c2), dimensions={}
506 multiply0 = f32[4] multiply(c0, b0)
507 ROOT multiply1 = f32[4] multiply(multiply0, b1)
508 }
509 )";
510 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
511 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
512 EXPECT_THAT(
513 m->entry_computation()->root_instruction(),
514 GmockMatch(m::Multiply(
515 m::Constant(), m::Broadcast(m::Multiply(m::ConstantScalar(3.0),
516 m::ConstantScalar(4.0))))));
517 }
518
519 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)520 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
521 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
522 HloComputation::Builder builder(TestName());
523 HloInstruction* param0 = builder.AddInstruction(
524 HloInstruction::CreateParameter(0, r0s32, "param0"));
525 HloInstruction* param1 = builder.AddInstruction(
526 HloInstruction::CreateParameter(1, r0s32, "param1"));
527 HloInstruction* one = builder.AddInstruction(
528 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
529 builder.AddInstruction(HloInstruction::CreateTernary(
530 r0s32, HloOpcode::kSelect, one, param0, param1));
531
532 auto module = CreateNewVerifiedModule();
533 auto computation = module->AddEntryComputation(builder.Build());
534 HloInstruction* root = computation->root_instruction();
535 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
536 AlgebraicSimplifier simplifier(default_options_);
537 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
538 EXPECT_EQ(computation->root_instruction(), param0);
539 }
540
541 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)542 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
543 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
544 HloComputation::Builder builder(TestName());
545 HloInstruction* param0 = builder.AddInstruction(
546 HloInstruction::CreateParameter(0, r0s32, "param0"));
547 HloInstruction* param1 = builder.AddInstruction(
548 HloInstruction::CreateParameter(1, r0s32, "param1"));
549 HloInstruction* zero = builder.AddInstruction(
550 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
551 builder.AddInstruction(HloInstruction::CreateTernary(
552 r0s32, HloOpcode::kSelect, zero, param0, param1));
553
554 auto module = CreateNewVerifiedModule();
555 auto computation = module->AddEntryComputation(builder.Build());
556 HloInstruction* root = computation->root_instruction();
557 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
558 AlgebraicSimplifier simplifier(default_options_);
559 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
560 EXPECT_EQ(computation->root_instruction(), param1);
561 }
562
563 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)564 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
565 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
566 HloComputation::Builder builder(TestName());
567 HloInstruction* param0 = builder.AddInstruction(
568 HloInstruction::CreateParameter(0, r0s32, "param0"));
569 HloInstruction* param1 = builder.AddInstruction(
570 HloInstruction::CreateParameter(1, r0s32, "param1"));
571 builder.AddInstruction(HloInstruction::CreateTernary(
572 r0s32, HloOpcode::kSelect, param0, param1, param1));
573
574 auto module = CreateNewVerifiedModule();
575 auto computation = module->AddEntryComputation(builder.Build());
576 HloInstruction* root = computation->root_instruction();
577 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
578 AlgebraicSimplifier simplifier(default_options_);
579 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
580 EXPECT_EQ(computation->root_instruction(), param1);
581 }
582
583 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)584 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
585 auto m = CreateNewVerifiedModule();
586 HloComputation::Builder builder(TestName());
587 // Create add computation.
588 HloInstruction* zero = builder.AddInstruction(
589 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
590 HloComputation* add_computation = nullptr;
591 {
592 HloComputation::Builder builder(TestName() + ".add");
593 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
594 HloInstruction* p0 = builder.AddInstruction(
595 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
596 HloInstruction* p1 = builder.AddInstruction(
597 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
598 builder.AddInstruction(
599 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
600 add_computation = m->AddEmbeddedComputation(builder.Build());
601 }
602 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
603 HloInstruction* param = builder.AddInstruction(
604 HloInstruction::CreateParameter(0, r4f32, "param"));
605 std::vector<int64> dims0({0});
606 Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
607 HloInstruction* reduce0 = builder.AddInstruction(
608 HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
609 std::vector<int64> dims1({1, 2});
610 Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
611 builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
612 dims1, add_computation));
613 m->AddEntryComputation(builder.Build());
614 AlgebraicSimplifier simplifier(default_options_);
615 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
616 HloInstruction* root = m->entry_computation()->root_instruction();
617 EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
618 EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
619 }
620
621 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)622 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
623 auto m = CreateNewVerifiedModule();
624 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
625 HloComputation::Builder builder(TestName());
626 HloInstruction* param0 = builder.AddInstruction(
627 HloInstruction::CreateParameter(0, r0f32, "param0"));
628 HloInstruction* constant = builder.AddInstruction(
629 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
630 builder.AddInstruction(
631 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
632
633 auto computation = m->AddEntryComputation(builder.Build());
634 HloInstruction* root = computation->root_instruction();
635 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
636 AlgebraicSimplifier simplifier(default_options_);
637 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
638 root = computation->root_instruction();
639 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
640 }
641
642 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)643 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
644 auto m = CreateNewVerifiedModule();
645 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
646 HloComputation::Builder builder(TestName());
647 HloInstruction* param0 = builder.AddInstruction(
648 HloInstruction::CreateParameter(0, r0f32, "param0"));
649 HloInstruction* constant1 = builder.AddInstruction(
650 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
651 HloInstruction* constant2 = builder.AddInstruction(
652 HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
653
654 HloInstruction* add1 = builder.AddInstruction(
655 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
656 builder.AddInstruction(
657 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
658
659 auto computation = m->AddEntryComputation(builder.Build());
660 HloInstruction* root = computation->root_instruction();
661 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
662 AlgebraicSimplifier simplifier(default_options_);
663 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
664 root = computation->root_instruction();
665 EXPECT_THAT(root, GmockMatch(m::Add(
666 m::Op().Is(param0),
667 m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
668 }
669
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeBroadcastedConstants)670 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
671 const char* kModuleStr = R"(
672 HloModule m
673 test {
674 p0 = f32[4] parameter(0)
675 c0 = f32[] constant(1.0)
676 c1 = f32[] constant(2.0)
677 b0 = f32[4] broadcast(c0), dimensions={}
678 b1 = f32[4] broadcast(c1), dimensions={}
679 add0 = f32[4] add(p0, b0)
680 ROOT add1 = f32[4] add(add0, b1)
681 }
682 )";
683 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
684 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
685 EXPECT_THAT(m->entry_computation()->root_instruction(),
686 GmockMatch(m::Add(m::Parameter(0),
687 m::Broadcast(m::Add(m::ConstantScalar(1.0),
688 m::ConstantScalar(2.0))))));
689 }
690
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)691 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
692 auto m = CreateNewVerifiedModule();
693 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
694 HloComputation::Builder builder(TestName());
695 HloInstruction* param0 = builder.AddInstruction(
696 HloInstruction::CreateParameter(0, r2f32, "param0"));
697 HloInstruction* zero = builder.AddInstruction(
698 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
699 HloInstruction* bcast = builder.AddInstruction(
700 HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
701 builder.AddInstruction(
702 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
703
704 auto computation = m->AddEntryComputation(builder.Build());
705 HloInstruction* root = computation->root_instruction();
706 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
707 AlgebraicSimplifier simplifier(default_options_);
708 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
709 root = computation->root_instruction();
710 EXPECT_EQ(root, param0);
711 }
712
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)713 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
714 auto m = CreateNewVerifiedModule();
715 HloComputation::Builder builder(TestName());
716 // Create add computation.
717 HloComputation* add_computation = nullptr;
718 {
719 HloComputation::Builder builder(TestName() + ".add");
720 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
721 HloInstruction* p0 = builder.AddInstruction(
722 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
723 HloInstruction* p1 = builder.AddInstruction(
724 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
725 builder.AddInstruction(
726 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
727 add_computation = m->AddEmbeddedComputation(builder.Build());
728 }
729 Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
730 HloInstruction* param0 = builder.AddInstruction(
731 HloInstruction::CreateParameter(0, r2f32, "param0"));
732 HloInstruction* zero = builder.AddInstruction(
733 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
734 builder.AddInstruction(HloInstruction::CreateMap(
735 r2f32,
736 {param0, builder.AddInstruction(
737 HloInstruction::CreateBroadcast(r2f32, zero, {}))},
738 add_computation));
739
740 auto computation = m->AddEntryComputation(builder.Build());
741 HloInstruction* root = computation->root_instruction();
742 EXPECT_EQ(root->opcode(), HloOpcode::kMap);
743 AlgebraicSimplifier simplifier(default_options_);
744 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
745 root = computation->root_instruction();
746 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
747 m::Broadcast(m::Op().Is(zero)))));
748 }
749
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)750 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
751 const char* kModuleStr = R"(
752 HloModule m
753 fusion {
754 x = f32[] parameter(0)
755 c = f32[] constant(42)
756 m = f32[] multiply(x, x)
757 ROOT a = f32[] add(m, c)
758 }
759
760 map {
761 x = f32[] parameter(0)
762 ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
763 }
764
765 ENTRY test {
766 p = f32[2,2] parameter(0)
767 ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
768 }
769 )";
770 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
771 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
772 }
773
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)774 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
775 auto m = CreateNewVerifiedModule();
776 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
777 HloComputation::Builder builder(TestName());
778 HloInstruction* param0 = builder.AddInstruction(
779 HloInstruction::CreateParameter(0, r2f32, "param0"));
780 HloInstruction* zero = builder.AddInstruction(
781 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
782 HloInstruction* bcast =
783 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
784 builder.AddInstruction(
785 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
786
787 auto computation = m->AddEntryComputation(builder.Build());
788 HloInstruction* root = computation->root_instruction();
789 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
790 AlgebraicSimplifier simplifier(default_options_);
791 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
792 root = computation->root_instruction();
793 EXPECT_EQ(root, param0);
794 }
795
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)796 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
797 auto m = CreateNewVerifiedModule();
798 HloComputation::Builder builder(TestName());
799 builder.AddInstruction(HloInstruction::CreateConstant(
800 LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
801
802 auto computation = m->AddEntryComputation(builder.Build());
803 HloInstruction* root = computation->root_instruction();
804 EXPECT_THAT(root, GmockMatch(m::Constant()));
805 AlgebraicSimplifier simplifier(default_options_);
806 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
807 root = computation->root_instruction();
808 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
809 EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
810 }
811
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)812 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
813 auto m = CreateNewVerifiedModule();
814 HloComputation::Builder builder(TestName());
815 builder.AddInstruction(HloInstruction::CreateConstant(
816 LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
817
818 auto computation = m->AddEntryComputation(builder.Build());
819 HloInstruction* root = computation->root_instruction();
820 EXPECT_THAT(root, GmockMatch(m::Constant()));
821 AlgebraicSimplifier simplifier(default_options_);
822 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
823 root = computation->root_instruction();
824 EXPECT_THAT(root, GmockMatch(m::Constant()));
825 }
826
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)827 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
828 auto m = CreateNewVerifiedModule();
829 HloComputation::Builder builder(TestName());
830 builder.AddInstruction(HloInstruction::CreateConstant(
831 LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
832
833 auto computation = m->AddEntryComputation(builder.Build());
834 HloInstruction* root = computation->root_instruction();
835 EXPECT_THAT(root, GmockMatch(m::Constant()));
836 AlgebraicSimplifier simplifier(default_options_);
837 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
838 root = computation->root_instruction();
839 EXPECT_THAT(root, GmockMatch(m::Iota()));
840 }
841
842 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)843 TEST_F(AlgebraicSimplifierTest, SubZero) {
844 auto m = CreateNewVerifiedModule();
845 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
846 HloComputation::Builder builder(TestName());
847 HloInstruction* param0 = builder.AddInstruction(
848 HloInstruction::CreateParameter(0, r0f32, "param0"));
849 HloInstruction* zero = builder.AddInstruction(
850 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
851 builder.AddInstruction(
852 HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
853
854 auto computation = m->AddEntryComputation(builder.Build());
855 HloInstruction* root = computation->root_instruction();
856 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
857 AlgebraicSimplifier simplifier(default_options_);
858 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
859 root = computation->root_instruction();
860 EXPECT_EQ(root, param0);
861 }
862
863 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)864 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
865 auto m = CreateNewVerifiedModule();
866 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
867 HloComputation::Builder builder(TestName());
868 HloInstruction* param0 = builder.AddInstruction(
869 HloInstruction::CreateParameter(0, r0f32, "param0"));
870 HloInstruction* constant = builder.AddInstruction(
871 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
872 builder.AddInstruction(HloInstruction::CreateBinary(
873 r0f32, HloOpcode::kSubtract, param0, constant));
874
875 auto computation = m->AddEntryComputation(builder.Build());
876 HloInstruction* root = computation->root_instruction();
877 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
878 AlgebraicSimplifier simplifier(default_options_);
879 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
880 root = computation->root_instruction();
881 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
882 m::Negate(m::Op().Is(constant)))));
883 }
884
885 // Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
TEST_F(AlgebraicSimplifierTest,SubBroadcastConstCanonicalization)886 TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
887 const char* kModuleStr = R"(
888 HloModule m
889 test {
890 p0 = f32[4] parameter(0)
891 c = f32[] constant(0.125)
892 b = f32[4] broadcast(c), dimensions={}
893 ROOT sub = f32[4] subtract(p0, b)
894 }
895 )";
896 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
897 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
898 EXPECT_THAT(
899 m->entry_computation()->root_instruction(),
900 GmockMatch(m::Add(m::Parameter(0),
901 m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
902 }
903
904 // Test that Broadcast(x) where x has degenerate dimensions first removes the
905 // degenerate dimensions.
TEST_F(AlgebraicSimplifierTest,DegenerateDimsInOperandRemovedFromBroadcast)906 TEST_F(AlgebraicSimplifierTest, DegenerateDimsInOperandRemovedFromBroadcast) {
907 const char* kModuleStr = R"(
908 HloModule m
909 test {
910 c = f32[1,4] parameter(0)
911 ROOT b = f32[5,1,4,3] broadcast(c), dimensions={1,2}
912 }
913 )";
914 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
915 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
916 EXPECT_THAT(m->entry_computation()->root_instruction(),
917 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
918 }
919
920 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)921 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
922 auto m = CreateNewVerifiedModule();
923 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
924 HloComputation::Builder builder(TestName());
925 HloInstruction* param0 = builder.AddInstruction(
926 HloInstruction::CreateParameter(0, r0f32, "param0"));
927 HloInstruction* param1 = builder.AddInstruction(
928 HloInstruction::CreateParameter(1, r0f32, "param1"));
929 HloInstruction* param2 = builder.AddInstruction(
930 HloInstruction::CreateParameter(2, r0f32, "param2"));
931 HloInstruction* div = builder.AddInstruction(
932 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
933 builder.AddInstruction(
934 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
935
936 auto computation = m->AddEntryComputation(builder.Build());
937
938 EXPECT_THAT(computation->root_instruction(),
939 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
940 m::Parameter(2))));
941
942 AlgebraicSimplifier simplifier(default_options_);
943 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
944
945 EXPECT_THAT(
946 computation->root_instruction(),
947 GmockMatch(m::Divide(m::Parameter(0),
948 m::Multiply(m::Parameter(1), m::Parameter(2)))));
949 }
950
951 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)952 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
953 auto m = CreateNewVerifiedModule();
954 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
955 HloComputation::Builder builder(TestName());
956 HloInstruction* param0 = builder.AddInstruction(
957 HloInstruction::CreateParameter(0, r0f32, "param0"));
958 HloInstruction* param1 = builder.AddInstruction(
959 HloInstruction::CreateParameter(1, r0f32, "param1"));
960 HloInstruction* param2 = builder.AddInstruction(
961 HloInstruction::CreateParameter(2, r0f32, "param2"));
962 HloInstruction* div = builder.AddInstruction(
963 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
964 builder.AddInstruction(
965 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
966
967 auto computation = m->AddEntryComputation(builder.Build());
968
969 EXPECT_THAT(
970 computation->root_instruction(),
971 GmockMatch(m::Divide(m::Parameter(0),
972 m::Divide(m::Parameter(1), m::Parameter(2)))));
973
974 AlgebraicSimplifier simplifier(default_options_);
975 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
976
977 EXPECT_THAT(
978 computation->root_instruction(),
979 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
980 m::Parameter(1))));
981 }
982
983 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)984 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
985 auto m = CreateNewVerifiedModule();
986 Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
987 HloComputation::Builder builder(TestName());
988 HloInstruction* param0 = builder.AddInstruction(
989 HloInstruction::CreateParameter(0, r2f32, "param0"));
990 HloInstruction* param1 = builder.AddInstruction(
991 HloInstruction::CreateParameter(1, r2f32, "param1"));
992 HloInstruction* param2 = builder.AddInstruction(
993 HloInstruction::CreateParameter(2, r2f32, "param2"));
994 HloInstruction* param3 = builder.AddInstruction(
995 HloInstruction::CreateParameter(3, r2f32, "param3"));
996 HloInstruction* div0 = builder.AddInstruction(
997 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
998 HloInstruction* div1 = builder.AddInstruction(
999 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
1000 builder.AddInstruction(
1001 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
1002
1003 auto computation = m->AddEntryComputation(builder.Build());
1004
1005 EXPECT_THAT(
1006 computation->root_instruction(),
1007 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1008 m::Divide(m::Parameter(2), m::Parameter(3)))));
1009
1010 AlgebraicSimplifier simplifier(default_options_);
1011 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1012
1013 EXPECT_THAT(
1014 computation->root_instruction(),
1015 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
1016 m::Multiply(m::Parameter(1), m::Parameter(2)))));
1017 }
1018
1019 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)1020 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
1021 auto m = CreateNewVerifiedModule();
1022 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1023 HloComputation::Builder builder(TestName());
1024 HloInstruction* param0 = builder.AddInstruction(
1025 HloInstruction::CreateParameter(0, r0f32, "param0"));
1026 HloInstruction* param1 = builder.AddInstruction(
1027 HloInstruction::CreateParameter(1, r0f32, "param1"));
1028 HloInstruction* exp = builder.AddInstruction(
1029 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1030 builder.AddInstruction(
1031 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
1032
1033 auto computation = m->AddEntryComputation(builder.Build());
1034
1035 EXPECT_THAT(computation->root_instruction(),
1036 GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
1037
1038 AlgebraicSimplifier simplifier(default_options_);
1039 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1040
1041 EXPECT_THAT(computation->root_instruction(),
1042 GmockMatch(m::Multiply(m::Parameter(0),
1043 m::Exp(m::Negate(m::Parameter(1))))));
1044 }
1045
1046 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)1047 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
1048 auto m = CreateNewVerifiedModule();
1049 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1050 HloComputation::Builder builder(TestName());
1051 HloInstruction* param0 = builder.AddInstruction(
1052 HloInstruction::CreateParameter(0, r0f32, "param0"));
1053 HloInstruction* param1 = builder.AddInstruction(
1054 HloInstruction::CreateParameter(1, r0f32, "param1"));
1055 HloInstruction* param2 = builder.AddInstruction(
1056 HloInstruction::CreateParameter(2, r0f32, "param2"));
1057 HloInstruction* power = builder.AddInstruction(
1058 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
1059 builder.AddInstruction(
1060 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
1061
1062 auto computation = m->AddEntryComputation(builder.Build());
1063
1064 EXPECT_THAT(
1065 computation->root_instruction(),
1066 GmockMatch(m::Divide(m::Parameter(0),
1067 m::Power(m::Parameter(1), m::Parameter(2)))));
1068
1069 AlgebraicSimplifier simplifier(default_options_);
1070 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1071
1072 EXPECT_THAT(computation->root_instruction(),
1073 GmockMatch(m::Multiply(
1074 m::Parameter(0),
1075 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1076 }
1077
1078 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
1079 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)1080 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
1081 auto m = CreateNewVerifiedModule();
1082 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1083 HloComputation::Builder builder(TestName());
1084 HloInstruction* param0 = builder.AddInstruction(
1085 HloInstruction::CreateParameter(0, r1f32, "param0"));
1086 HloInstruction* param1 = builder.AddInstruction(
1087 HloInstruction::CreateParameter(1, r1f32, "param1"));
1088 HloInstruction* param2 = builder.AddInstruction(
1089 HloInstruction::CreateParameter(2, r1f32, "param2"));
1090 HloInstruction* power = builder.AddInstruction(
1091 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
1092 builder.AddInstruction(
1093 HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
1094
1095 auto computation = m->AddEntryComputation(builder.Build());
1096
1097 EXPECT_THAT(
1098 computation->root_instruction(),
1099 GmockMatch(m::Divide(m::Parameter(0),
1100 m::Power(m::Parameter(1), m::Parameter(2)))));
1101
1102 AlgebraicSimplifier simplifier(default_options_);
1103 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1104
1105 ASSERT_THAT(computation->root_instruction(),
1106 GmockMatch(m::Multiply(
1107 m::Parameter(0),
1108 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1109 }
1110
1111 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)1112 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
1113 auto m = CreateNewVerifiedModule();
1114 Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
1115 HloComputation::Builder builder(TestName());
1116 HloInstruction* param0 = builder.AddInstruction(
1117 HloInstruction::CreateParameter(0, r1f32, "param0"));
1118 HloInstruction* constant =
1119 builder.AddInstruction(HloInstruction::CreateConstant(
1120 LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
1121 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
1122 param0, constant));
1123
1124 auto computation = m->AddEntryComputation(builder.Build());
1125
1126 AlgebraicSimplifier simplifier(default_options_);
1127 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1128
1129 EXPECT_THAT(computation->root_instruction(),
1130 GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
1131 }
1132
1133 // A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest,DivideByBroadcastedConstant)1134 TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
1135 const char* kModuleStr = R"(
1136 HloModule m
1137 test {
1138 p = f32[4] parameter(0)
1139 c = f32[] constant(256.0)
1140 b = f32[4] broadcast(c), dimensions={}
1141 ROOT d = f32[4] divide(p, b)
1142 }
1143 )";
1144 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
1145 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
1146
1147 EXPECT_THAT(m->entry_computation()->root_instruction(),
1148 GmockMatch(m::Multiply(
1149 m::Parameter(0),
1150 m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
1151 }
1152
1153 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)1154 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
1155 auto m = CreateNewVerifiedModule();
1156 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1157 HloComputation::Builder builder(TestName());
1158 HloInstruction* base = builder.AddInstruction(
1159 HloInstruction::CreateParameter(0, r1f32, "param0"));
1160 HloInstruction* exp1 = builder.AddInstruction(
1161 HloInstruction::CreateParameter(1, r1f32, "param1"));
1162 HloInstruction* exp2 = builder.AddInstruction(
1163 HloInstruction::CreateParameter(2, r1f32, "param2"));
1164 HloInstruction* inner_power = builder.AddInstruction(
1165 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
1166 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
1167 inner_power, exp2));
1168
1169 AlgebraicSimplifier simplifier(default_options_);
1170 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1171 }
1172
1173 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
1174 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)1175 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
1176 auto m = CreateNewVerifiedModule();
1177 Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
1178 HloComputation::Builder builder(TestName());
1179 HloInstruction* base = builder.AddInstruction(
1180 HloInstruction::CreateParameter(0, r1c64, "param0"));
1181 HloInstruction* exp1 = builder.AddInstruction(
1182 HloInstruction::CreateParameter(1, r1c64, "param1"));
1183 HloInstruction* exp2 = builder.AddInstruction(
1184 HloInstruction::CreateParameter(2, r1c64, "param2"));
1185 HloInstruction* inner_power = builder.AddInstruction(
1186 HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
1187 builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
1188 inner_power, exp2));
1189
1190 m->AddEntryComputation(builder.Build());
1191 AlgebraicSimplifier simplifier(default_options_);
1192 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1193 }
1194
1195 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)1196 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
1197 auto m = CreateNewVerifiedModule();
1198 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1199 HloComputation::Builder builder(TestName());
1200 HloInstruction* param0 = builder.AddInstruction(
1201 HloInstruction::CreateParameter(0, r0f32, "param0"));
1202 HloInstruction* one = builder.AddInstruction(
1203 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1204 HloInstruction* div = builder.AddInstruction(
1205 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
1206
1207 auto computation = m->AddEntryComputation(builder.Build());
1208 HloInstruction* root = computation->root_instruction();
1209 EXPECT_EQ(root, div);
1210 AlgebraicSimplifier simplifier(default_options_);
1211 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1212 root = computation->root_instruction();
1213 EXPECT_EQ(root, param0);
1214 }
1215
1216 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)1217 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
1218 auto m = CreateNewVerifiedModule();
1219 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1220 HloComputation::Builder builder(TestName());
1221 HloInstruction* param0 = builder.AddInstruction(
1222 HloInstruction::CreateParameter(0, r2f32, "param0"));
1223 HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
1224 LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
1225 HloInstruction* div = builder.AddInstruction(
1226 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
1227
1228 auto computation = m->AddEntryComputation(builder.Build());
1229 HloInstruction* root = computation->root_instruction();
1230 EXPECT_EQ(root, div);
1231 AlgebraicSimplifier simplifier(default_options_);
1232 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1233 root = computation->root_instruction();
1234 EXPECT_EQ(root, param0);
1235 }
1236
1237 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)1238 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
1239 auto m = CreateNewVerifiedModule();
1240 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1241 Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
1242 HloComputation::Builder builder(TestName());
1243 HloInstruction* param0 = builder.AddInstruction(
1244 HloInstruction::CreateParameter(0, r2c64, "param0"));
1245 HloInstruction* real = builder.AddInstruction(
1246 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
1247 HloInstruction* imag = builder.AddInstruction(
1248 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
1249 HloInstruction* cplx = builder.AddInstruction(
1250 HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
1251
1252 auto computation = m->AddEntryComputation(builder.Build());
1253 HloInstruction* root = computation->root_instruction();
1254 EXPECT_EQ(root, cplx);
1255 AlgebraicSimplifier simplifier(default_options_);
1256 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1257 root = computation->root_instruction();
1258 EXPECT_EQ(root, param0);
1259 }
1260
1261 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)1262 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
1263 auto m = CreateNewVerifiedModule();
1264 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1265 HloComputation::Builder builder(TestName());
1266 HloInstruction* param0 = builder.AddInstruction(
1267 HloInstruction::CreateParameter(0, r2f32, "param0"));
1268 HloInstruction* param1 = builder.AddInstruction(
1269 HloInstruction::CreateParameter(1, r2f32, "param1"));
1270 HloInstruction* cplx = builder.AddInstruction(
1271 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1272 HloOpcode::kComplex, param0, param1));
1273 HloInstruction* real = builder.AddInstruction(
1274 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
1275
1276 auto computation = m->AddEntryComputation(builder.Build());
1277 HloInstruction* root = computation->root_instruction();
1278 EXPECT_EQ(root, real);
1279 AlgebraicSimplifier simplifier(default_options_);
1280 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1281 root = computation->root_instruction();
1282 EXPECT_EQ(root, param0);
1283 }
1284
1285 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)1286 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
1287 auto m = CreateNewVerifiedModule();
1288 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1289 HloComputation::Builder builder(TestName());
1290 HloInstruction* param0 = builder.AddInstruction(
1291 HloInstruction::CreateParameter(0, r2f32, "param0"));
1292 HloInstruction* param1 = builder.AddInstruction(
1293 HloInstruction::CreateParameter(1, r2f32, "param1"));
1294 HloInstruction* cplx = builder.AddInstruction(
1295 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1296 HloOpcode::kComplex, param0, param1));
1297 HloInstruction* imag = builder.AddInstruction(
1298 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1299
1300 auto computation = m->AddEntryComputation(builder.Build());
1301 HloInstruction* root = computation->root_instruction();
1302 EXPECT_EQ(root, imag);
1303 AlgebraicSimplifier simplifier(default_options_);
1304 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1305 root = computation->root_instruction();
1306 EXPECT_EQ(root, param1);
1307 }
1308
1309 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1310 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1311 auto m = CreateNewVerifiedModule();
1312 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1313 HloComputation::Builder builder(TestName());
1314 HloInstruction* param0 = builder.AddInstruction(
1315 HloInstruction::CreateParameter(0, r0f32, "param0"));
1316 HloInstruction* param1 = builder.AddInstruction(
1317 HloInstruction::CreateParameter(1, r0f32, "param1"));
1318 HloInstruction* param2 = builder.AddInstruction(
1319 HloInstruction::CreateParameter(2, r0f32, "param2"));
1320 HloInstruction* tuple =
1321 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1322 HloInstruction* get = builder.AddInstruction(
1323 HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1324 HloInstruction* add = builder.AddInstruction(
1325 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1326
1327 auto computation = m->AddEntryComputation(builder.Build());
1328 HloInstruction* root = computation->root_instruction();
1329 EXPECT_EQ(root, add);
1330 AlgebraicSimplifier simplifier(default_options_);
1331 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1332 root = computation->root_instruction();
1333 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1334 }
1335
1336 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1337 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1338 auto m = CreateNewVerifiedModule();
1339 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1340 HloComputation::Builder builder(TestName());
1341 HloInstruction* param0 = builder.AddInstruction(
1342 HloInstruction::CreateParameter(0, r0f32, "param0"));
1343 HloInstruction* param1 = builder.AddInstruction(
1344 HloInstruction::CreateParameter(1, r0f32, "param1"));
1345 HloInstruction* exp0 = builder.AddInstruction(
1346 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1347 HloInstruction* exp1 = builder.AddInstruction(
1348 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1349 builder.AddInstruction(
1350 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1351
1352 auto computation = m->AddEntryComputation(builder.Build());
1353
1354 EXPECT_THAT(
1355 computation->root_instruction(),
1356 GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1357
1358 AlgebraicSimplifier simplifier(default_options_);
1359 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1360
1361 EXPECT_THAT(
1362 computation->root_instruction(),
1363 GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1364 }
1365
1366 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1367 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1368 auto m = CreateNewVerifiedModule();
1369 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1370 HloComputation::Builder builder(TestName());
1371 HloInstruction* param0 = builder.AddInstruction(
1372 HloInstruction::CreateParameter(0, r0f32, "param0"));
1373 HloInstruction* param1 = builder.AddInstruction(
1374 HloInstruction::CreateParameter(1, r0f32, "param1"));
1375 HloInstruction* exp0 = builder.AddInstruction(
1376 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1377 HloInstruction* exp1 = builder.AddInstruction(
1378 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1379 builder.AddInstruction(
1380 HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1381
1382 auto computation = m->AddEntryComputation(builder.Build());
1383
1384 EXPECT_THAT(computation->root_instruction(),
1385 GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1386 m::Exp(m::Parameter(1)))));
1387
1388 AlgebraicSimplifier simplifier(default_options_);
1389 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1390
1391 EXPECT_THAT(computation->root_instruction(),
1392 GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1393 }
1394
1395 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1396 TEST_F(AlgebraicSimplifierTest, PowExp) {
1397 auto m = CreateNewVerifiedModule();
1398 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1399 HloComputation::Builder builder(TestName());
1400 HloInstruction* param0 = builder.AddInstruction(
1401 HloInstruction::CreateParameter(0, r0f32, "param0"));
1402 HloInstruction* param1 = builder.AddInstruction(
1403 HloInstruction::CreateParameter(1, r0f32, "param1"));
1404 HloInstruction* exp0 = builder.AddInstruction(
1405 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1406 builder.AddInstruction(
1407 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1408
1409 auto computation = m->AddEntryComputation(builder.Build());
1410
1411 EXPECT_THAT(computation->root_instruction(),
1412 GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1413
1414 AlgebraicSimplifier simplifier(default_options_);
1415 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1416
1417 EXPECT_THAT(
1418 computation->root_instruction(),
1419 GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1420 }
1421
1422 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1423 TEST_F(AlgebraicSimplifierTest, LnPow) {
1424 auto m = CreateNewVerifiedModule();
1425 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1426 HloComputation::Builder builder(TestName());
1427 HloInstruction* param0 = builder.AddInstruction(
1428 HloInstruction::CreateParameter(0, r0f32, "param0"));
1429 HloInstruction* param1 = builder.AddInstruction(
1430 HloInstruction::CreateParameter(1, r0f32, "param1"));
1431 HloInstruction* pow = builder.AddInstruction(
1432 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1433 builder.AddInstruction(
1434 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1435
1436 auto computation = m->AddEntryComputation(builder.Build());
1437
1438 EXPECT_THAT(computation->root_instruction(),
1439 GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1440
1441 AlgebraicSimplifier simplifier(default_options_);
1442 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1443
1444 EXPECT_THAT(computation->root_instruction(),
1445 GmockMatch(m::Multiply(m::Log(m::Abs(m::Parameter(0))),
1446 m::Parameter(1))));
1447 }
1448
TEST_F(AlgebraicSimplifierTest,LnSqrt)1449 TEST_F(AlgebraicSimplifierTest, LnSqrt) {
1450 auto m = CreateNewVerifiedModule();
1451 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1452 HloComputation::Builder builder(TestName());
1453 HloInstruction* param0 = builder.AddInstruction(
1454 HloInstruction::CreateParameter(0, r0f32, "param0"));
1455 HloInstruction* sqrt = builder.AddInstruction(
1456 HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
1457 builder.AddInstruction(
1458 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, sqrt));
1459
1460 auto computation = m->AddEntryComputation(builder.Build());
1461
1462 EXPECT_THAT(computation->root_instruction(),
1463 GmockMatch(m::Log(m::Sqrt(m::Parameter(0)))));
1464
1465 AlgebraicSimplifier simplifier(default_options_);
1466 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1467
1468 EXPECT_THAT(
1469 computation->root_instruction(),
1470 GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::ConstantScalar(0.5))));
1471 }
1472
TEST_F(AlgebraicSimplifierTest,LnRsqrt)1473 TEST_F(AlgebraicSimplifierTest, LnRsqrt) {
1474 auto m = CreateNewVerifiedModule();
1475 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1476 HloComputation::Builder builder(TestName());
1477 HloInstruction* param0 = builder.AddInstruction(
1478 HloInstruction::CreateParameter(0, r0f32, "param0"));
1479 HloInstruction* rsqrt = builder.AddInstruction(
1480 HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
1481 builder.AddInstruction(
1482 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, rsqrt));
1483
1484 auto computation = m->AddEntryComputation(builder.Build());
1485
1486 EXPECT_THAT(computation->root_instruction(),
1487 GmockMatch(m::Log(m::Rsqrt(m::Parameter(0)))));
1488
1489 AlgebraicSimplifier simplifier(default_options_);
1490 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1491
1492 EXPECT_THAT(computation->root_instruction(),
1493 GmockMatch(m::Multiply(m::Log(m::Parameter(0)),
1494 m::ConstantScalar(-0.5))));
1495 }
1496
1497 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1498 TEST_F(AlgebraicSimplifierTest, LnExp) {
1499 auto m = CreateNewVerifiedModule();
1500 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1501 HloComputation::Builder builder(TestName());
1502 HloInstruction* param0 = builder.AddInstruction(
1503 HloInstruction::CreateParameter(0, r0f32, "param0"));
1504 HloInstruction* exp0 = builder.AddInstruction(
1505 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1506 builder.AddInstruction(
1507 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1508
1509 auto computation = m->AddEntryComputation(builder.Build());
1510
1511 EXPECT_THAT(computation->root_instruction(),
1512 GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1513
1514 AlgebraicSimplifier simplifier(default_options_);
1515 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1516
1517 EXPECT_EQ(computation->root_instruction(), param0);
1518 }
1519
1520 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1521 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1522 auto m = CreateNewVerifiedModule();
1523 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1524 HloComputation::Builder builder(TestName());
1525 HloInstruction* param0 = builder.AddInstruction(
1526 HloInstruction::CreateParameter(0, r0f32, "param0"));
1527 HloInstruction* param1 = builder.AddInstruction(
1528 HloInstruction::CreateParameter(1, r0f32, "param1"));
1529 HloInstruction* exp0 = builder.AddInstruction(
1530 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1531 HloInstruction* exp1 = builder.AddInstruction(
1532 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1533 HloInstruction* div = builder.AddInstruction(
1534 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1535 builder.AddInstruction(
1536 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1537
1538 auto computation = m->AddEntryComputation(builder.Build());
1539
1540 EXPECT_THAT(computation->root_instruction(),
1541 GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1542 m::Exp(m::Parameter(1))))));
1543
1544 AlgebraicSimplifier simplifier(default_options_);
1545 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1546
1547 EXPECT_THAT(computation->root_instruction(),
1548 GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1549 }
1550
1551 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1552 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1553 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1554 auto m = CreateNewVerifiedModule();
1555 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1556 HloComputation::Builder builder(TestName());
1557 HloInstruction* param0 = builder.AddInstruction(
1558 HloInstruction::CreateParameter(0, r0f32, "param0"));
1559 HloInstruction* zero = builder.AddInstruction(
1560 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1561 builder.AddInstruction(
1562 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1563
1564 auto computation = m->AddEntryComputation(builder.Build());
1565
1566 EXPECT_THAT(computation->root_instruction(),
1567 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1568
1569 AlgebraicSimplifier simplifier(default_options_);
1570 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1571
1572 HloInstruction* root = computation->root_instruction();
1573 EXPECT_THAT(root, GmockMatch(m::Constant()));
1574 EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1575 }
1576
1577 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1578 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1579 auto m = CreateNewVerifiedModule();
1580 Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1581 HloComputation::Builder builder(TestName());
1582 HloInstruction* param0 = builder.AddInstruction(
1583 HloInstruction::CreateParameter(0, r1f32, "param0"));
1584 HloInstruction* zero = builder.AddInstruction(
1585 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1586 builder.AddInstruction(
1587 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1588
1589 auto computation = m->AddEntryComputation(builder.Build());
1590
1591 EXPECT_THAT(computation->root_instruction(),
1592 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1593
1594 AlgebraicSimplifier simplifier(default_options_);
1595 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1596
1597 HloInstruction* root = computation->root_instruction();
1598 EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1599 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1600 << ShapeUtil::HumanString(root->shape());
1601 EXPECT_EQ(root->dimensions().size(), 0);
1602 EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1603 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1604 }
1605
1606 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1607 TEST_F(AlgebraicSimplifierTest, Pow1) {
1608 auto m = CreateNewVerifiedModule();
1609 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1610 HloComputation::Builder builder(TestName());
1611 HloInstruction* param0 = builder.AddInstruction(
1612 HloInstruction::CreateParameter(0, r0f32, "param0"));
1613 HloInstruction* one = builder.AddInstruction(
1614 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1615 builder.AddInstruction(
1616 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1617
1618 auto computation = m->AddEntryComputation(builder.Build());
1619
1620 EXPECT_THAT(computation->root_instruction(),
1621 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1622
1623 AlgebraicSimplifier simplifier(default_options_);
1624 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1625
1626 EXPECT_EQ(computation->root_instruction(), param0);
1627 }
1628
1629 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1630 TEST_F(AlgebraicSimplifierTest, Pow2) {
1631 auto m = CreateNewVerifiedModule();
1632 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1633 HloComputation::Builder builder(TestName());
1634 HloInstruction* param0 = builder.AddInstruction(
1635 HloInstruction::CreateParameter(0, r0f32, "param0"));
1636 HloInstruction* two = builder.AddInstruction(
1637 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1638 builder.AddInstruction(
1639 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1640
1641 auto computation = m->AddEntryComputation(builder.Build());
1642
1643 EXPECT_THAT(computation->root_instruction(),
1644 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1645
1646 AlgebraicSimplifier simplifier(default_options_);
1647 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1648
1649 EXPECT_THAT(computation->root_instruction(),
1650 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1651 }
1652
1653 // Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest,Pow3)1654 TEST_F(AlgebraicSimplifierTest, Pow3) {
1655 auto m = CreateNewVerifiedModule();
1656 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1657 HloComputation::Builder builder(TestName());
1658 HloInstruction* param0 = builder.AddInstruction(
1659 HloInstruction::CreateParameter(0, r0f32, "param0"));
1660 HloInstruction* three = builder.AddInstruction(
1661 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
1662 builder.AddInstruction(
1663 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
1664
1665 auto computation = m->AddEntryComputation(builder.Build());
1666
1667 EXPECT_THAT(computation->root_instruction(),
1668 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
1669
1670 AlgebraicSimplifier simplifier(default_options_);
1671 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1672
1673 EXPECT_THAT(
1674 computation->root_instruction(),
1675 GmockMatch(m::Multiply(m::Parameter(0),
1676 m::Multiply(m::Parameter(0), m::Parameter(0)))));
1677 }
1678
1679 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1680 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1681 auto m = CreateNewVerifiedModule();
1682 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1683 HloComputation::Builder builder(TestName());
1684 HloInstruction* param0 = builder.AddInstruction(
1685 HloInstruction::CreateParameter(0, r0f32, "param0"));
1686 HloInstruction* negative_one = builder.AddInstruction(
1687 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1688 builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1689 param0, negative_one));
1690
1691 auto computation = m->AddEntryComputation(builder.Build());
1692
1693 EXPECT_THAT(computation->root_instruction(),
1694 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1695
1696 AlgebraicSimplifier simplifier(default_options_);
1697 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1698
1699 HloInstruction* root = computation->root_instruction();
1700 EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
1701 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1702 }
1703
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1704 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1705 auto m = CreateNewVerifiedModule();
1706 auto builder = HloComputation::Builder(TestName());
1707 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1708 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1709
1710 HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1711 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1712
1713 ConvolutionDimensionNumbers dnums;
1714 dnums.set_input_batch_dimension(0);
1715 dnums.add_input_spatial_dimensions(1);
1716 dnums.set_input_feature_dimension(2);
1717
1718 dnums.set_output_batch_dimension(0);
1719 dnums.add_output_spatial_dimensions(1);
1720 dnums.set_output_feature_dimension(2);
1721
1722 dnums.add_kernel_spatial_dimensions(0);
1723 dnums.set_kernel_input_feature_dimension(1);
1724 dnums.set_kernel_output_feature_dimension(2);
1725 Window window;
1726 WindowDimension* dim = window.add_dimensions();
1727 dim->set_size(3);
1728 dim->set_padding_low(0);
1729 dim->set_padding_high(0);
1730 dim->set_stride(1);
1731 dim->set_window_dilation(1);
1732 dim->set_base_dilation(1);
1733 dim->set_window_reversal(false);
1734 // Create add computation.
1735 builder.AddInstruction(HloInstruction::CreateConvolve(
1736 ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1737 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1738 m->AddEntryComputation(builder.Build());
1739 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1740 EXPECT_THAT(m->entry_computation()->root_instruction(),
1741 GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1742 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1743 EXPECT_THAT(m->entry_computation()->root_instruction(),
1744 GmockMatch(m::Broadcast(m::Constant())));
1745 }
1746
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1747 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1748 auto m = CreateNewVerifiedModule();
1749 auto builder = HloComputation::Builder(TestName());
1750 HloInstruction* param =
1751 builder.AddInstruction(HloInstruction::CreateParameter(
1752 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1753 Window window;
1754 for (int64 i = 0; i < 4; ++i) {
1755 WindowDimension* dim = window.add_dimensions();
1756 // Makes 1x2x3x1 window.
1757 dim->set_size((i % 3) + 1);
1758 dim->set_stride(1);
1759 dim->set_padding_low(0);
1760 dim->set_padding_high(0);
1761 dim->set_window_dilation(1);
1762 dim->set_base_dilation(1);
1763 }
1764 // Create add computation.
1765 HloComputation* add_computation = nullptr;
1766 {
1767 HloComputation::Builder builder(TestName() + ".add");
1768 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1769 HloInstruction* p0 = builder.AddInstruction(
1770 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1771 HloInstruction* p1 = builder.AddInstruction(
1772 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1773 builder.AddInstruction(
1774 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1775 add_computation = m->AddEmbeddedComputation(builder.Build());
1776 }
1777 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1778 ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1779 builder.AddInstruction(
1780 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1781 window, add_computation));
1782 m->AddEntryComputation(builder.Build());
1783 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1784 EXPECT_THAT(m->entry_computation()->root_instruction(),
1785 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1786 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1787 EXPECT_THAT(
1788 m->entry_computation()->root_instruction(),
1789 GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1790 }
1791
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1792 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1793 auto m = CreateNewVerifiedModule();
1794 auto builder = HloComputation::Builder(TestName());
1795 HloInstruction* param =
1796 builder.AddInstruction(HloInstruction::CreateParameter(
1797 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1798 Window window;
1799 for (int64 i = 0; i < 2; ++i) {
1800 WindowDimension* dim = window.add_dimensions();
1801 dim->set_size(1);
1802 dim->set_padding_low(1);
1803 dim->set_padding_high(1);
1804 dim->set_window_dilation(1);
1805 dim->set_base_dilation(1);
1806 }
1807 // Create add computation.
1808 HloComputation* add_computation = nullptr;
1809 {
1810 HloComputation::Builder builder(TestName() + ".add");
1811 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1812 HloInstruction* p0 = builder.AddInstruction(
1813 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1814 HloInstruction* p1 = builder.AddInstruction(
1815 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1816 builder.AddInstruction(
1817 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1818 add_computation = m->AddEmbeddedComputation(builder.Build());
1819 }
1820 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1821 ShapeUtil::MakeShape(F32, {5, 2}), param,
1822 builder.AddInstruction(
1823 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1824 window, add_computation));
1825 m->AddEntryComputation(builder.Build());
1826 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1827 EXPECT_THAT(m->entry_computation()->root_instruction(),
1828 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1829 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1830 EXPECT_THAT(m->entry_computation()->root_instruction(),
1831 GmockMatch(m::Broadcast(m::Constant())));
1832 }
1833
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)1834 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
1835 auto m = CreateNewVerifiedModule();
1836 auto builder = HloComputation::Builder(TestName());
1837 HloInstruction* param =
1838 builder.AddInstruction(HloInstruction::CreateParameter(
1839 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1840 PaddingConfig padding;
1841 for (int i = 0; i < 2; ++i) {
1842 PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
1843 dimension->set_edge_padding_low(1);
1844 dimension->set_edge_padding_high(1);
1845 dimension->set_interior_padding(0);
1846 }
1847 builder.AddInstruction(HloInstruction::CreatePad(
1848 ShapeUtil::MakeShape(F32, {5, 2}), param,
1849 builder.AddInstruction(
1850 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
1851 padding));
1852 m->AddEntryComputation(builder.Build());
1853 EXPECT_THAT(m->entry_computation()->root_instruction(),
1854 GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
1855 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1856 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1857 EXPECT_THAT(m->entry_computation()->root_instruction(),
1858 GmockMatch(m::Broadcast(m::Constant())));
1859 }
1860
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)1861 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
1862 auto m = CreateNewVerifiedModule();
1863 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1864
1865 auto builder = HloComputation::Builder(TestName());
1866 auto op = builder.AddInstruction(HloInstruction::CreateParameter(
1867 0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
1868 auto reshape1 = builder.AddInstruction(
1869 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
1870 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1871 ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
1872 builder.AddInstruction(HloInstruction::CreateReshape(
1873 ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
1874
1875 auto computation = builder.Build();
1876 m->AddEntryComputation(std::move(computation));
1877
1878 EXPECT_THAT(m->entry_computation()->root_instruction(),
1879 GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
1880
1881 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1882 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1883
1884 EXPECT_THAT(m->entry_computation()->root_instruction(), op);
1885 }
1886
1887 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)1888 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
1889 auto m = CreateNewVerifiedModule();
1890 HloComputation::Builder builder(TestName());
1891 HloInstruction* input = builder.AddInstruction(
1892 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
1893 builder.AddInstruction(
1894 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
1895
1896 auto computation = m->AddEntryComputation(builder.Build());
1897
1898 EXPECT_THAT(computation->root_instruction(),
1899 GmockMatch(m::Convert(m::Op().Is(input))));
1900
1901 AlgebraicSimplifier simplifier(default_options_);
1902 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1903
1904 EXPECT_THAT(computation->root_instruction(), input);
1905 }
1906
1907 // Test that convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
1908 // $TYPE2 and convert(A, $TYP1) is an upcast.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairUpCast)1909 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairUpCast) {
1910 auto m = CreateNewVerifiedModule();
1911 HloComputation::Builder builder(TestName());
1912 HloInstruction* input =
1913 builder.AddInstruction(HloInstruction::CreateParameter(
1914 0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
1915 "param"));
1916 HloInstruction* convert_1 =
1917 builder.AddInstruction(HloInstruction::CreateConvert(
1918 ShapeUtil::ChangeElementType(input->shape(), F32), input));
1919 builder.AddInstruction(HloInstruction::CreateConvert(
1920 ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
1921
1922 auto computation = m->AddEntryComputation(builder.Build());
1923
1924 EXPECT_THAT(computation->root_instruction(),
1925 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1926
1927 AlgebraicSimplifier simplifier(default_options_);
1928 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1929
1930 EXPECT_THAT(computation->root_instruction(), input);
1931 }
1932
1933 // Test that convert(convert(A, $TYPE1), $TYPE2) is NOT simplified to A even if
1934 // A is of $TYPE2 since convert(A, $TYP1) is a downcast.
TEST_F(AlgebraicSimplifierTest,DoNotEliminateConvertPairDownCast)1935 TEST_F(AlgebraicSimplifierTest, DoNotEliminateConvertPairDownCast) {
1936 auto m = CreateNewVerifiedModule();
1937 HloComputation::Builder builder(TestName());
1938 HloInstruction* input =
1939 builder.AddInstruction(HloInstruction::CreateParameter(
1940 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
1941 "param"));
1942 HloInstruction* convert_1 =
1943 builder.AddInstruction(HloInstruction::CreateConvert(
1944 ShapeUtil::ChangeElementType(input->shape(), F16), input));
1945 builder.AddInstruction(HloInstruction::CreateConvert(
1946 ShapeUtil::ChangeElementType(convert_1->shape(), F32), convert_1));
1947
1948 auto computation = m->AddEntryComputation(builder.Build());
1949
1950 EXPECT_THAT(computation->root_instruction(),
1951 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1952 AlgebraicSimplifier simplifier(default_options_);
1953 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1954 EXPECT_THAT(computation->root_instruction(),
1955 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
1956 }
1957
1958 // Test that Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1),
1959 // $TYPE2)), convert(convert(A, $TYPE1), $TYPE2)) is simplified to
1960 // Tuple(convert(A, $TYPE1) , floor(A), A) showing a case where the first
1961 // convert has a fan-out.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairMultiOut)1962 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairMultiOut) {
1963 auto m = CreateNewVerifiedModule();
1964 HloComputation::Builder builder(TestName());
1965 HloInstruction* input =
1966 builder.AddInstruction(HloInstruction::CreateParameter(
1967 0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
1968 "param"));
1969 HloInstruction* convert_1 =
1970 builder.AddInstruction(HloInstruction::CreateConvert(
1971 ShapeUtil::ChangeElementType(input->shape(), F32), input));
1972 HloInstruction* convert_2 =
1973 builder.AddInstruction(HloInstruction::CreateConvert(
1974 ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
1975
1976 HloInstruction* floor = builder.AddInstruction(HloInstruction::CreateUnary(
1977 convert_2->shape(), HloOpcode::kFloor, convert_2));
1978
1979 // Collect all the reshapes into a tuple so they are not dead.
1980 builder.AddInstruction(
1981 HloInstruction::CreateTuple({convert_1, convert_2, floor}));
1982
1983 auto computation = m->AddEntryComputation(builder.Build());
1984 EXPECT_THAT(computation->root_instruction(),
1985 GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(convert_2),
1986 m::Op().Is(floor))));
1987
1988 AlgebraicSimplifier simplifier(default_options_);
1989 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1990
1991 EXPECT_THAT(computation->root_instruction(),
1992 GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(input),
1993 m::Floor(m::Op().Is(input)))));
1994 }
1995
1996 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)1997 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
1998 auto m = CreateNewVerifiedModule();
1999 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2000 HloComputation::Builder builder(TestName());
2001 HloInstruction* param0 = builder.AddInstruction(
2002 HloInstruction::CreateParameter(0, r0f32, "param0"));
2003 builder.AddInstruction(
2004 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2005
2006 auto computation = m->AddEntryComputation(builder.Build());
2007
2008 EXPECT_THAT(computation->root_instruction(),
2009 GmockMatch(m::Copy(m::Parameter(0))));
2010
2011 AlgebraicSimplifier simplifier(default_options_);
2012 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2013
2014 EXPECT_THAT(computation->root_instruction(), param0);
2015 }
2016
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)2017 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
2018 auto m = CreateNewVerifiedModule();
2019 HloComputation::Builder builder(TestName());
2020 HloInstruction* param =
2021 builder.AddInstruction(HloInstruction::CreateParameter(
2022 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2023 "param"));
2024 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2025 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2026 HloOpcode::kCopy, param));
2027 HloInstruction* reshape =
2028 builder.AddInstruction(HloInstruction::CreateReshape(
2029 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
2030 builder.AddInstruction(HloInstruction::CreateUnary(
2031 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
2032 HloOpcode::kCopy, reshape));
2033 auto computation = m->AddEntryComputation(builder.Build());
2034 EXPECT_THAT(computation->root_instruction(),
2035 GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
2036
2037 AlgebraicSimplifierOptions options;
2038 options.set_is_layout_sensitive(true);
2039 AlgebraicSimplifier simplifier(options);
2040 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2041 // Verify that the copy of reshape of copy is replaced.
2042 EXPECT_THAT(computation->root_instruction(),
2043 GmockMatch(m::Bitcast(m::Parameter(0))));
2044 }
2045
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)2046 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
2047 auto m = CreateNewVerifiedModule();
2048 HloComputation::Builder builder(TestName());
2049 HloInstruction* param =
2050 builder.AddInstruction(HloInstruction::CreateParameter(
2051 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2052 "param"));
2053 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2054 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2055 HloOpcode::kCopy, param));
2056 builder.AddInstruction(HloInstruction::CreateReshape(
2057 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
2058
2059 auto computation = m->AddEntryComputation(builder.Build());
2060 EXPECT_THAT(computation->root_instruction(),
2061 GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
2062
2063 AlgebraicSimplifierOptions options;
2064 options.set_is_layout_sensitive(true);
2065 AlgebraicSimplifier simplifier(options);
2066 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2067 // Verify that the copy of reshape of copy is replaced.
2068 EXPECT_THAT(computation->root_instruction(),
2069 GmockMatch(m::Bitcast(m::Parameter(0))));
2070 }
2071
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)2072 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
2073 auto m = CreateNewVerifiedModule();
2074 HloComputation::Builder builder(TestName());
2075 HloInstruction* param =
2076 builder.AddInstruction(HloInstruction::CreateParameter(
2077 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2078 "param"));
2079 builder.AddInstruction(HloInstruction::CreateUnary(
2080 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
2081 HloOpcode::kCopy, param));
2082 auto computation = m->AddEntryComputation(builder.Build());
2083 EXPECT_THAT(computation->root_instruction(),
2084 GmockMatch(m::Copy(m::Parameter(0))));
2085
2086 AlgebraicSimplifierOptions options(
2087 [](const Shape&, const Shape&) { return false; });
2088 options.set_is_layout_sensitive(true);
2089 AlgebraicSimplifier simplifier1(options);
2090 ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
2091 // Verify that the copy is not replaced.
2092 EXPECT_THAT(computation->root_instruction(),
2093 GmockMatch(m::Copy(m::Parameter(0))));
2094
2095 AlgebraicSimplifierOptions options2;
2096 options2.set_is_layout_sensitive(true);
2097 AlgebraicSimplifier simplifier2(options2);
2098 EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
2099 // Verify that the copy is replaced.
2100 EXPECT_THAT(computation->root_instruction(),
2101 GmockMatch(m::Bitcast(m::Parameter(0))));
2102 }
2103
2104 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)2105 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
2106 auto m = CreateNewVerifiedModule();
2107 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2108 HloComputation::Builder builder(TestName());
2109 HloInstruction* param0 = builder.AddInstruction(
2110 HloInstruction::CreateParameter(0, r1f32, "param0"));
2111 builder.AddInstruction(
2112 HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
2113
2114 auto computation = m->AddEntryComputation(builder.Build());
2115
2116 EXPECT_THAT(computation->root_instruction(),
2117 GmockMatch(m::Concatenate(m::Parameter(0))));
2118
2119 AlgebraicSimplifier simplifier(default_options_);
2120 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2121
2122 EXPECT_THAT(computation->root_instruction(), param0);
2123 }
2124
TEST_F(AlgebraicSimplifierTest,SliceReverse)2125 TEST_F(AlgebraicSimplifierTest, SliceReverse) {
2126 const char* const hlo_string = R"(
2127 HloModule module
2128
2129 ENTRY test {
2130 param = f32[6,7,32] parameter(0)
2131 constant = f32[] constant(0)
2132 pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2133 rev = f32[8,7,32] reverse(pad), dimensions={0,2}
2134 slice = f32[1,7,32] slice(rev), slice={[2:3:1], [0:7:1], [0:32:1]}
2135 ROOT tuple = (f32[1,7,32]) tuple(slice)
2136 })";
2137
2138 TF_ASSERT_OK_AND_ASSIGN(auto module,
2139 ParseAndReturnVerifiedModule(hlo_string));
2140 AlgebraicSimplifier simplifier(default_options_);
2141 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2142 HloComputation* computation = module->entry_computation();
2143 EXPECT_THAT(computation->root_instruction(),
2144 GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2145 const HloInstruction* slice =
2146 computation->root_instruction()->operand(0)->operand(0);
2147 EXPECT_TRUE(
2148 ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 7, 32})));
2149 // slice start,limit of 0th and 2nd dimensions are changed
2150 // while 1st dimension's slice start, limit remains the same since
2151 // it is not reversed.
2152 EXPECT_EQ(slice->slice_starts(0), 5);
2153 EXPECT_EQ(slice->slice_limits(0), 6);
2154 EXPECT_EQ(slice->slice_starts(1), 0);
2155 EXPECT_EQ(slice->slice_limits(1), 7);
2156 EXPECT_EQ(slice->slice_starts(2), 0);
2157 EXPECT_EQ(slice->slice_limits(2), 32);
2158 EXPECT_EQ(slice->slice_strides(0), 1);
2159 EXPECT_EQ(slice->slice_strides(1), 1);
2160 EXPECT_EQ(slice->slice_strides(2), 1);
2161 }
2162
TEST_F(AlgebraicSimplifierTest,SliceReverseNonUnitEvenOddStrides)2163 TEST_F(AlgebraicSimplifierTest, SliceReverseNonUnitEvenOddStrides) {
2164 const char* const hlo_string = R"(
2165 HloModule module
2166
2167 ENTRY test {
2168 param = f32[6,7,32] parameter(0)
2169 constant = f32[] constant(0)
2170 pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2171 rev = f32[8,7,32] reverse(pad), dimensions={0,1,2}
2172 slice = f32[1,2,7] slice(rev), slice={[2:3:2], [0:7:4], [0:32:5]}
2173 ROOT tuple = (f32[1,2,7]) tuple(slice)
2174 })";
2175 TF_ASSERT_OK_AND_ASSIGN(auto module,
2176 ParseAndReturnVerifiedModule(hlo_string));
2177
2178 AlgebraicSimplifier simplifier(default_options_);
2179 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2180 HloComputation* computation = module->entry_computation();
2181 EXPECT_THAT(computation->root_instruction(),
2182 GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2183 const HloInstruction* slice =
2184 computation->root_instruction()->operand(0)->operand(0);
2185 EXPECT_TRUE(
2186 ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 2, 7})));
2187 // slice start,limit of all dimensions are changed
2188 EXPECT_EQ(slice->slice_starts(0), 5);
2189 EXPECT_EQ(slice->slice_limits(0), 6);
2190 EXPECT_EQ(slice->slice_starts(1), 2);
2191 EXPECT_EQ(slice->slice_limits(1), 7);
2192 EXPECT_EQ(slice->slice_starts(2), 1);
2193 EXPECT_EQ(slice->slice_limits(2), 32);
2194 EXPECT_EQ(slice->slice_strides(0), 2);
2195 EXPECT_EQ(slice->slice_strides(1), 4);
2196 EXPECT_EQ(slice->slice_strides(2), 5);
2197 }
2198
2199 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)2200 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
2201 auto m = CreateNewVerifiedModule();
2202 const int kParamLength = 100;
2203 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2204 HloComputation::Builder builder(TestName());
2205 HloInstruction* param0 = builder.AddInstruction(
2206 HloInstruction::CreateParameter(0, r1f32, "param0"));
2207 HloInstruction* param1 = builder.AddInstruction(
2208 HloInstruction::CreateParameter(1, r1f32, "param1"));
2209 HloInstruction* empty_literal = builder.AddInstruction(
2210 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2211 HloInstruction* empty_slice =
2212 builder.AddInstruction(HloInstruction::CreateSlice(
2213 ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
2214 Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
2215 builder.AddInstruction(HloInstruction::CreateConcatenate(
2216 result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
2217
2218 auto computation = m->AddEntryComputation(builder.Build());
2219
2220 EXPECT_THAT(computation->root_instruction(),
2221 GmockMatch(m::Concatenate(
2222 m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
2223 m::Op().Is(empty_slice), m::Parameter(1))));
2224
2225 AlgebraicSimplifier simplifier(default_options_);
2226 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2227
2228 EXPECT_THAT(computation->root_instruction(),
2229 GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
2230 m::Parameter(1))));
2231 }
2232
2233 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)2234 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
2235 auto m = CreateNewVerifiedModule();
2236 const int kParamLength = 100;
2237 Shape r3f32 =
2238 ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
2239 HloComputation::Builder builder(TestName());
2240 HloInstruction* param0 = builder.AddInstruction(
2241 HloInstruction::CreateParameter(0, r3f32, "param0"));
2242 HloInstruction* param1 = builder.AddInstruction(
2243 HloInstruction::CreateParameter(1, r3f32, "param1"));
2244 HloInstruction* param2 = builder.AddInstruction(
2245 HloInstruction::CreateParameter(2, r3f32, "param2"));
2246 Shape concat_shape =
2247 ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
2248 HloInstruction* Concatenate =
2249 builder.AddInstruction(HloInstruction::CreateConcatenate(
2250 concat_shape, {param0, param1, param2}, 1));
2251 HloComputation* add_computation = nullptr;
2252 {
2253 HloComputation::Builder builder(TestName() + ".add");
2254 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2255 HloInstruction* p0 = builder.AddInstruction(
2256 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2257 HloInstruction* p1 = builder.AddInstruction(
2258 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2259 builder.AddInstruction(
2260 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2261 add_computation = m->AddEmbeddedComputation(builder.Build());
2262 }
2263 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
2264 Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
2265
2266 HloInstruction* zero = builder.AddInstruction(
2267 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
2268 builder.AddInstruction(HloInstruction::CreateReduce(
2269 reduce_shape, Concatenate, zero, {1, 2}, add_computation));
2270
2271 auto computation = m->AddEntryComputation(builder.Build());
2272
2273 AlgebraicSimplifier simplifier(default_options_);
2274 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2275
2276 EXPECT_THAT(
2277 computation->root_instruction(),
2278 GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
2279 m::Reduce(m::Parameter(1), m::Op().Is(zero))),
2280 m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
2281 }
2282
2283 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)2284 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
2285 auto m = CreateNewVerifiedModule();
2286 const int kParamLength = 100;
2287 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2288 HloComputation::Builder builder(TestName());
2289 HloInstruction* param0 = builder.AddInstruction(
2290 HloInstruction::CreateParameter(0, r1f32, "param0"));
2291 HloInstruction* empty_literal = builder.AddInstruction(
2292 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2293 HloInstruction* empty_slice =
2294 builder.AddInstruction(HloInstruction::CreateSlice(
2295 ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
2296 Shape result_shape = ShapeUtil::MakeShape(F32, {0});
2297 builder.AddInstruction(HloInstruction::CreateConcatenate(
2298 result_shape, {empty_literal, empty_slice}, 0));
2299
2300 auto computation = m->AddEntryComputation(builder.Build());
2301
2302 EXPECT_THAT(computation->root_instruction(),
2303 GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
2304 m::Op().Is(empty_slice))));
2305
2306 AlgebraicSimplifier simplifier(default_options_);
2307 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2308
2309 EXPECT_EQ(computation->root_instruction(), empty_literal);
2310 }
2311
2312 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)2313 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
2314 auto m = CreateNewVerifiedModule();
2315 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2316 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2317 HloComputation::Builder builder(TestName());
2318 HloInstruction* param0 = builder.AddInstruction(
2319 HloInstruction::CreateParameter(0, r1f32, "param0"));
2320 HloInstruction* param1 = builder.AddInstruction(
2321 HloInstruction::CreateParameter(1, r0f32, "param1"));
2322 HloInstruction* broadcast = builder.AddInstruction(
2323 HloInstruction::CreateBroadcast(r1f32, param1, {}));
2324 builder.AddInstruction(HloInstruction::CreateConcatenate(
2325 ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
2326
2327 auto computation = m->AddEntryComputation(builder.Build());
2328
2329 AlgebraicSimplifier simplifier(default_options_);
2330 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2331 EXPECT_THAT(computation->root_instruction(),
2332 GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
2333 }
2334
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)2335 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
2336 auto m = CreateNewVerifiedModule();
2337 Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
2338 Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90});
2339 HloComputation::Builder builder(TestName());
2340 HloInstruction* param0 = builder.AddInstruction(
2341 HloInstruction::CreateParameter(0, r2f32, "param0"));
2342 HloInstruction* param1 = builder.AddInstruction(
2343 HloInstruction::CreateParameter(1, r2f32, "param1"));
2344
2345 HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2346 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
2347 /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
2348
2349 // Cannot merge 'slice0' and 'slice1' because of different start indices in
2350 // dimension 0.
2351 HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2352 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
2353 /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
2354
2355 // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
2356 HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2357 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
2358 /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
2359
2360 // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
2361 HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2362 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
2363 /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
2364
2365 // Can merge 'slice3' and 'slice4'.
2366 HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
2367 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
2368 /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
2369
2370 // Can merge 'slice4' and 'slice5'.
2371 HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
2372 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
2373 /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
2374
2375 // Cannot merge 'slice5' and 'slice6' because of overlap.
2376 HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
2377 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
2378 /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
2379
2380 // Cannot merge 'slice6' and 'slice7' because of slicing from a different
2381 // parameter.
2382 HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
2383 ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
2384 /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
2385 // Can merge 'slice7' and 'slice8'.
2386 HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice(
2387 ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89},
2388 /*limit_indices=*/{100, 99}, /*strides=*/{1, 1}));
2389
2390 builder.AddInstruction(HloInstruction::CreateConcatenate(
2391 concat_shape,
2392 {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8},
2393 1));
2394 auto computation = m->AddEntryComputation(builder.Build());
2395
2396 AlgebraicSimplifier simplifier(default_options_);
2397 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2398 auto s = m::Slice(m::Parameter(0));
2399 EXPECT_THAT(
2400 computation->root_instruction(),
2401 GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
2402 // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
2403 // shape should have dimensions {50, 30}.
2404 EXPECT_TRUE(
2405 ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
2406 ShapeUtil::MakeShape(F32, {50, 30})));
2407 EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
2408
2409 // The operand 6 should be merge of 'slice7' and 'slice8', so its
2410 // shape should have dimensions {50, 20}
2411 EXPECT_TRUE(
2412 ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(),
2413 ShapeUtil::MakeShape(F32, {50, 20})));
2414 }
2415
2416 // Test that a simplification which changes layouts is not performed if layout
2417 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)2418 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
2419 auto m = CreateNewVerifiedModule();
2420 HloComputation::Builder builder(TestName());
2421 HloInstruction* param0 =
2422 builder.AddInstruction(HloInstruction::CreateParameter(
2423 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2424 HloInstruction* copy = builder.AddInstruction(
2425 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2426
2427 auto computation = m->AddEntryComputation(builder.Build());
2428
2429 // Set to different layouts.
2430 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2431 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
2432
2433 EXPECT_THAT(computation->root_instruction(),
2434 GmockMatch(m::Copy(m::Parameter(0))));
2435
2436 AlgebraicSimplifierOptions options;
2437 options.set_is_layout_sensitive(true);
2438 AlgebraicSimplifier simplifier(options);
2439 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2440
2441 // Copy has not been removed.
2442 EXPECT_THAT(computation->root_instruction(),
2443 GmockMatch(m::Copy(m::Parameter(0))));
2444 }
2445
2446 // Test that a simplification which preserves layouts is performed if layout
2447 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)2448 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
2449 auto m = CreateNewVerifiedModule();
2450 HloComputation::Builder builder(TestName());
2451 HloInstruction* param0 =
2452 builder.AddInstruction(HloInstruction::CreateParameter(
2453 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2454 HloInstruction* copy = builder.AddInstruction(
2455 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2456
2457 auto computation = m->AddEntryComputation(builder.Build());
2458
2459 // Set to same layouts.
2460 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2461 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2462
2463 EXPECT_THAT(computation->root_instruction(),
2464 GmockMatch(m::Copy(m::Parameter(0))));
2465
2466 AlgebraicSimplifierOptions options;
2467 options.set_is_layout_sensitive(true);
2468 AlgebraicSimplifier simplifier(options);
2469 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2470
2471 // Copy has been removed.
2472 EXPECT_THAT(computation->root_instruction(), param0);
2473 }
2474
2475 // Test that a reshape which could be replaced with a bitcast is not if
2476 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)2477 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
2478 auto m = CreateNewVerifiedModule();
2479 HloComputation::Builder builder(TestName());
2480 HloInstruction* param0 =
2481 builder.AddInstruction(HloInstruction::CreateParameter(
2482 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2483 HloInstruction* reshape =
2484 builder.AddInstruction(HloInstruction::CreateReshape(
2485 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2486
2487 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2488 *reshape->mutable_shape()->mutable_layout() =
2489 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2490
2491 auto computation = m->AddEntryComputation(builder.Build());
2492
2493 EXPECT_THAT(computation->root_instruction(),
2494 GmockMatch(m::Reshape(m::Parameter(0))));
2495
2496 AlgebraicSimplifierOptions options(
2497 [](const Shape&, const Shape&) { return false; });
2498 options.set_is_layout_sensitive(true);
2499 AlgebraicSimplifier simplifier(options);
2500 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2501
2502 // Reshape is not replaced with a bitcast.
2503 EXPECT_THAT(computation->root_instruction(),
2504 GmockMatch(m::Reshape(m::Parameter(0))));
2505 }
2506
2507 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)2508 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
2509 auto m = CreateNewVerifiedModule();
2510 HloComputation::Builder builder(TestName());
2511 HloInstruction* zero = builder.AddInstruction(
2512 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2513 HloInstruction* one = builder.AddInstruction(
2514 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
2515 HloInstruction* rng0 = builder.AddInstruction(
2516 HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
2517 RandomDistribution::RNG_UNIFORM, {zero, one}));
2518
2519 HloInstruction* transpose = builder.AddInstruction(
2520 HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
2521 Shape reshape_shape = builder
2522 .AddInstruction(HloInstruction::CreateReshape(
2523 ShapeUtil::MakeShape(F32, {4}), transpose))
2524 ->shape();
2525
2526 auto computation = m->AddEntryComputation(builder.Build());
2527
2528 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2529 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2530
2531 // Verify that reshape(transpose(rng)) is replace by a single rng of the
2532 // same shape as the reshape.
2533 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
2534 EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
2535 reshape_shape));
2536 }
2537
2538 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2539 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2540 auto m = CreateNewVerifiedModule();
2541 HloComputation::Builder builder(TestName());
2542 HloInstruction* param0 =
2543 builder.AddInstruction(HloInstruction::CreateParameter(
2544 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2545 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2546
2547 // Reshape which can be transformed into a bitcast.
2548 HloInstruction* transformable_reshape =
2549 builder.AddInstruction(HloInstruction::CreateReshape(
2550 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2551 *transformable_reshape->mutable_shape()->mutable_layout() =
2552 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2553
2554 // Reshape does not just add degenerate dimensions.
2555 HloInstruction* dimensions_wrong_reshape =
2556 builder.AddInstruction(HloInstruction::CreateReshape(
2557 ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2558 *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2559 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2560
2561 // Reshape has wrong layout.
2562 HloInstruction* layout_wrong_reshape =
2563 builder.AddInstruction(HloInstruction::CreateReshape(
2564 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2565 *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2566 LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2567
2568 // Collect all the reshapes into a tuple so they are not dead.
2569 builder.AddInstruction(HloInstruction::CreateTuple(
2570 {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2571
2572 auto computation = m->AddEntryComputation(builder.Build());
2573
2574 EXPECT_THAT(computation->root_instruction(),
2575 GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2576 m::Op().Is(dimensions_wrong_reshape),
2577 m::Op().Is(layout_wrong_reshape))));
2578
2579 AlgebraicSimplifierOptions options;
2580 options.set_is_layout_sensitive(true);
2581 AlgebraicSimplifier simplifier(options);
2582 simplifier.Run(m.get()).ValueOrDie();
2583
2584 // Verify that only the first reshape is replaced.
2585 EXPECT_THAT(
2586 computation->root_instruction(),
2587 GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
2588 m::Op().Is(layout_wrong_reshape))));
2589 }
2590
2591 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2592 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)2593 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
2594 auto m = CreateNewVerifiedModule();
2595 HloComputation::Builder builder(TestName());
2596
2597 // This add (param0 + 0) can be simplified.
2598 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2599 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2600 shape, HloOpcode::kAdd,
2601 builder.AddInstruction(
2602 HloInstruction::CreateParameter(0, shape, "param0")),
2603 builder.AddInstruction(HloInstruction::CreateConstant(
2604 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2605
2606 builder.AddInstruction(
2607 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
2608
2609 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2610 m->AddEntryComputation(builder.Build());
2611 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2612 }
2613
2614 // Regression test for a bug where if we failed to sink a reshape, we'd set the
2615 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)2616 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
2617 auto m = CreateNewVerifiedModule();
2618 HloComputation::Builder builder(TestName());
2619
2620 // This add (param0 + 0) can be simplified.
2621 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2622 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
2623 shape, HloOpcode::kAdd,
2624 builder.AddInstruction(
2625 HloInstruction::CreateParameter(0, shape, "param0")),
2626 builder.AddInstruction(HloInstruction::CreateConstant(
2627 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
2628
2629 builder.AddInstruction(
2630 HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
2631 /*broadcast_dimensions=*/{0, 1}));
2632
2633 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2634 m->AddEntryComputation(builder.Build());
2635 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2636 }
2637
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)2638 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
2639 auto m = CreateNewVerifiedModule();
2640 HloComputation::Builder builder(TestName());
2641 HloInstruction* param =
2642 builder.AddInstruction(HloInstruction::CreateParameter(
2643 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
2644 *param->mutable_shape()->mutable_layout() =
2645 LayoutUtil::MakeLayout({1, 2, 0, 3});
2646
2647 HloInstruction* transpose =
2648 builder.AddInstruction(HloInstruction::CreateTranspose(
2649 ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
2650 *transpose->mutable_shape()->mutable_layout() =
2651 LayoutUtil::MakeLayout({0, 1, 2, 3});
2652
2653 auto computation = m->AddEntryComputation(builder.Build());
2654
2655 EXPECT_THAT(computation->root_instruction(),
2656 GmockMatch(m::Transpose(m::Parameter(0))));
2657
2658 AlgebraicSimplifierOptions options;
2659 options.set_is_layout_sensitive(true);
2660 AlgebraicSimplifier simplifier(options);
2661 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2662
2663 // Verify that the transpose is replaced.
2664 EXPECT_THAT(computation->root_instruction(),
2665 GmockMatch(m::Bitcast(m::Parameter(0))));
2666 }
2667
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)2668 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
2669 auto m = CreateNewVerifiedModule();
2670 HloComputation::Builder builder(TestName());
2671 HloInstruction* param =
2672 builder.AddInstruction(HloInstruction::CreateParameter(
2673 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
2674 *param->mutable_shape()->mutable_layout() =
2675 LayoutUtil::MakeLayout({1, 2, 3, 0});
2676
2677 HloInstruction* transpose =
2678 builder.AddInstruction(HloInstruction::CreateTranspose(
2679 ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
2680 *transpose->mutable_shape()->mutable_layout() =
2681 LayoutUtil::MakeLayout({3, 1, 2, 0});
2682
2683 auto computation = m->AddEntryComputation(builder.Build());
2684
2685 EXPECT_THAT(computation->root_instruction(),
2686 GmockMatch(m::Transpose(m::Parameter(0))));
2687
2688 AlgebraicSimplifierOptions options;
2689 options.set_is_layout_sensitive(true);
2690 // Don't replace transposes with bitcasts.
2691 options.set_replace_transpose_with_bitcast(false);
2692 AlgebraicSimplifier simplifier_no_replace(options);
2693 ASSERT_FALSE(simplifier_no_replace.Run(m.get()).ValueOrDie());
2694
2695 // Replace transposes with bitcasts if possible.
2696 options.set_replace_transpose_with_bitcast(true);
2697 AlgebraicSimplifier simplifier(options);
2698 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2699
2700 // Verify that the transpose is replaced.
2701 EXPECT_THAT(computation->root_instruction(),
2702 GmockMatch(m::Bitcast(m::Parameter(0))));
2703 }
2704
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)2705 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
2706 auto m = CreateNewVerifiedModule();
2707 HloComputation::Builder builder(TestName());
2708 HloInstruction* param0 =
2709 builder.AddInstruction(HloInstruction::CreateParameter(
2710 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2711
2712 HloInstruction* reshape1 =
2713 builder.AddInstruction(HloInstruction::CreateReshape(
2714 ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
2715
2716 builder.AddInstruction(HloInstruction::CreateReshape(
2717 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
2718
2719 auto computation = m->AddEntryComputation(builder.Build());
2720
2721 EXPECT_THAT(computation->root_instruction(),
2722 GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
2723
2724 AlgebraicSimplifier simplifier(default_options_);
2725 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2726
2727 EXPECT_THAT(computation->root_instruction(),
2728 GmockMatch(m::Reshape(m::Parameter(0))));
2729 }
2730
TEST_F(AlgebraicSimplifierTest,CopiesMerged)2731 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
2732 auto m = CreateNewVerifiedModule();
2733 HloComputation::Builder builder(TestName());
2734 HloInstruction* param0 =
2735 builder.AddInstruction(HloInstruction::CreateParameter(
2736 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
2737 "param0"));
2738
2739 HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2740 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
2741 HloOpcode::kCopy, param0));
2742
2743 builder.AddInstruction(HloInstruction::CreateUnary(
2744 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
2745 HloOpcode::kCopy, copy1));
2746
2747 auto computation = m->AddEntryComputation(builder.Build());
2748
2749 EXPECT_THAT(computation->root_instruction(),
2750 GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
2751
2752 AlgebraicSimplifierOptions options;
2753 options.set_is_layout_sensitive(true);
2754 AlgebraicSimplifier simplifier(options);
2755 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2756
2757 EXPECT_THAT(computation->root_instruction(),
2758 GmockMatch(m::Copy(m::Parameter(0))));
2759 }
2760
TEST_F(AlgebraicSimplifierTest,TransposesMerged)2761 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
2762 auto m = CreateNewVerifiedModule();
2763 HloComputation::Builder builder(TestName());
2764 HloInstruction* param0 =
2765 builder.AddInstruction(HloInstruction::CreateParameter(
2766 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
2767
2768 HloInstruction* transpose1 =
2769 builder.AddInstruction(HloInstruction::CreateTranspose(
2770 ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
2771
2772 builder.AddInstruction(HloInstruction::CreateTranspose(
2773 ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
2774
2775 auto computation = m->AddEntryComputation(builder.Build());
2776
2777 EXPECT_THAT(computation->root_instruction(),
2778 GmockMatch(m::Transpose(m::Op().Is(transpose1))));
2779
2780 AlgebraicSimplifier simplifier(default_options_);
2781 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2782
2783 EXPECT_THAT(computation->root_instruction(),
2784 GmockMatch(m::Transpose(m::Parameter(0))));
2785 EXPECT_EQ(std::vector<int64>({2, 1, 0}),
2786 computation->root_instruction()->dimensions());
2787 }
2788
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcast)2789 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
2790 const char* hlo_string = R"(
2791 HloModule module
2792
2793 ENTRY test {
2794 p0 = f32[10,20] parameter(0)
2795 b = f32[10,30,20] broadcast(p0), dimensions={0,2}
2796 ROOT s = f32[5,5,5] slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
2797 }
2798 )";
2799 TF_ASSERT_OK_AND_ASSIGN(auto module,
2800 ParseAndReturnVerifiedModule(hlo_string));
2801
2802 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2803 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2804 auto root = module->entry_computation()->root_instruction();
2805 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
2806 }
2807
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastPreserveLayout)2808 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
2809 const char* hlo_string = R"(
2810 HloModule module
2811
2812 ENTRY test {
2813 p0 = f32[10,20] parameter(0)
2814 b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
2815 ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
2816 }
2817 )";
2818 TF_ASSERT_OK_AND_ASSIGN(auto module,
2819 ParseAndReturnVerifiedModule(hlo_string));
2820
2821 const Shape original_slice_shape =
2822 module->entry_computation()->root_instruction()->shape();
2823 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2824 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2825 auto root = module->entry_computation()->root_instruction();
2826 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
2827 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
2828 }
2829
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcast)2830 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
2831 const char* hlo_string = R"(
2832 HloModule module
2833
2834 ENTRY test {
2835 p0 = f32[10,20] parameter(0)
2836 i0 = s32[] parameter(1)
2837 i1 = s32[] parameter(2)
2838 i2 = s32[] parameter(3)
2839 b = f32[10,30,20] broadcast(p0), dimensions={0,2}
2840 ROOT ds = f32[5,5,5] dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
2841 }
2842 )";
2843 TF_ASSERT_OK_AND_ASSIGN(auto module,
2844 ParseAndReturnVerifiedModule(hlo_string));
2845
2846 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2847 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2848 auto root = module->entry_computation()->root_instruction();
2849 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
2850 m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
2851 }
2852
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcastPreserveLayout)2853 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
2854 const char* hlo_string = R"(
2855 HloModule module
2856
2857 ENTRY test {
2858 p0 = f32[10,20] parameter(0)
2859 i0 = s32[] parameter(1)
2860 i1 = s32[] parameter(2)
2861 i2 = s32[] parameter(3)
2862 b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
2863 ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
2864 }
2865 )";
2866 TF_ASSERT_OK_AND_ASSIGN(auto module,
2867 ParseAndReturnVerifiedModule(hlo_string));
2868
2869 const Shape original_dynslice_shape =
2870 module->entry_computation()->root_instruction()->shape();
2871 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2872 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2873 auto root = module->entry_computation()->root_instruction();
2874 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
2875 m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
2876 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
2877 }
2878
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)2879 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
2880 const char* hlo_string = R"(
2881 HloModule module
2882
2883 ENTRY test {
2884 param = f32[10] parameter(0)
2885 reshaped = f32[1,1,10] reshape(f32[10] param)
2886 transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
2887 ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
2888 }
2889 )";
2890 TF_ASSERT_OK_AND_ASSIGN(auto module,
2891 ParseAndReturnVerifiedModule(hlo_string));
2892
2893 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2894 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2895 auto root = module->entry_computation()->root_instruction();
2896 EXPECT_THAT(root, GmockMatch(m::Parameter()));
2897 }
2898
2899 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)2900 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
2901 auto m = CreateNewVerifiedModule();
2902 HloComputation::Builder builder(TestName());
2903 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2904 0, ShapeUtil::MakeShape(F32, {5}), "param0"));
2905 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
2906 ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
2907 builder.AddInstruction(HloInstruction::CreateBroadcast(
2908 ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
2909
2910 auto computation = m->AddEntryComputation(builder.Build());
2911
2912 EXPECT_THAT(computation->root_instruction(),
2913 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2914
2915 AlgebraicSimplifier simplifier(default_options_);
2916 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2917
2918 EXPECT_THAT(computation->root_instruction(),
2919 GmockMatch(m::Broadcast(m::Parameter(0))));
2920 }
2921
2922 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)2923 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
2924 auto m = CreateNewVerifiedModule();
2925 HloComputation::Builder builder(TestName());
2926 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
2927 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
2928 auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
2929 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
2930 builder.AddInstruction(HloInstruction::CreateReshape(
2931 ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
2932
2933 auto computation = m->AddEntryComputation(builder.Build());
2934
2935 EXPECT_THAT(computation->root_instruction(),
2936 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2937
2938 AlgebraicSimplifier simplifier(default_options_);
2939 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2940
2941 EXPECT_THAT(computation->root_instruction(),
2942 GmockMatch(m::Broadcast(m::Parameter(0))));
2943 }
2944
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)2945 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
2946 auto m = CreateNewVerifiedModule();
2947 HloComputation::Builder builder(TestName());
2948 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2949 0, ShapeUtil::MakeShape(F32, {1}), "param"));
2950 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2951 ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
2952 builder.AddInstruction(
2953 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
2954
2955 auto computation = m->AddEntryComputation(builder.Build());
2956
2957 EXPECT_THAT(computation->root_instruction(),
2958 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2959
2960 AlgebraicSimplifier simplifier(default_options_);
2961 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2962
2963 EXPECT_THAT(computation->root_instruction(),
2964 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
2965 }
2966
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)2967 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
2968 auto m = CreateNewVerifiedModule();
2969 HloComputation::Builder builder(TestName());
2970 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2971 0, ShapeUtil::MakeShape(F32, {4}), "param"));
2972 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2973 ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
2974 builder.AddInstruction(HloInstruction::CreateReshape(
2975 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
2976
2977 HloComputation* computation = m->AddEntryComputation(builder.Build());
2978
2979 EXPECT_THAT(computation->root_instruction(),
2980 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
2981
2982 AlgebraicSimplifier simplifier(default_options_);
2983 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2984
2985 EXPECT_THAT(computation->root_instruction(),
2986 GmockMatch(m::Broadcast(m::Parameter(0))));
2987 EXPECT_THAT(computation->root_instruction()->dimensions(),
2988 ::testing::ElementsAre(3));
2989 }
2990
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)2991 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
2992 auto m = CreateNewVerifiedModule();
2993 HloComputation::Builder builder(TestName());
2994 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
2995 0, ShapeUtil::MakeShape(F32, {1}), "param"));
2996 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2997 ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
2998 builder.AddInstruction(HloInstruction::CreateReshape(
2999 ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
3000
3001 HloComputation* computation = m->AddEntryComputation(builder.Build());
3002
3003 EXPECT_THAT(computation->root_instruction(),
3004 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3005
3006 AlgebraicSimplifier simplifier(default_options_);
3007 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3008
3009 EXPECT_THAT(computation->root_instruction(),
3010 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3011 EXPECT_EQ(0, computation->root_instruction()->dimensions().size());
3012 }
3013
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)3014 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
3015 auto m = CreateNewVerifiedModule();
3016 HloComputation::Builder builder(TestName());
3017 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3018 0, ShapeUtil::MakeShape(F32, {4}), "param"));
3019 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3020 ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
3021 builder.AddInstruction(HloInstruction::CreateReshape(
3022 ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
3023
3024 HloComputation* computation = m->AddEntryComputation(builder.Build());
3025
3026 EXPECT_THAT(computation->root_instruction(),
3027 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3028
3029 AlgebraicSimplifier simplifier(default_options_);
3030 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3031
3032 EXPECT_THAT(computation->root_instruction(),
3033 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3034 }
3035
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)3036 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
3037 auto m = CreateNewVerifiedModule();
3038 HloComputation::Builder builder(TestName());
3039 auto iota = builder.AddInstruction(HloInstruction::CreateIota(
3040 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
3041 Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
3042 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3043
3044 auto computation = m->AddEntryComputation(builder.Build());
3045
3046 EXPECT_THAT(computation->root_instruction(),
3047 GmockMatch(m::Reshape(m::Iota())));
3048
3049 AlgebraicSimplifier simplifier(default_options_);
3050 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3051
3052 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3053 EXPECT_TRUE(
3054 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3055 }
3056
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadix)3057 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
3058 auto m = CreateNewVerifiedModule();
3059 HloComputation::Builder builder(TestName());
3060 auto iota = builder.AddInstruction(
3061 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
3062 Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
3063 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3064
3065 auto computation = m->AddEntryComputation(builder.Build());
3066
3067 EXPECT_THAT(computation->root_instruction(),
3068 GmockMatch(m::Reshape(m::Iota())));
3069
3070 AlgebraicSimplifier simplifier(default_options_);
3071 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3072
3073 EXPECT_THAT(computation->root_instruction(),
3074 GmockMatch(m::Add(
3075 m::Iota(),
3076 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3077 EXPECT_TRUE(
3078 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3079 }
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadixExtraDims)3080 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
3081 auto m = CreateNewVerifiedModule();
3082 HloComputation::Builder builder(TestName());
3083 auto iota = builder.AddInstruction(
3084 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
3085 Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
3086 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3087
3088 auto computation = m->AddEntryComputation(builder.Build());
3089
3090 EXPECT_THAT(computation->root_instruction(),
3091 GmockMatch(m::Reshape(m::Iota())));
3092
3093 AlgebraicSimplifier simplifier(default_options_);
3094 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3095
3096 EXPECT_THAT(
3097 computation->root_instruction(),
3098 GmockMatch(m::Add(
3099 m::Add(m::Iota(),
3100 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
3101 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3102 EXPECT_TRUE(
3103 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3104 }
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)3105 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
3106 auto m = CreateNewVerifiedModule();
3107 HloComputation::Builder builder(TestName());
3108 auto iota = builder.AddInstruction(
3109 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
3110 auto result_shape = iota->shape();
3111
3112 auto computation = m->AddEntryComputation(builder.Build());
3113
3114 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3115
3116 AlgebraicSimplifier simplifier(default_options_);
3117 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3118
3119 auto root = computation->root_instruction();
3120 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3121 EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
3122 EXPECT_TRUE(
3123 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3124 }
3125
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)3126 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
3127 auto m = CreateNewVerifiedModule();
3128 HloComputation::Builder builder(TestName());
3129 auto iota = builder.AddInstruction(
3130 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
3131 builder.AddInstruction(
3132 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
3133
3134 auto computation = m->AddEntryComputation(builder.Build());
3135
3136 EXPECT_THAT(computation->root_instruction(),
3137 GmockMatch(m::Reshape(m::Iota())));
3138
3139 AlgebraicSimplifier simplifier(default_options_);
3140 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3141
3142 EXPECT_THAT(computation->root_instruction(),
3143 GmockMatch(m::Reshape(m::Iota())));
3144 }
3145
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)3146 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
3147 auto m = CreateNewVerifiedModule();
3148 HloComputation::Builder builder(TestName());
3149 auto iota = builder.AddInstruction(
3150 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
3151 builder.AddInstruction(HloInstruction::CreateReshape(
3152 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
3153
3154 HloComputation* computation = m->AddEntryComputation(builder.Build());
3155
3156 EXPECT_THAT(computation->root_instruction(),
3157 GmockMatch(m::Reshape(m::Iota())));
3158
3159 AlgebraicSimplifier simplifier(default_options_);
3160 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3161
3162 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3163 EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
3164 ->iota_dimension(),
3165 3);
3166 }
3167
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)3168 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
3169 auto m = CreateNewVerifiedModule();
3170 HloComputation::Builder builder(TestName());
3171 auto iota = builder.AddInstruction(
3172 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
3173 builder.AddInstruction(HloInstruction::CreateReshape(
3174 ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
3175
3176 HloComputation* computation = m->AddEntryComputation(builder.Build());
3177
3178 EXPECT_THAT(computation->root_instruction(),
3179 GmockMatch(m::Reshape(m::Iota())));
3180
3181 AlgebraicSimplifier simplifier(default_options_);
3182 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3183
3184 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3185 const int64 iota_dim =
3186 Cast<HloIotaInstruction>(computation->root_instruction())
3187 ->iota_dimension();
3188 EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
3189 }
3190
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)3191 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
3192 auto m = CreateNewVerifiedModule();
3193 HloComputation::Builder builder(TestName());
3194 auto iota = builder.AddInstruction(
3195 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
3196 builder.AddInstruction(
3197 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
3198
3199 HloComputation* computation = m->AddEntryComputation(builder.Build());
3200
3201 EXPECT_THAT(computation->root_instruction(),
3202 GmockMatch(m::Reshape(m::Iota())));
3203
3204 AlgebraicSimplifier simplifier(default_options_);
3205 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3206
3207 EXPECT_THAT(computation->root_instruction(),
3208 GmockMatch(m::Reshape(m::Iota())));
3209 }
3210
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)3211 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
3212 HloComputation::Builder builder(TestName());
3213 HloInstruction* param =
3214 builder.AddInstruction(HloInstruction::CreateParameter(
3215 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3216 HloInstruction* zero = builder.AddInstruction(
3217 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3218 PaddingConfig no_padding;
3219 for (int i = 0; i < 2; ++i) {
3220 auto dimension = no_padding.add_dimensions();
3221 dimension->set_edge_padding_low(0);
3222 dimension->set_edge_padding_high(0);
3223 dimension->set_interior_padding(0);
3224 }
3225 builder.AddInstruction(HloInstruction::CreatePad(
3226 ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
3227
3228 auto module = CreateNewVerifiedModule();
3229 HloComputation* computation = module->AddEntryComputation(builder.Build());
3230
3231 EXPECT_THAT(computation->root_instruction(),
3232 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3233
3234 AlgebraicSimplifier simplifier(default_options_);
3235 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3236
3237 EXPECT_THAT(computation->root_instruction(), param);
3238 }
3239
TEST_F(AlgebraicSimplifierTest,RemoveNoopSliceOfPad)3240 TEST_F(AlgebraicSimplifierTest, RemoveNoopSliceOfPad) {
3241 HloComputation::Builder builder(TestName());
3242 HloInstruction* param =
3243 builder.AddInstruction(HloInstruction::CreateParameter(
3244 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3245 HloInstruction* zero = builder.AddInstruction(
3246 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3247 PaddingConfig no_padding;
3248 for (int i = 0; i < 2; ++i) {
3249 auto dimension = no_padding.add_dimensions();
3250 dimension->set_edge_padding_low(2);
3251 dimension->set_edge_padding_high(0);
3252 dimension->set_interior_padding(1);
3253 }
3254 auto pad = builder.AddInstruction(HloInstruction::CreatePad(
3255 ShapeUtil::MakeShape(F32, {5, 5}), param, zero, no_padding));
3256 builder.AddInstruction(HloInstruction::CreateSlice(
3257 ShapeUtil::MakeShape(F32, {2, 2}), pad, /*start_indices=*/{2, 2},
3258 /*limit_indices=*/{5, 5}, /*strides=*/{2, 2}));
3259
3260 auto module = CreateNewVerifiedModule();
3261 HloComputation* computation = module->AddEntryComputation(builder.Build());
3262
3263 EXPECT_THAT(computation->root_instruction(),
3264 GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3265
3266 AlgebraicSimplifier simplifier(default_options_);
3267 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3268
3269 EXPECT_THAT(computation->root_instruction(), param);
3270 }
3271
TEST_F(AlgebraicSimplifierTest,NegativePadding)3272 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
3273 // Verify that a pad instruction with negative padding is replaced with a
3274 // pad with non-negative padding followed by a slice.
3275 HloComputation::Builder builder(TestName());
3276 HloInstruction* param =
3277 builder.AddInstruction(HloInstruction::CreateParameter(
3278 0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3279 HloInstruction* zero = builder.AddInstruction(
3280 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3281 PaddingConfig padding;
3282 int64 low_padding[2] = {-1, -2};
3283 int64 high_padding[2] = {2, -3};
3284 for (int i = 0; i < 2; ++i) {
3285 auto dimension = padding.add_dimensions();
3286 dimension->set_edge_padding_low(low_padding[i]);
3287 dimension->set_edge_padding_high(high_padding[i]);
3288 dimension->set_interior_padding(0);
3289 }
3290 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3291 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3292
3293 auto module = CreateNewVerifiedModule();
3294 HloComputation* computation = module->AddEntryComputation(builder.Build());
3295
3296 AlgebraicSimplifier simplifier(default_options_);
3297
3298 auto has_negative_padding = [](const HloInstruction* pad) {
3299 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3300 if (padding_dimension.edge_padding_low() < 0 ||
3301 padding_dimension.edge_padding_high() < 0) {
3302 return true;
3303 }
3304 }
3305 return false;
3306 };
3307
3308 EXPECT_THAT(computation->root_instruction(),
3309 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3310 EXPECT_TRUE(has_negative_padding(pad));
3311
3312 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3313
3314 EXPECT_THAT(computation->root_instruction(),
3315 GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3316 EXPECT_FALSE(
3317 has_negative_padding(computation->root_instruction()->operand(0)));
3318 }
3319
TEST_F(AlgebraicSimplifierTest,CanDisableNegativePadding)3320 TEST_F(AlgebraicSimplifierTest, CanDisableNegativePadding) {
3321 // Verify that a pad instruction with negative padding is replaced with a
3322 // pad with non-negative padding followed by a slice.
3323 HloComputation::Builder builder(TestName());
3324 HloInstruction* param =
3325 builder.AddInstruction(HloInstruction::CreateParameter(
3326 0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3327 HloInstruction* zero = builder.AddInstruction(
3328 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3329 PaddingConfig padding;
3330 int64 low_padding[2] = {-1, -2};
3331 int64 high_padding[2] = {2, -3};
3332 for (int i = 0; i < 2; ++i) {
3333 auto dimension = padding.add_dimensions();
3334 dimension->set_edge_padding_low(low_padding[i]);
3335 dimension->set_edge_padding_high(high_padding[i]);
3336 dimension->set_interior_padding(0);
3337 }
3338 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3339 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3340
3341 auto module = CreateNewVerifiedModule();
3342 HloComputation* computation = module->AddEntryComputation(builder.Build());
3343
3344 // Verify that we can disable the negative padding optimization.
3345 AlgebraicSimplifierOptions opts = default_options_;
3346 opts.set_enable_negative_padding_replacement(false);
3347
3348 AlgebraicSimplifier simplifier(opts);
3349
3350 auto has_negative_padding = [](const HloInstruction* pad) {
3351 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3352 if (padding_dimension.edge_padding_low() < 0 ||
3353 padding_dimension.edge_padding_high() < 0) {
3354 return true;
3355 }
3356 }
3357 return false;
3358 };
3359
3360 EXPECT_THAT(computation->root_instruction(),
3361 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3362 EXPECT_TRUE(has_negative_padding(pad));
3363
3364 // Nothing has changed since the negative padding replacement is disabled.
3365 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3366 }
3367
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)3368 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
3369 // Verify that a pad instruction with interior padding on one-sized
3370 // dimensions, removes the interior padding.
3371 HloComputation::Builder builder(TestName());
3372 HloInstruction* param =
3373 builder.AddInstruction(HloInstruction::CreateParameter(
3374 0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
3375 HloInstruction* zero = builder.AddInstruction(
3376 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3377 PaddingConfig padding;
3378 for (int i = 0; i < 2; ++i) {
3379 auto dimension = padding.add_dimensions();
3380 dimension->set_edge_padding_low(3);
3381 dimension->set_edge_padding_high(3);
3382 dimension->set_interior_padding(i * 3);
3383 }
3384 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3385 ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
3386
3387 auto module = CreateNewVerifiedModule();
3388 HloComputation* computation = module->AddEntryComputation(builder.Build());
3389
3390 AlgebraicSimplifier simplifier(default_options_);
3391
3392 ASSERT_THAT(computation->root_instruction(),
3393 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3394 ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
3395
3396 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3397
3398 EXPECT_THAT(computation->root_instruction(),
3399 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3400 EXPECT_FALSE(
3401 HasInteriorPadding(computation->root_instruction()->padding_config()));
3402 }
3403
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)3404 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
3405 HloComputation::Builder builder(TestName());
3406 HloInstruction* param =
3407 builder.AddInstruction(HloInstruction::CreateParameter(
3408 0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
3409 builder.AddInstruction(
3410 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
3411
3412 auto module = CreateNewVerifiedModule();
3413 HloComputation* computation = module->AddEntryComputation(builder.Build());
3414
3415 EXPECT_THAT(computation->root_instruction(),
3416 GmockMatch(m::Reshape(m::Parameter(0))));
3417
3418 AlgebraicSimplifier simplifier(default_options_);
3419 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3420
3421 EXPECT_THAT(computation->root_instruction(), param);
3422 }
3423
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)3424 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
3425 HloComputation::Builder builder(TestName());
3426 const int64 dim0 = 2;
3427 const int64 dim1 = 3;
3428 HloInstruction* param =
3429 builder.AddInstruction(HloInstruction::CreateParameter(
3430 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3431 builder.AddInstruction(HloInstruction::CreateSlice(
3432 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
3433 /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
3434
3435 auto module = CreateNewVerifiedModule();
3436 HloComputation* computation = module->AddEntryComputation(builder.Build());
3437
3438 EXPECT_THAT(computation->root_instruction(),
3439 GmockMatch(m::Slice(m::Parameter(0))));
3440
3441 AlgebraicSimplifier simplifier(default_options_);
3442 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3443
3444 EXPECT_THAT(computation->root_instruction(), param);
3445 }
3446
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)3447 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
3448 HloComputation::Builder builder(TestName());
3449 const int64 dim0 = 11;
3450 const int64 dim1 = 12;
3451 HloInstruction* param =
3452 builder.AddInstruction(HloInstruction::CreateParameter(
3453 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3454 HloInstruction* original_slice =
3455 builder.AddInstruction(HloInstruction::CreateSlice(
3456 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
3457 /*start_indices=*/{1, 2},
3458 /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
3459
3460 builder.AddInstruction(HloInstruction::CreateSlice(
3461 ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
3462 /*start_indices=*/{2, 3},
3463 /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
3464 auto module = CreateNewVerifiedModule();
3465 HloComputation* computation = module->AddEntryComputation(builder.Build());
3466
3467 EXPECT_THAT(computation->root_instruction(),
3468 GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
3469
3470 AlgebraicSimplifier simplifier(default_options_);
3471 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3472
3473 EXPECT_THAT(computation->root_instruction(),
3474 GmockMatch(m::Slice(m::Parameter(0))));
3475 EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
3476 EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
3477 EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
3478 EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
3479 }
3480
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)3481 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
3482 HloComputation::Builder builder(TestName());
3483 const int64 dim0 = 11;
3484 const int64 dim1 = 12;
3485 HloInstruction* param =
3486 builder.AddInstruction(HloInstruction::CreateParameter(
3487 0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
3488 HloInstruction* broadcast =
3489 builder.AddInstruction(HloInstruction::CreateBroadcast(
3490 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
3491 builder.AddInstruction(HloInstruction::CreateSlice(
3492 ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
3493 /*start_indices=*/{0, 3},
3494 /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
3495 auto module = CreateNewVerifiedModule();
3496 HloComputation* computation = module->AddEntryComputation(builder.Build());
3497
3498 EXPECT_THAT(computation->root_instruction(),
3499 GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
3500
3501 AlgebraicSimplifier simplifier(default_options_);
3502 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3503
3504 EXPECT_THAT(computation->root_instruction(),
3505 GmockMatch(m::Broadcast(m::Parameter(0))));
3506 }
3507
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)3508 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
3509 HloComputation::Builder builder(TestName());
3510 const int64 dim0 = 11;
3511 const int64 dim1 = 12;
3512 const int64 dim2 = 13;
3513 HloInstruction* param =
3514 builder.AddInstruction(HloInstruction::CreateParameter(
3515 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
3516 HloInstruction* original_reshape =
3517 builder.AddInstruction(HloInstruction::CreateReshape(
3518 ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
3519
3520 builder.AddInstruction(HloInstruction::CreateSlice(
3521 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
3522 /*start_indices=*/{0, 0, 0},
3523 /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
3524 auto module = CreateNewVerifiedModule();
3525 HloComputation* computation = module->AddEntryComputation(builder.Build());
3526
3527 EXPECT_THAT(computation->root_instruction(),
3528 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3529
3530 AlgebraicSimplifier simplifier(default_options_);
3531 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3532
3533 EXPECT_THAT(computation->root_instruction(),
3534 GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
3535 }
3536
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)3537 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
3538 HloComputation::Builder builder(TestName());
3539 HloInstruction* param =
3540 builder.AddInstruction(HloInstruction::CreateParameter(
3541 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
3542 HloInstruction* original_reshape =
3543 builder.AddInstruction(HloInstruction::CreateReshape(
3544 ShapeUtil::MakeShape(F32, {3600, 512}), param));
3545
3546 builder.AddInstruction(HloInstruction::CreateSlice(
3547 ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
3548 /*start_indices=*/{0, 0},
3549 /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
3550 auto module = CreateNewVerifiedModule();
3551 HloComputation* computation = module->AddEntryComputation(builder.Build());
3552
3553 EXPECT_THAT(computation->root_instruction(),
3554 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3555
3556 AlgebraicSimplifier simplifier(default_options_);
3557 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3558 }
3559
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)3560 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
3561 auto builder = HloComputation::Builder(TestName());
3562 auto module = CreateNewVerifiedModule();
3563
3564 Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
3565 auto keys = builder.AddInstruction(
3566 HloInstruction::CreateParameter(0, keys_shape, "keys"));
3567 TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
3568 module.get())
3569 .status());
3570 HloComputation* computation = module->AddEntryComputation(builder.Build());
3571 AlgebraicSimplifier simplifier(default_options_);
3572 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3573 EXPECT_THAT(computation->root_instruction(), keys);
3574 }
3575
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)3576 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
3577 auto builder = HloComputation::Builder(TestName());
3578 auto module = CreateNewVerifiedModule();
3579
3580 Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
3581 Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
3582 auto keys = builder.AddInstruction(
3583 HloInstruction::CreateParameter(0, keys_shape, "keys"));
3584 auto values0 = builder.AddInstruction(
3585 HloInstruction::CreateParameter(1, values_shape, "values0"));
3586 auto values1 = builder.AddInstruction(
3587 HloInstruction::CreateParameter(2, values_shape, "values1"));
3588 TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
3589 {keys_shape, values_shape, values_shape}),
3590 {keys, values0, values1}, 0, /*is_stable=*/false,
3591 &builder, module.get())
3592 .status());
3593 HloComputation* computation = module->AddEntryComputation(builder.Build());
3594 AlgebraicSimplifier simplifier(default_options_);
3595 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3596 EXPECT_THAT(computation->root_instruction(),
3597 GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
3598 m::Op().Is(values1))));
3599 }
3600
3601 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)3602 TEST_F(AlgebraicSimplifierTest, AndTrue) {
3603 auto m = CreateNewVerifiedModule();
3604 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3605 HloComputation::Builder builder(TestName());
3606 HloInstruction* param0 = builder.AddInstruction(
3607 HloInstruction::CreateParameter(0, r0pred, "param0"));
3608 HloInstruction* const_true = builder.AddInstruction(
3609 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3610 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3611 param0, const_true));
3612
3613 auto computation = m->AddEntryComputation(builder.Build());
3614 HloInstruction* root = computation->root_instruction();
3615 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3616 AlgebraicSimplifier simplifier(default_options_);
3617 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3618 root = computation->root_instruction();
3619 EXPECT_EQ(root, param0);
3620 }
3621
3622 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)3623 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
3624 auto m = CreateNewVerifiedModule();
3625 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3626 HloComputation::Builder builder(TestName());
3627 HloInstruction* param0 = builder.AddInstruction(
3628 HloInstruction::CreateParameter(0, r0pred, "param0"));
3629 HloInstruction* const_true = builder.AddInstruction(
3630 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3631 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3632 const_true, param0));
3633
3634 auto computation = m->AddEntryComputation(builder.Build());
3635 HloInstruction* root = computation->root_instruction();
3636 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3637 AlgebraicSimplifier simplifier(default_options_);
3638 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3639 root = computation->root_instruction();
3640 EXPECT_EQ(root, param0);
3641 }
3642
3643 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)3644 TEST_F(AlgebraicSimplifierTest, AndFalse) {
3645 auto m = CreateNewVerifiedModule();
3646 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3647 HloComputation::Builder builder(TestName());
3648 HloInstruction* param0 = builder.AddInstruction(
3649 HloInstruction::CreateParameter(0, r0pred, "param0"));
3650 HloInstruction* const_false = builder.AddInstruction(
3651 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3652 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3653 param0, const_false));
3654
3655 auto computation = m->AddEntryComputation(builder.Build());
3656 HloInstruction* root = computation->root_instruction();
3657 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3658 AlgebraicSimplifier simplifier(default_options_);
3659 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3660 root = computation->root_instruction();
3661 EXPECT_EQ(root, const_false);
3662 }
3663
3664 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)3665 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
3666 auto m = CreateNewVerifiedModule();
3667 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3668 HloComputation::Builder builder(TestName());
3669 HloInstruction* param0 = builder.AddInstruction(
3670 HloInstruction::CreateParameter(0, r0pred, "param0"));
3671 HloInstruction* const_false = builder.AddInstruction(
3672 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3673 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
3674 const_false, param0));
3675
3676 auto computation = m->AddEntryComputation(builder.Build());
3677 HloInstruction* root = computation->root_instruction();
3678 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
3679 AlgebraicSimplifier simplifier(default_options_);
3680 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3681 root = computation->root_instruction();
3682 EXPECT_EQ(root, const_false);
3683 }
3684
3685 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)3686 TEST_F(AlgebraicSimplifierTest, OrTrue) {
3687 auto m = CreateNewVerifiedModule();
3688 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3689 HloComputation::Builder builder(TestName());
3690 HloInstruction* param0 = builder.AddInstruction(
3691 HloInstruction::CreateParameter(0, r0pred, "param0"));
3692 HloInstruction* const_true = builder.AddInstruction(
3693 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3694 builder.AddInstruction(
3695 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
3696
3697 auto computation = m->AddEntryComputation(builder.Build());
3698 HloInstruction* root = computation->root_instruction();
3699 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3700 AlgebraicSimplifier simplifier(default_options_);
3701 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3702 root = computation->root_instruction();
3703 EXPECT_EQ(root, const_true);
3704 }
3705
3706 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)3707 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
3708 auto m = CreateNewVerifiedModule();
3709 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3710 HloComputation::Builder builder(TestName());
3711 HloInstruction* param0 = builder.AddInstruction(
3712 HloInstruction::CreateParameter(0, r0pred, "param0"));
3713 HloInstruction* const_true = builder.AddInstruction(
3714 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3715 builder.AddInstruction(
3716 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
3717
3718 auto computation = m->AddEntryComputation(builder.Build());
3719 HloInstruction* root = computation->root_instruction();
3720 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3721 AlgebraicSimplifier simplifier(default_options_);
3722 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3723 root = computation->root_instruction();
3724 EXPECT_EQ(root, const_true);
3725 }
3726
3727 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)3728 TEST_F(AlgebraicSimplifierTest, OrFalse) {
3729 auto m = CreateNewVerifiedModule();
3730 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3731 HloComputation::Builder builder(TestName());
3732 HloInstruction* param0 = builder.AddInstruction(
3733 HloInstruction::CreateParameter(0, r0pred, "param0"));
3734 HloInstruction* const_false = builder.AddInstruction(
3735 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3736 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
3737 param0, const_false));
3738
3739 auto computation = m->AddEntryComputation(builder.Build());
3740 HloInstruction* root = computation->root_instruction();
3741 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3742 AlgebraicSimplifier simplifier(default_options_);
3743 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3744 root = computation->root_instruction();
3745 EXPECT_EQ(root, param0);
3746 }
3747
3748 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)3749 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
3750 auto m = CreateNewVerifiedModule();
3751 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
3752 HloComputation::Builder builder(TestName());
3753 HloInstruction* param0 = builder.AddInstruction(
3754 HloInstruction::CreateParameter(0, r0pred, "param0"));
3755 HloInstruction* const_false = builder.AddInstruction(
3756 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
3757 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
3758 const_false, param0));
3759
3760 auto computation = m->AddEntryComputation(builder.Build());
3761 HloInstruction* root = computation->root_instruction();
3762 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
3763 AlgebraicSimplifier simplifier(default_options_);
3764 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3765 root = computation->root_instruction();
3766 EXPECT_EQ(root, param0);
3767 }
3768
3769 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
3770 // convolution's Window.
3771 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3772 ConvPaddingTestcase(absl::string_view padding,
3773 absl::string_view orig_conv_window,
3774 absl::string_view expected_conv_window)
3775 : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
3776 /*pad_value=*/0) {}
3777
ConvPaddingTestcasexla::__anonac242c730111::ConvPaddingTestcase3778 ConvPaddingTestcase(absl::string_view padding,
3779 absl::string_view orig_conv_window,
3780 absl::string_view expected_conv_window, float pad_value)
3781 : padding(padding),
3782 orig_conv_window(orig_conv_window),
3783 expected_conv_window(expected_conv_window),
3784 pad_value(pad_value) {}
3785
ToStringxla::__anonac242c730111::ConvPaddingTestcase3786 string ToString() const {
3787 return absl::StrFormat(
3788 "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
3789 "pad_value=%f",
3790 padding, orig_conv_window, expected_conv_window, pad_value);
3791 }
3792
3793 string padding;
3794 string orig_conv_window;
3795 string expected_conv_window;
3796 float pad_value;
3797 };
3798
3799 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
3800 // computation that does
3801 //
3802 // conv(pad(param0, padding=padding), param1), window=orig_conv_window
3803 //
3804 // gets transformed by AlgebraicSimplifier to
3805 //
3806 // conv(param0, param1), window=expected_conv_window
3807 //
3808 // or, if expected_conv_window is the empty string, checks that
3809 // AlgebraicSimplifier does *not* transform the original convolution.
3810 class ConvInputPaddingTest
3811 : public AlgebraicSimplifierTest,
3812 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3813
3814 INSTANTIATE_TEST_SUITE_P(
3815 ConvInputPaddingTestCases, ConvInputPaddingTest,
3816 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3817 // Merge this edge padding into the conv.
3818 {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
3819 // Merge this edge padding with the conv's edge padding.
3820 {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
3821 // Merge this interior-padded kPad with the unpadded conv. The 3x6
3822 // interior padding gets transformed to 4x7 conv lhs dilation.
3823 {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
3824 // kPad has dilation on one dim, conv has it on the other; merge them.
3825 {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
3826 // kPad has dilation and edge padding on one dim, conv has them on the
3827 // other; merge them.
3828 {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
3829 "pad=0_1x3_0 lhs_dilate=2x10"},
3830
3831 // Don't transform if the pad value is nonzero.
3832 {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
3833
3834 // We refuse to transform the following because on some dimension, one
3835 // of the kPad and conv has dilation and the other has some sort of
3836 // padding.
3837 {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
3838 {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
3839 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3840 {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
3841 {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
3842 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
3843
3844 // We can't merge feature or batch padding into the conv.
3845 {"1_0x0_0x0_0x0_0", "", ""},
3846 {"0_0x1_0x0_0x0_0", "", ""},
3847 }));
3848
TEST_P(ConvInputPaddingTest,DoTest)3849 TEST_P(ConvInputPaddingTest, DoTest) {
3850 ConvPaddingTestcase testcase = GetParam();
3851
3852 // It would be better to put the testcase's ToString into the test name, but
3853 // gUnit has constraints on what can go into test names, and any reasonable
3854 // implementation of ToString() seems to violate them.
3855 SCOPED_TRACE(testcase.ToString());
3856
3857 auto builder = HloComputation::Builder(TestName());
3858 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3859 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
3860 "input"));
3861 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3862 LiteralUtil::CreateR0(testcase.pad_value)));
3863
3864 PaddingConfig padding_config =
3865 ParsePaddingConfig(testcase.padding).ValueOrDie();
3866 auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3867 ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
3868 padding_config)
3869 .ValueOrDie(),
3870 input, pad_value, padding_config));
3871
3872 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3873 1,
3874 ShapeUtil::MakeShape(
3875 F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
3876 "input"));
3877
3878 ConvolutionDimensionNumbers dnums =
3879 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3880 Window window =
3881 ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
3882 .ValueOrDie();
3883 builder.AddInstruction(HloInstruction::CreateConvolve(
3884 ShapeInference::InferConvolveShape(
3885 lhs_pad->shape(), filter->shape(),
3886 /*feature_group_count=*/1,
3887 /*batch_group_count=*/1, window, dnums,
3888 /*preferred_element_type=*/absl::nullopt)
3889 .ValueOrDie(),
3890 lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
3891 window, dnums, DefaultPrecisionConfig(2)));
3892 auto module = CreateNewVerifiedModule();
3893 module->AddEntryComputation(builder.Build());
3894
3895 AlgebraicSimplifier simplifier(default_options_);
3896 if (testcase.expected_conv_window.empty()) {
3897 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3898 } else {
3899 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3900 auto* conv = module->entry_computation()->root_instruction();
3901 SCOPED_TRACE(module->ToString());
3902 ASSERT_THAT(conv,
3903 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
3904 EXPECT_EQ(window_util::ToString(conv->window()),
3905 absl::StrCat("size=3x3 ", testcase.expected_conv_window));
3906 }
3907 }
3908
3909 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
3910 // computation that does
3911 //
3912 // conv(param0, pad(param1, padding=padding)), window=orig_conv_window
3913 //
3914 // gets transformed by AlgebraicSimplifier to
3915 //
3916 // conv(param0, param1), window=expected_conv_window
3917 //
3918 // or, if expected_conv_window is the empty string, checks that
3919 // AlgebraicSimplifier does *not* transform the original convolution.
3920 class ConvFilterPaddingTest
3921 : public AlgebraicSimplifierTest,
3922 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
3923
3924 INSTANTIATE_TEST_SUITE_P(
3925 ConvFilterPaddingTestCases, ConvFilterPaddingTest,
3926 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
3927 // Can only merge interior padding on the filter's spatial dimensions;
3928 // all
3929 // other paddings (edge padding and interior padding on the channel
3930 // dims)
3931 // should be rejected out of hand.
3932 {"1_0_0x0_0_0x0_0x0_0", "", ""},
3933 {"0_1_0x0_0_0x0_0x0_0", "", ""},
3934 {"0_0_1x0_0_0x0_0x0_0", "", ""},
3935 {"0_0_0x1_0_0x0_0x0_0", "", ""},
3936 {"0_0_0x0_1_0x0_0x0_0", "", ""},
3937 {"0_0_0x0_0_1x0_0x0_0", "", ""},
3938 {"0_0_0x0_0_0x1_0x0_0", "", ""},
3939 {"0_0_0x0_0_0x0_1x0_0", "", ""},
3940 {"0_0_0x0_0_0x0_0x1_0", "", ""},
3941 {"0_0_0x0_0_0x0_0x0_1", "", ""},
3942
3943 // Interior padding on channel dims can be merged into the conv, so long
3944 // as the conv and pad don't have interior padding on the same dim.
3945 {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
3946 {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
3947 {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
3948 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
3949 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
3950
3951 // Can't merge if for a given dim there's interior padding on both the
3952 // pad and conv.
3953 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
3954 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
3955
3956 // Don't transform if the pad value is nonzero.
3957 {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
3958 }));
3959
TEST_P(ConvFilterPaddingTest,DoIt)3960 TEST_P(ConvFilterPaddingTest, DoIt) {
3961 ConvPaddingTestcase testcase = GetParam();
3962
3963 // It would be better to put the testcase's ToString into the test name, but
3964 // gUnit has constraints on what can go into test names, and any reasonable
3965 // implementation of ToString() seems to violate them.
3966 SCOPED_TRACE(testcase.ToString());
3967
3968 auto builder = HloComputation::Builder(TestName());
3969 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
3970 LiteralUtil::CreateR0(testcase.pad_value)));
3971 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
3972 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
3973 "input"));
3974 PaddingConfig padding_config =
3975 ParsePaddingConfig(testcase.padding).ValueOrDie();
3976 auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
3977 ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
3978 padding_config)
3979 .ValueOrDie(),
3980 filter, pad_value, padding_config));
3981
3982 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
3983 0,
3984 ShapeUtil::MakeShape(
3985 F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
3986 "input"));
3987
3988 ConvolutionDimensionNumbers dnums =
3989 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
3990 Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
3991 rhs_pad->shape().dimensions(2),
3992 rhs_pad->shape().dimensions(3),
3993 testcase.orig_conv_window))
3994 .ValueOrDie();
3995
3996 // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
3997 // after the transformation.
3998 PrecisionConfig precision_config;
3999 precision_config.add_operand_precision(PrecisionConfig::HIGH);
4000 precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
4001
4002 builder.AddInstruction(HloInstruction::CreateConvolve(
4003 ShapeInference::InferConvolveShape(
4004 input->shape(), rhs_pad->shape(),
4005 /*feature_group_count=*/1,
4006 /*batch_group_count=*/1, window, dnums,
4007 /*preferred_element_type=*/absl::nullopt)
4008 .ValueOrDie(),
4009 input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
4010 window, dnums, precision_config));
4011
4012 auto module = CreateNewVerifiedModule();
4013 module->AddEntryComputation(builder.Build());
4014
4015 AlgebraicSimplifier simplifier(default_options_);
4016 if (testcase.expected_conv_window.empty()) {
4017 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4018 } else {
4019 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4020 auto* conv = module->entry_computation()->root_instruction();
4021 SCOPED_TRACE(module->ToString());
4022 ASSERT_THAT(conv,
4023 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4024 EXPECT_EQ(window_util::ToString(conv->window()),
4025 absl::StrFormat("size=%dx%d %s",
4026 conv->operand(1)->shape().dimensions(2),
4027 conv->operand(1)->shape().dimensions(3),
4028 testcase.expected_conv_window));
4029 EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
4030 ->precision_config()
4031 .operand_precision(),
4032 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
4033 }
4034 }
4035
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)4036 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
4037 struct ConvTestOptions {
4038 int in_batch = 10;
4039 int in_height = 2;
4040 int in_width = 2;
4041 int in_channels = 3;
4042 int f_width = 1;
4043 int f_height = 1;
4044 int f_output_channels = 10;
4045 int row_stride = 1;
4046 int row_padding = 0;
4047 int col_stride = 1;
4048 int col_padding = 0;
4049 bool input_minor_to_major_layout = false;
4050 bool filter_minor_to_major_layout = false;
4051 bool output_minor_to_major_layout = false;
4052
4053 const char* dim_order = "NHWC"; // can use chars NHWC in any order.
4054 const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order.
4055
4056 ConvTestOptions& Reset() {
4057 *this = ConvTestOptions();
4058 return *this;
4059 }
4060 };
4061
4062 ConvTestOptions options;
4063
4064 // Builds a convolution from <options> and runs algebraic simplification on
4065 // the computation. Returns a string description of the result of
4066 // simplification.
4067 auto build_and_simplify = [&]() -> string {
4068 HloComputation::Builder b(TestName());
4069
4070 Window window;
4071 auto* f_dim_1 = window.add_dimensions();
4072 f_dim_1->set_size(options.f_height);
4073 f_dim_1->set_stride(options.row_stride);
4074 f_dim_1->set_padding_low(options.row_padding);
4075 f_dim_1->set_padding_high(options.row_padding);
4076 f_dim_1->set_window_dilation(1);
4077 f_dim_1->set_base_dilation(1);
4078 auto* f_dim_2 = window.add_dimensions();
4079 f_dim_2->set_size(options.f_width);
4080 f_dim_2->set_stride(options.col_stride);
4081 f_dim_2->set_padding_low(options.col_padding);
4082 f_dim_2->set_padding_high(options.col_padding);
4083 f_dim_2->set_window_dilation(1);
4084 f_dim_2->set_base_dilation(1);
4085
4086 ConvolutionDimensionNumbers dnums;
4087 std::vector<int64> in_dims;
4088 int in_channel_idx = -1;
4089 // filled in later
4090 dnums.add_input_spatial_dimensions(-1);
4091 dnums.add_output_spatial_dimensions(-1);
4092 dnums.add_input_spatial_dimensions(-1);
4093 dnums.add_output_spatial_dimensions(-1);
4094 for (int i = 0; i < strlen(options.dim_order); ++i) {
4095 char ch = options.dim_order[i];
4096 if (ch == 'N') {
4097 dnums.set_input_batch_dimension(i);
4098 dnums.set_output_batch_dimension(i);
4099 in_dims.push_back(options.in_batch);
4100 } else if (ch == 'H') {
4101 dnums.set_input_spatial_dimensions(0, i);
4102 dnums.set_output_spatial_dimensions(0, i);
4103 in_dims.push_back(options.in_height);
4104 } else if (ch == 'W') {
4105 dnums.set_input_spatial_dimensions(1, i);
4106 dnums.set_output_spatial_dimensions(1, i);
4107 in_dims.push_back(options.in_width);
4108 } else if (ch == 'C') {
4109 dnums.set_input_feature_dimension(i);
4110 dnums.set_output_feature_dimension(i);
4111 in_dims.push_back(options.in_channels);
4112 in_channel_idx = i;
4113 }
4114 }
4115
4116 std::vector<int64> f_dims;
4117 dnums.add_kernel_spatial_dimensions(-1); // filled in later
4118 dnums.add_kernel_spatial_dimensions(-1); // filled in later
4119 for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
4120 char ch = options.kernel_dim_order[i];
4121 if (ch == 'H') {
4122 dnums.set_kernel_spatial_dimensions(0, i);
4123 f_dims.push_back(options.f_height);
4124 } else if (ch == 'W') {
4125 dnums.set_kernel_spatial_dimensions(1, i);
4126 f_dims.push_back(options.f_width);
4127 } else if (ch == 'I') {
4128 dnums.set_kernel_input_feature_dimension(i);
4129 f_dims.push_back(options.in_channels);
4130 } else if (ch == 'O') {
4131 dnums.set_kernel_output_feature_dimension(i);
4132 f_dims.push_back(options.f_output_channels);
4133 }
4134 }
4135
4136 auto make_shape = [](absl::Span<const int64> dims,
4137 bool minor_to_major_layout) {
4138 if (minor_to_major_layout) {
4139 return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
4140 } else {
4141 return ShapeUtil::MakeShape(F32, dims);
4142 }
4143 };
4144 auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
4145 auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
4146
4147 HloInstruction* input =
4148 b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
4149 HloInstruction* filter =
4150 b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
4151 Shape out_shape = ShapeInference::InferConvolveShape(
4152 in_shape, f_shape, /*feature_group_count=*/1,
4153 /*batch_group_count=*/1, window, dnums,
4154 /*preferred_element_type=*/absl::nullopt)
4155 .ValueOrDie();
4156 if (options.output_minor_to_major_layout) {
4157 out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
4158 {0, 1, 2, 3});
4159 }
4160
4161 b.AddInstruction(HloInstruction::CreateConvolve(
4162 out_shape, input, filter,
4163 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
4164 DefaultPrecisionConfig(2)));
4165
4166 auto module = CreateNewVerifiedModule();
4167 auto* computation = module->AddEntryComputation(b.Build());
4168
4169 AlgebraicSimplifierOptions simplifier_options;
4170 simplifier_options.set_is_layout_sensitive(true);
4171 AlgebraicSimplifier simplifier(simplifier_options);
4172 if (!simplifier.Run(module.get()).ValueOrDie()) {
4173 return "NO_CHANGE";
4174 }
4175 auto* root = computation->root_instruction();
4176 if (root->opcode() == HloOpcode::kBitcast &&
4177 root->operand(0)->opcode() == HloOpcode::kDot) {
4178 auto lhs_shape = root->operand(0)->operand(0)->shape();
4179 auto rhs_shape = root->operand(0)->operand(1)->shape();
4180 return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
4181 absl::StrJoin(rhs_shape.dimensions(), "x"));
4182 }
4183 return "UNEXPECTED CHANGE";
4184 };
4185
4186 // Default options are the simplest case and succeed.
4187 options.Reset();
4188 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4189
4190 // Swapping dim spatial and batch order works.
4191 options.Reset().dim_order = "NWHC";
4192 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4193 options.Reset().dim_order = "WHNC";
4194 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4195 // Channel dimension earlier fails.
4196 options.Reset().dim_order = "HWCN";
4197 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4198 options.Reset().dim_order = "CHWN";
4199 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4200
4201 // Filtering dims spatial dims can be anywhere, since they are 1x1.
4202 options.Reset().kernel_dim_order = "WHIO";
4203 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4204 options.Reset().kernel_dim_order = "IWOH";
4205 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4206 options.Reset().kernel_dim_order = "IWHO";
4207 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4208 // But moving output channel before input channel fails.
4209 options.Reset().kernel_dim_order = "HWOI";
4210 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4211 options.Reset().kernel_dim_order = "WHOI";
4212 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4213 options.Reset().kernel_dim_order = "OWIH";
4214 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4215 options.Reset().kernel_dim_order = "OWHI";
4216 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4217
4218 // Combine different dim and kernel dim orders.
4219 options.Reset().kernel_dim_order = "IWHO";
4220 options.dim_order = "WHNC";
4221 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4222
4223 // Test invalid cases from wrong filter size, strides, or padding.
4224 options.Reset().f_width = 2;
4225 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4226 options.Reset().f_height = 2;
4227 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4228 options.Reset().row_stride = 2;
4229 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4230 options.Reset().col_stride = 2;
4231 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4232 options.Reset().col_padding = 1;
4233 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4234 options.Reset().row_padding = 1;
4235 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4236
4237 // The default dim_order is "NHWC". Col-major layout makes C the most major.
4238 options.Reset().input_minor_to_major_layout = true;
4239 options.output_minor_to_major_layout = true;
4240 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4241
4242 // The input and output have different layouts.
4243 options.Reset().input_minor_to_major_layout = true;
4244 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4245
4246 // C is most minor, and I is more major than O.
4247 options.Reset().input_minor_to_major_layout = true;
4248 options.filter_minor_to_major_layout = true;
4249 options.output_minor_to_major_layout = true;
4250 options.dim_order = "CHWN";
4251 options.kernel_dim_order = "OIHW";
4252 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4253
4254 // C is not the most minor dimension.
4255 options.Reset().input_minor_to_major_layout = true;
4256 options.filter_minor_to_major_layout = true;
4257 options.output_minor_to_major_layout = true;
4258 options.dim_order = "HWNC";
4259 options.kernel_dim_order = "OIHW";
4260 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4261
4262 // I is more minor than O.
4263 options.Reset().input_minor_to_major_layout = true;
4264 options.filter_minor_to_major_layout = true;
4265 options.output_minor_to_major_layout = true;
4266 options.dim_order = "CHWN";
4267 options.kernel_dim_order = "IOHW";
4268 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4269 }
4270
4271 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
4272 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)4273 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
4274 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
4275 HloComputation::Builder builder(TestName());
4276 HloInstruction* scalar_param = builder.AddInstruction(
4277 HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
4278
4279 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
4280 HloInstruction* broadcast = builder.AddInstruction(
4281 HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
4282
4283 Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
4284 HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
4285 slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
4286
4287 auto module = CreateNewVerifiedModule();
4288 auto computation = module->AddEntryComputation(builder.Build());
4289
4290 HloInstruction* root = computation->root_instruction();
4291 EXPECT_EQ(root, slice);
4292 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
4293
4294 AlgebraicSimplifier simplifier(default_options_);
4295
4296 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4297
4298 // Running simplification again should not result in any further changes.
4299 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4300 EXPECT_THAT(computation->root_instruction(),
4301 GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
4302 .WithShapeEqualTo(&slice_shape)));
4303 }
4304
4305 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
4306 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)4307 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
4308 HloComputation::Builder builder(TestName());
4309 HloInstruction* forty_two = builder.AddInstruction(
4310 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
4311
4312 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
4313 HloInstruction* broadcast = builder.AddInstruction(
4314 HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
4315
4316 HloInstruction* transpose =
4317 builder.AddInstruction(HloInstruction::CreateTranspose(
4318 ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
4319
4320 Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
4321 HloInstruction* reshape = builder.AddInstruction(
4322 HloInstruction::CreateReshape(reshape_shape, transpose));
4323
4324 auto module = CreateNewVerifiedModule();
4325 auto computation = module->AddEntryComputation(builder.Build());
4326
4327 HloInstruction* root = computation->root_instruction();
4328 EXPECT_EQ(root, reshape);
4329 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
4330
4331 AlgebraicSimplifier simplifier(default_options_);
4332 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4333 EXPECT_THAT(computation->root_instruction(),
4334 GmockMatch(m::Broadcast(m::Op().Is(forty_two))
4335 .WithShapeEqualTo(&reshape_shape)));
4336 }
4337
4338 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)4339 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
4340 auto module = CreateNewVerifiedModule();
4341 HloComputation::Builder builder(TestName());
4342
4343 // Create operand to the pad.
4344 HloInstruction* operand =
4345 builder.AddInstruction(HloInstruction::CreateParameter(
4346 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
4347
4348 // Create the pad.
4349 PaddingConfig padding = MakeNoPaddingConfig(4);
4350 padding.mutable_dimensions(1)->set_edge_padding_low(1);
4351 padding.mutable_dimensions(3)->set_edge_padding_high(2);
4352
4353 HloInstruction* pad_value = builder.AddInstruction(
4354 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4355 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4356 ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
4357
4358 // Create add computation.
4359 HloComputation* add_computation = nullptr;
4360 {
4361 HloComputation::Builder builder(TestName() + ".add");
4362 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4363 HloInstruction* p0 = builder.AddInstruction(
4364 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4365 HloInstruction* p1 = builder.AddInstruction(
4366 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4367 builder.AddInstruction(
4368 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4369 add_computation = module->AddEmbeddedComputation(builder.Build());
4370 }
4371
4372 // Create the reduce-window.
4373 Window window;
4374 for (int64 i = 0; i < pad->shape().rank(); ++i) {
4375 auto* dim = window.add_dimensions();
4376 dim->set_size(1);
4377 dim->set_padding_low(10);
4378 dim->set_padding_high(100);
4379 dim->set_window_dilation(1);
4380 dim->set_base_dilation(1);
4381 dim->set_stride(1);
4382 }
4383 const Shape reduce_window_shape =
4384 ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
4385 HloInstruction* reduce_init_value = builder.AddInstruction(
4386 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4387 HloInstruction* reduce_window =
4388 builder.AddInstruction(HloInstruction::CreateReduceWindow(
4389 reduce_window_shape, pad, reduce_init_value, window,
4390 add_computation));
4391
4392 // Build the computation and run the simplifier.
4393 auto computation = module->AddEntryComputation(builder.Build());
4394 HloInstruction* root = computation->root_instruction();
4395 EXPECT_EQ(root, reduce_window);
4396 AlgebraicSimplifier simplifier(default_options_);
4397 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4398
4399 // Running simplification again should not result in any further changes.
4400 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4401
4402 // Verify the result
4403 root = computation->root_instruction();
4404 EXPECT_THAT(root,
4405 GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant())));
4406 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
4407 << ShapeUtil::HumanString(root->shape()) << " vs "
4408 << ShapeUtil::HumanString(reduce_window_shape);
4409 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4410 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4411 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4412 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4413 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4414 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4415 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4416 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4417 }
4418
4419 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
4420 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)4421 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
4422 auto module = CreateNewVerifiedModule();
4423 HloComputation::Builder builder(TestName());
4424
4425 // Create operand to the pad.
4426 HloInstruction* parameter =
4427 builder.AddInstruction(HloInstruction::CreateParameter(
4428 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
4429
4430 // Create the pad.
4431 PaddingConfig padding = MakeNoPaddingConfig(4);
4432 padding.mutable_dimensions(1)->set_edge_padding_low(1);
4433 padding.mutable_dimensions(3)->set_edge_padding_high(2);
4434
4435 HloInstruction* pad_value = builder.AddInstruction(
4436 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4437 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
4438 ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
4439
4440 HloInstruction* convert =
4441 builder.AddInstruction(HloInstruction::CreateConvert(
4442 ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
4443
4444 // Create add computation.
4445 HloComputation* add_computation = nullptr;
4446 {
4447 HloComputation::Builder builder(TestName() + ".add");
4448 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
4449 HloInstruction* p0 = builder.AddInstruction(
4450 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
4451 HloInstruction* p1 = builder.AddInstruction(
4452 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
4453 builder.AddInstruction(
4454 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
4455 add_computation = module->AddEmbeddedComputation(builder.Build());
4456 }
4457
4458 // Create the reduce-window.
4459 Window window;
4460 for (int64 i = 0; i < pad->shape().rank(); ++i) {
4461 auto* dim = window.add_dimensions();
4462 dim->set_size(1);
4463 dim->set_padding_low(10);
4464 dim->set_padding_high(100);
4465 dim->set_window_dilation(1);
4466 dim->set_base_dilation(1);
4467 dim->set_stride(1);
4468 }
4469 const Shape reduce_window_shape =
4470 ShapeUtil::MakeShape(F32, {111, 113, 113, 116});
4471 HloInstruction* reduce_init_value = builder.AddInstruction(
4472 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
4473 HloInstruction* reduce_window =
4474 builder.AddInstruction(HloInstruction::CreateReduceWindow(
4475 reduce_window_shape, convert, reduce_init_value, window,
4476 add_computation));
4477
4478 // Build the computation and run the simplifier.
4479 auto computation = module->AddEntryComputation(builder.Build());
4480 HloInstruction* root = computation->root_instruction();
4481 EXPECT_EQ(root, reduce_window);
4482 AlgebraicSimplifier simplifier(default_options_);
4483 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4484
4485 // Running simplification again should not result in any further changes.
4486 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4487
4488 // Verify the result
4489 root = computation->root_instruction();
4490 EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
4491 m::Constant())));
4492 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
4493 << ShapeUtil::HumanString(root->shape()) << " vs "
4494 << ShapeUtil::HumanString(reduce_window_shape);
4495 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4496 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4497 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4498 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4499 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4500 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4501 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4502 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4503 }
4504
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)4505 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
4506 HloComputation::Builder builder(TestName());
4507 const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
4508 HloInstruction* a =
4509 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
4510 builder.AddInstruction(
4511 HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
4512
4513 auto module = CreateNewVerifiedModule();
4514 auto computation = module->AddEntryComputation(builder.Build());
4515
4516 AlgebraicSimplifier simplifier(default_options_);
4517 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4518
4519 HloInstruction* root = computation->root_instruction();
4520 EXPECT_EQ(a, root);
4521 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
4522 }
4523
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)4524 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
4525 // Dots add computations to the parent module. Test that, when the HloModule's
4526 // computations are updated, then iterator invalidation doesn't occur
4527 // when running on subsequent computations.
4528 auto m = CreateNewVerifiedModule();
4529 Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
4530 HloComputation::Builder builder(TestName() + ".Dot");
4531 HloInstruction* x =
4532 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
4533 HloInstruction* y =
4534 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
4535 DotDimensionNumbers dot_dnums;
4536 dot_dnums.add_lhs_batch_dimensions(0);
4537 dot_dnums.add_rhs_batch_dimensions(0);
4538 builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
4539 DefaultPrecisionConfig(2)));
4540 std::unique_ptr<HloComputation> dot_computation(builder.Build());
4541
4542 HloComputation::Builder call_builder(TestName() + ".Call");
4543 HloInstruction* zero = call_builder.AddInstruction(
4544 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
4545 HloInstruction* one = call_builder.AddInstruction(
4546 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
4547 call_builder.AddInstruction(
4548 HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
4549
4550 m->AddEmbeddedComputation(std::move(dot_computation));
4551 m->AddEntryComputation(call_builder.Build());
4552 AlgebraicSimplifier simplifier(default_options_);
4553 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4554 }
4555
4556 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)4557 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
4558 auto m = CreateNewVerifiedModule();
4559 HloComputation::Builder builder(TestName());
4560 const float constant_scalar = 7.3f;
4561 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
4562 Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
4563 LiteralUtil::CreateR1<float>(constant_vector)};
4564 Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
4565 builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
4566
4567 auto computation = m->AddEntryComputation(builder.Build());
4568
4569 AlgebraicSimplifier simplifier(default_options_);
4570 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4571 EXPECT_THAT(computation->root_instruction(),
4572 GmockMatch(m::Tuple(m::Constant(), m::Constant())));
4573 }
4574
4575 // A dynamic-slice is trivial if its start indices are all zeroes and the size
4576 // of its input equals the size of its output. In this case, the dynamic slice
4577 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)4578 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
4579 auto m = CreateNewVerifiedModule();
4580 HloComputation::Builder builder(TestName());
4581
4582 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4583 std::vector<HloInstruction*> params;
4584 for (int i = 0; i < 3; ++i) {
4585 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
4586 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
4587 }
4588 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4589 shape,
4590 builder.AddInstruction(
4591 HloInstruction::CreateParameter(0, shape, "slice_from")),
4592 params,
4593 /*slice_sizes=*/{10, 100, 1000}));
4594
4595 auto computation = m->AddEntryComputation(builder.Build());
4596 AlgebraicSimplifier simplifier(default_options_);
4597 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4598 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
4599 }
4600
TEST_F(AlgebraicSimplifierTest,ConstantDynamicSlice)4601 TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
4602 auto m = CreateNewVerifiedModule();
4603 HloComputation::Builder builder(TestName());
4604
4605 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4606 std::vector<HloInstruction*> params;
4607 for (int i = 0; i < 3; ++i) {
4608 params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
4609 LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
4610 }
4611 Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
4612 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4613 ds_shape,
4614 builder.AddInstruction(
4615 HloInstruction::CreateParameter(0, shape, "operand")),
4616 params,
4617 /*slice_sizes=*/{2, 20, 200}));
4618
4619 auto computation = m->AddEntryComputation(builder.Build());
4620 AlgebraicSimplifier simplifier(default_options_);
4621 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4622 EXPECT_THAT(computation->root_instruction(),
4623 GmockMatch(m::Slice(m::Parameter())));
4624 }
4625
4626 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
4627 // size of its "update" equals the size of its output. In this case, the
4628 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)4629 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
4630 auto m = CreateNewVerifiedModule();
4631 HloComputation::Builder builder(TestName());
4632
4633 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4634 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
4635
4636 std::vector<HloInstruction*> slice_indices, update_indices;
4637 for (int i = 0; i < 3; ++i) {
4638 slice_indices.push_back(
4639 builder.AddInstruction(HloInstruction::CreateParameter(
4640 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
4641 update_indices.push_back(
4642 builder.AddInstruction(HloInstruction::CreateParameter(
4643 i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
4644 }
4645 HloInstruction* slice =
4646 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
4647 slice_shape,
4648 builder.AddInstruction(
4649 HloInstruction::CreateParameter(0, full_shape, "slice_from")),
4650 slice_indices,
4651 /*slice_sizes=*/{10, 1, 1000}));
4652
4653 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4654 slice_shape,
4655 builder.AddInstruction(
4656 HloInstruction::CreateParameter(4, slice_shape, "to_update")),
4657 slice, update_indices));
4658
4659 auto computation = m->AddEntryComputation(builder.Build());
4660 AlgebraicSimplifier simplifier(default_options_);
4661 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4662 EXPECT_THAT(computation->root_instruction(),
4663 GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
4664 m::Parameter(), m::Parameter())));
4665 }
4666
4667 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)4668 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
4669 auto m = CreateNewVerifiedModule();
4670 HloComputation::Builder builder(TestName());
4671 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
4672 HloInstruction* input_array = builder.AddInstruction(
4673 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
4674 HloInstruction* inner_bcast = builder.AddInstruction(
4675 HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
4676 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
4677 builder.AddInstruction(
4678 HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
4679
4680 auto computation = m->AddEntryComputation(builder.Build());
4681 HloInstruction* root = computation->root_instruction();
4682 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4683 AlgebraicSimplifier simplifier(default_options_);
4684 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4685 root = computation->root_instruction();
4686 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4687 EXPECT_THAT(root->dimensions(), ElementsAre(2));
4688 }
4689
4690 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)4691 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
4692 auto m = CreateNewVerifiedModule();
4693 HloComputation::Builder builder(TestName());
4694 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
4695 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
4696 HloInstruction* param0 = builder.AddInstruction(
4697 HloInstruction::CreateParameter(0, r2f32, "param0"));
4698 // The initial dimensions go to places 0 and 2 in the 3-dim array,
4699 // and to places 1 and 3 in the 4-dim array,
4700 HloInstruction* inner_bcast = builder.AddInstruction(
4701 HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
4702 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
4703 builder.AddInstruction(
4704 HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
4705
4706 auto computation = m->AddEntryComputation(builder.Build());
4707 HloInstruction* root = computation->root_instruction();
4708 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4709 AlgebraicSimplifier simplifier(default_options_);
4710 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4711 root = computation->root_instruction();
4712 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
4713 EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
4714 }
4715
4716 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)4717 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
4718 auto m = CreateNewVerifiedModule();
4719 HloComputation::Builder builder(TestName());
4720 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
4721 HloInstruction* iota =
4722 builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
4723 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
4724 builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
4725
4726 auto computation = m->AddEntryComputation(builder.Build());
4727 HloInstruction* root = computation->root_instruction();
4728 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4729 AlgebraicSimplifier simplifier(default_options_);
4730 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4731 root = computation->root_instruction();
4732 EXPECT_THAT(root, GmockMatch(m::Iota()));
4733 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
4734 }
4735
4736 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)4737 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
4738 auto m = CreateNewVerifiedModule();
4739 HloComputation::Builder builder(TestName());
4740 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
4741 HloInstruction* iota =
4742 builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
4743 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
4744 builder.AddInstruction(
4745 HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
4746
4747 auto computation = m->AddEntryComputation(builder.Build());
4748 HloInstruction* root = computation->root_instruction();
4749 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
4750 AlgebraicSimplifier simplifier(default_options_);
4751 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4752 root = computation->root_instruction();
4753 EXPECT_THAT(root, GmockMatch(m::Iota()));
4754 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
4755 }
4756
TEST_F(AlgebraicSimplifierTest,TransposeOfDot)4757 TEST_F(AlgebraicSimplifierTest, TransposeOfDot) {
4758 const char* hlo_string = R"(
4759 HloModule module
4760
4761 ENTRY test {
4762 lhs = f32[3,4,5] parameter(0)
4763 rhs = f32[6,3,4] parameter(1)
4764 dot = f32[5,6] dot(lhs,rhs), lhs_contracting_dims={0,1}, rhs_contracting_dims={1,2}
4765 ROOT transpose = f32[6,5] transpose(dot), dimensions={1,0}
4766 }
4767 )";
4768 TF_ASSERT_OK_AND_ASSIGN(auto module,
4769 ParseAndReturnVerifiedModule(hlo_string));
4770
4771 AlgebraicSimplifierOptions options;
4772 AlgebraicSimplifier simplifier(options);
4773 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4774 auto root = module->entry_computation()->root_instruction();
4775 EXPECT_THAT(root, GmockMatch(m::Dot(m::Parameter(1), m::Parameter(0))));
4776 }
4777
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)4778 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
4779 const char* hlo_string = R"(
4780 HloModule module
4781
4782 ENTRY test {
4783 param = f32[3,4] parameter(0)
4784 constant = f32[] constant(0.0)
4785 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4786 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
4787 }
4788 )";
4789 TF_ASSERT_OK_AND_ASSIGN(auto module,
4790 ParseAndReturnVerifiedModule(hlo_string));
4791
4792 AlgebraicSimplifierOptions options;
4793 AlgebraicSimplifier simplifier(options);
4794 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4795 auto root = module->entry_computation()->root_instruction();
4796 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4797 }
4798
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)4799 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
4800 const char* hlo_string = R"(
4801 HloModule module
4802
4803 ENTRY test {
4804 param = f32[3,4] parameter(0)
4805 constant = f32[] constant(0.0)
4806 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4807 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
4808 }
4809 )";
4810 TF_ASSERT_OK_AND_ASSIGN(auto module,
4811 ParseAndReturnVerifiedModule(hlo_string));
4812
4813 AlgebraicSimplifierOptions options;
4814 AlgebraicSimplifier simplifier(options);
4815 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4816 auto root = module->entry_computation()->root_instruction();
4817 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4818 }
4819
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)4820 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
4821 const char* hlo_string = R"(
4822 HloModule module
4823
4824 ENTRY test {
4825 param = f32[3,4] parameter(0)
4826 constant = f32[] constant(0.0)
4827 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4828 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
4829 }
4830 )";
4831 TF_ASSERT_OK_AND_ASSIGN(auto module,
4832 ParseAndReturnVerifiedModule(hlo_string));
4833
4834 AlgebraicSimplifierOptions options;
4835 AlgebraicSimplifier simplifier(options);
4836 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4837 EXPECT_THAT(module->entry_computation()->root_instruction(),
4838 GmockMatch(m::Slice(m::Parameter(0))));
4839 }
4840
TEST_F(AlgebraicSimplifierTest,SliceOfPad)4841 TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
4842 const char* hlo_string = R"(
4843 HloModule module
4844
4845 ENTRY test {
4846 param = f32[3,4] parameter(0)
4847 constant = f32[] constant(0.0)
4848 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4849 ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
4850 }
4851 )";
4852 TF_ASSERT_OK_AND_ASSIGN(auto module,
4853 ParseAndReturnVerifiedModule(hlo_string));
4854
4855 AlgebraicSimplifierOptions options;
4856 AlgebraicSimplifier simplifier(options);
4857 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4858 auto root = module->entry_computation()->root_instruction();
4859 EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
4860 EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
4861 }
4862
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)4863 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
4864 const char* hlo_string = R"(
4865 HloModule module
4866
4867 ENTRY test {
4868 param = f32[3,4] parameter(0)
4869 constant = f32[] constant(0.0)
4870 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
4871 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
4872 }
4873 )";
4874 TF_ASSERT_OK_AND_ASSIGN(auto module,
4875 ParseAndReturnVerifiedModule(hlo_string));
4876
4877 AlgebraicSimplifierOptions options;
4878 AlgebraicSimplifier simplifier(options);
4879 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4880 auto root = module->entry_computation()->root_instruction();
4881 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
4882 }
4883
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)4884 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
4885 const char* hlo_string = R"(
4886 HloModule module
4887
4888 ENTRY test {
4889 param = f32[1,1] parameter(0)
4890 constant = f32[] constant(0.0)
4891 pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
4892 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
4893 }
4894 )";
4895 TF_ASSERT_OK_AND_ASSIGN(auto module,
4896 ParseAndReturnVerifiedModule(hlo_string));
4897
4898 AlgebraicSimplifierOptions options;
4899 AlgebraicSimplifier simplifier(options);
4900 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4901 auto root = module->entry_computation()->root_instruction();
4902 EXPECT_THAT(root, GmockMatch(m::Parameter()));
4903 }
4904
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)4905 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
4906 const char* hlo_string = R"(
4907 HloModule module
4908
4909 ENTRY entry () -> f32[1]{0} {
4910 constant.val = f32[] constant(4)
4911 constant.pad = f32[] constant(-7)
4912 reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
4913 pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
4914 slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
4915 ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
4916 }
4917 )";
4918 TF_ASSERT_OK_AND_ASSIGN(auto module,
4919 ParseAndReturnVerifiedModule(hlo_string));
4920
4921 AlgebraicSimplifierOptions options;
4922 AlgebraicSimplifier simplifier(options);
4923 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4924 auto root = module->entry_computation()->root_instruction();
4925 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
4926 }
4927
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)4928 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
4929 const char* hlo_string = R"(
4930 HloModule module
4931
4932 ENTRY test {
4933 param.0 = f32[2] parameter(0)
4934 param.1 = f32[1] parameter(1)
4935 param.2 = f32[3] parameter(2)
4936 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4937 ROOT slice = f32[1] slice(concat), slice={[2:3]}
4938 }
4939 )";
4940 TF_ASSERT_OK_AND_ASSIGN(auto module,
4941 ParseAndReturnVerifiedModule(hlo_string));
4942
4943 AlgebraicSimplifierOptions options;
4944 AlgebraicSimplifier simplifier(options);
4945 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4946 auto root = module->entry_computation()->root_instruction();
4947 EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
4948 }
4949
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)4950 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
4951 const char* hlo_string = R"(
4952 HloModule module
4953
4954 ENTRY test {
4955 param.0 = f32[2] parameter(0)
4956 param.1 = f32[1] parameter(1)
4957 param.2 = f32[3] parameter(2)
4958 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
4959 ROOT slice = f32[1] slice(concat), slice={[4:5]}
4960 }
4961 )";
4962 TF_ASSERT_OK_AND_ASSIGN(auto module,
4963 ParseAndReturnVerifiedModule(hlo_string));
4964
4965 AlgebraicSimplifierOptions options;
4966 AlgebraicSimplifier simplifier(options);
4967 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4968 auto root = module->entry_computation()->root_instruction();
4969 EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
4970 EXPECT_EQ(root->slice_starts(0), 1);
4971 EXPECT_EQ(root->slice_limits(0), 2);
4972 }
4973
TEST_F(AlgebraicSimplifierTest,ConcatToBroadcast)4974 TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) {
4975 const char* hlo_string = R"(
4976 HloModule module
4977
4978 ENTRY test {
4979 p = f32[2,1,4] parameter(0)
4980 ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1}
4981 }
4982 )";
4983 TF_ASSERT_OK_AND_ASSIGN(auto module,
4984 ParseAndReturnVerifiedModule(hlo_string));
4985
4986 AlgebraicSimplifierOptions options;
4987 AlgebraicSimplifier simplifier(options);
4988 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4989 auto root = module->entry_computation()->root_instruction();
4990 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
4991 }
4992
TEST_F(AlgebraicSimplifierTest,NegateNegate)4993 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
4994 const char* hlo_string = R"(
4995 HloModule module
4996
4997 ENTRY test {
4998 param.0 = f32[2] parameter(0)
4999 neg.0 = f32[2] negate(param.0)
5000 ROOT neg.1 = f32[2] negate(neg.0)
5001 }
5002 )";
5003 TF_ASSERT_OK_AND_ASSIGN(auto module,
5004 ParseAndReturnVerifiedModule(hlo_string));
5005
5006 AlgebraicSimplifierOptions options;
5007 AlgebraicSimplifier simplifier(options);
5008 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5009 auto root = module->entry_computation()->root_instruction();
5010 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5011 }
5012
TEST_F(AlgebraicSimplifierTest,NotNot)5013 TEST_F(AlgebraicSimplifierTest, NotNot) {
5014 const char* hlo_string = R"(
5015 HloModule module
5016
5017 ENTRY test {
5018 param.0 = pred[2] parameter(0)
5019 not.0 = pred[2] not(param.0)
5020 ROOT not.1 = pred[2] not(not.0)
5021 }
5022 )";
5023 TF_ASSERT_OK_AND_ASSIGN(auto module,
5024 ParseAndReturnVerifiedModule(hlo_string));
5025
5026 AlgebraicSimplifierOptions options;
5027 AlgebraicSimplifier simplifier(options);
5028 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5029 auto root = module->entry_computation()->root_instruction();
5030 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5031 }
5032
5033 struct PadReduceWindowEffectiveBroadcastCase {
5034 std::vector<int64> input_spatials;
5035 std::vector<int64> symmetric_pad_spatials;
5036 std::vector<int64> reduce_window_spatials;
5037 // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
5038 //
5039 // This doesn't test any different functionality but is useful for making sure
5040 // kBroadcast nodes are well formed.
5041 bool prepend_a;
5042 bool should_become_broadcast;
5043
ToTestCaseNamexla::__anonac242c730111::PadReduceWindowEffectiveBroadcastCase5044 string ToTestCaseName() const {
5045 return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
5046 absl::StrJoin(symmetric_pad_spatials, ","), ";",
5047 absl::StrJoin(reduce_window_spatials, ","), ";",
5048 prepend_a, ";", should_become_broadcast);
5049 }
5050 };
5051
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)5052 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
5053 *os << c.ToTestCaseName();
5054 }
5055
5056 class PadReduceWindowEffectiveBroadcastTest
5057 : public AlgebraicSimplifierTest,
5058 public ::testing::WithParamInterface<
5059 PadReduceWindowEffectiveBroadcastCase> {};
5060
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)5061 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
5062 auto m = CreateNewVerifiedModule();
5063 const auto& param = GetParam();
5064
5065 // a and b are parallel bounds we can either turn into a B F S0 S1 or
5066 // `B S0 S1 F` kind of pattern.
5067 auto decorate_spatials = [¶m](absl::Span<const int64> spatials, int64 a,
5068 int64 b) {
5069 std::vector<int64> result;
5070 if (param.prepend_a) {
5071 result.push_back(a);
5072 }
5073 for (int64 s : spatials) {
5074 result.push_back(s);
5075 }
5076 if (!param.prepend_a) {
5077 result.push_back(a);
5078 }
5079 result.push_back(b);
5080 return result;
5081 };
5082
5083 HloComputation::Builder builder(TestName());
5084 const Shape input_shape = ShapeUtil::MakeShape(
5085 F32, decorate_spatials(param.input_spatials, 128, 2048));
5086 HloInstruction* input = builder.AddInstruction(
5087 HloInstruction::CreateParameter(0, input_shape, "input"));
5088
5089 PaddingConfig padding = window_util::MakeSymmetricPadding(
5090 decorate_spatials(param.symmetric_pad_spatials, 0, 0));
5091 TF_ASSERT_OK_AND_ASSIGN(
5092 const Shape pad_shape,
5093 ShapeInference::InferPadShape(input->shape(),
5094 ShapeUtil::MakeShape(F32, {}), padding));
5095 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
5096 pad_shape, input,
5097 builder.AddInstruction(
5098 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
5099 padding));
5100
5101 HloComputation* add_computation = nullptr;
5102 {
5103 HloComputation::Builder builder(TestName() + ".add");
5104 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
5105 HloInstruction* p0 = builder.AddInstruction(
5106 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
5107 HloInstruction* p1 = builder.AddInstruction(
5108 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
5109 builder.AddInstruction(
5110 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
5111 add_computation = m->AddEmbeddedComputation(builder.Build());
5112 }
5113
5114 Window window = window_util::MakeWindow(
5115 decorate_spatials(param.reduce_window_spatials, 1, 1));
5116 auto zero = builder.AddInstruction(
5117 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
5118 TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
5119 ShapeInference::InferReduceWindowShape(
5120 pad->shape(), zero->shape(), window,
5121 add_computation->ComputeProgramShape()));
5122 builder.AddInstruction(HloInstruction::CreateReduceWindow(
5123 output_shape, pad, zero, window, add_computation));
5124
5125 auto computation = m->AddEntryComputation(builder.Build());
5126 AlgebraicSimplifier simplifier(default_options_);
5127 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5128 ASSERT_TRUE(run_successful);
5129
5130 EXPECT_TRUE(
5131 ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
5132
5133 if (param.should_become_broadcast) {
5134 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
5135 } else {
5136 EXPECT_THAT(computation->root_instruction(),
5137 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
5138 }
5139 }
5140
5141 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()5142 PadReduceWindowEffectiveBroadcastCases() {
5143 static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
5144 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5145 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5146 /*should_become_broadcast=*/true}, //
5147 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5148 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
5149 /*should_become_broadcast=*/true}, //
5150 {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
5151 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5152 /*should_become_broadcast=*/false}, //
5153 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
5154 /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
5155 /*should_become_broadcast=*/false}, //
5156 {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
5157 /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
5158 /*should_become_broadcast=*/false}, //
5159 };
5160 return *cases;
5161 }
5162
5163 INSTANTIATE_TEST_SUITE_P(
5164 PadReduceWindowEffectiveBroadcastInstantiation,
5165 PadReduceWindowEffectiveBroadcastTest,
5166 ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
5167
5168 class BatchDotStrengthReductionTest
5169 : public AlgebraicSimplifierTest,
5170 public ::testing::WithParamInterface<
5171 ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)5172 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
5173 auto module = CreateNewVerifiedModule();
5174 int m, k, n;
5175 PrimitiveType element_type;
5176 std::tie(m, k, n, element_type) = GetParam();
5177 std::vector<int64> lhs_dims = {2, 3, 5};
5178 std::vector<int64> rhs_dims = lhs_dims;
5179 std::vector<int64> output_dims = lhs_dims;
5180 if (m > 0) {
5181 lhs_dims.push_back(m);
5182 output_dims.push_back(m);
5183 }
5184 if (k > 0) {
5185 lhs_dims.push_back(k);
5186 rhs_dims.push_back(k);
5187 }
5188 if (n > 0) {
5189 rhs_dims.push_back(n);
5190 output_dims.push_back(n);
5191 }
5192 Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
5193 Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
5194 Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
5195 HloComputation::Builder builder(TestName());
5196
5197 auto lhs = builder.AddInstruction(
5198 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
5199 auto rhs = builder.AddInstruction(
5200 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
5201 DotDimensionNumbers dot_dnums;
5202 dot_dnums.add_lhs_batch_dimensions(0);
5203 dot_dnums.add_lhs_batch_dimensions(1);
5204 dot_dnums.add_lhs_batch_dimensions(2);
5205 dot_dnums.add_rhs_batch_dimensions(0);
5206 dot_dnums.add_rhs_batch_dimensions(1);
5207 dot_dnums.add_rhs_batch_dimensions(2);
5208 if (k > 0) {
5209 dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
5210 dot_dnums.add_rhs_contracting_dimensions(3);
5211 }
5212 builder.AddInstruction(HloInstruction::CreateDot(
5213 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5214 auto computation = module->AddEntryComputation(builder.Build());
5215 AlgebraicSimplifier simplifier(default_options_);
5216 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5217 const bool dot_should_be_transformed =
5218 m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
5219 EXPECT_EQ(changed, dot_should_be_transformed);
5220 TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5221 bool has_no_dot = true;
5222 for (const auto& hlo : computation->instructions()) {
5223 if (hlo->opcode() == HloOpcode::kDot) {
5224 has_no_dot = false;
5225 break;
5226 }
5227 }
5228 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5229 }
5230
5231 INSTANTIATE_TEST_SUITE_P(
5232 BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
5233 ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2),
5234 ::testing::Values(-1, 1, 2),
5235 ::testing::Values(C128, C64, F64, F32, BF16)));
5236
5237 class DotStrengthReductionTest
5238 : public AlgebraicSimplifierTest,
5239 public ::testing::WithParamInterface<
5240 ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)5241 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
5242 auto module = CreateNewVerifiedModule();
5243 int m, k, n;
5244 bool transpose_lhs, transpose_rhs;
5245 PrimitiveType element_type;
5246 std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
5247
5248 Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
5249 Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
5250 Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
5251 Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
5252 Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
5253 HloComputation::Builder builder(TestName());
5254
5255 auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
5256 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
5257 if (transpose_lhs) {
5258 lhs = builder.AddInstruction(
5259 HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
5260 }
5261 auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
5262 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
5263 if (transpose_rhs) {
5264 rhs = builder.AddInstruction(
5265 HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
5266 }
5267 DotDimensionNumbers dot_dnums;
5268 dot_dnums.add_lhs_contracting_dimensions(1);
5269 dot_dnums.add_rhs_contracting_dimensions(0);
5270 builder.AddInstruction(HloInstruction::CreateDot(
5271 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5272 auto computation = module->AddEntryComputation(builder.Build());
5273 AlgebraicSimplifier simplifier(default_options_);
5274 // First pass of algebraic simplifier will remove degenerate dimensions
5275 // and optimize dot(transpose(x),transpose(y))
5276 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5277 const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
5278 const bool computation_should_be_modified =
5279 dot_should_be_transformed || (transpose_lhs && transpose_rhs);
5280 EXPECT_EQ(changed, computation_should_be_modified);
5281 // The second pass of algebraic simplifier will remove dots without
5282 // non-contracting dimensions or contracting dimensions.
5283 TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5284 EXPECT_EQ(changed, computation_should_be_modified);
5285 bool has_no_dot = true;
5286 for (const auto& hlo : computation->instructions()) {
5287 if (hlo->opcode() == HloOpcode::kDot) {
5288 has_no_dot = false;
5289 break;
5290 }
5291 }
5292 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5293 }
5294
5295 INSTANTIATE_TEST_SUITE_P(
5296 DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
5297 ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
5298 ::testing::Values(1, 2), ::testing::Bool(),
5299 ::testing::Bool(),
5300 ::testing::Values(C128, C64, F64, F32, BF16)));
5301
5302 struct DotOfConcatTestSpec {
5303 int64 m;
5304 int64 k;
5305 int64 n;
5306 };
5307
5308 class DotOfConcatSimplificationTest
5309 : public AlgebraicSimplifierTest,
5310 public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
5311
5312 // Test that we transform
5313 // dot(const, concat(A, B, C))
5314 // to
5315 // add(dot(const_0, A), dot(const_1, B), dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)5316 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
5317 auto m = CreateNewVerifiedModule();
5318 HloComputation::Builder builder(TestName());
5319
5320 DotOfConcatTestSpec spec = GetParam();
5321
5322 ASSERT_GE(spec.k, 3);
5323
5324 int64 k0 = spec.k / 3;
5325 int64 k1 = spec.k / 3;
5326 int64 k2 = spec.k - k0 - k1;
5327
5328 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5329 auto* lhs = builder.AddInstruction(
5330 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5331 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
5332
5333 Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
5334 Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
5335 Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
5336
5337 HloInstruction* rhs0 = builder.AddInstruction(
5338 HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
5339 HloInstruction* rhs1 = builder.AddInstruction(
5340 HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
5341 HloInstruction* rhs2 = builder.AddInstruction(
5342 HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
5343
5344 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5345 HloInstruction* rhs = builder.AddInstruction(
5346 HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
5347
5348 DotDimensionNumbers dot_dnums;
5349 dot_dnums.add_lhs_contracting_dimensions(1);
5350 dot_dnums.add_rhs_contracting_dimensions(0);
5351
5352 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5353 builder.AddInstruction(HloInstruction::CreateDot(
5354 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5355
5356 auto computation = m->AddEntryComputation(builder.Build());
5357 AlgebraicSimplifier simplifier(default_options_);
5358 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5359 ASSERT_TRUE(run_successful);
5360
5361 EXPECT_TRUE(
5362 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5363
5364 auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
5365 auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
5366 auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
5367 EXPECT_THAT(
5368 computation->root_instruction(),
5369 GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
5370 }
5371
5372 // Test that we transform
5373 // dot(concat(A, B, C), const)
5374 // to
5375 // add(dot(A, const_0), dot(B, const_1), dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)5376 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
5377 auto m = CreateNewVerifiedModule();
5378 HloComputation::Builder builder(TestName());
5379
5380 DotOfConcatTestSpec spec = GetParam();
5381
5382 ASSERT_GE(spec.k, 4);
5383
5384 int64 k0 = spec.k / 4;
5385 int64 k1 = spec.k / 4;
5386 int64 k2 = spec.k / 4;
5387 int64 k3 = spec.k - k0 - k1 - k2;
5388
5389 Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
5390 Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
5391 Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
5392 Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
5393
5394 HloInstruction* lhs0 = builder.AddInstruction(
5395 HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
5396 HloInstruction* lhs1 = builder.AddInstruction(
5397 HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
5398 HloInstruction* lhs2 = builder.AddInstruction(
5399 HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
5400 HloInstruction* lhs3 = builder.AddInstruction(
5401 HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
5402
5403 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5404 HloInstruction* lhs =
5405 builder.AddInstruction(HloInstruction::CreateConcatenate(
5406 lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
5407
5408 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5409 auto* rhs = builder.AddInstruction(
5410 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5411 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
5412
5413 DotDimensionNumbers dot_dnums;
5414 dot_dnums.add_lhs_contracting_dimensions(1);
5415 dot_dnums.add_rhs_contracting_dimensions(0);
5416
5417 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5418 builder.AddInstruction(HloInstruction::CreateDot(
5419 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5420
5421 auto computation = m->AddEntryComputation(builder.Build());
5422 AlgebraicSimplifier simplifier(default_options_);
5423 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5424 ASSERT_TRUE(run_successful);
5425 EXPECT_TRUE(
5426 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5427
5428 auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
5429 auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
5430 auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
5431 auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
5432 EXPECT_THAT(
5433 computation->root_instruction(),
5434 GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
5435 match_dot_3)));
5436 }
5437
5438 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
5439 {/*m=*/3, /*k=*/9, /*n=*/3}, //
5440 {/*m=*/3, /*k=*/20, /*n=*/3}, //
5441 {/*m=*/1, /*k=*/18, /*n=*/5}, //
5442 {/*m=*/20, /*k=*/20, /*n=*/1}, //
5443 {/*m=*/1, /*k=*/16, /*n=*/1}, //
5444 };
5445
TEST_F(DotOfConcatSimplificationTest,ConcatIntoScalarDot)5446 TEST_F(DotOfConcatSimplificationTest, ConcatIntoScalarDot) {
5447 const char* kModuleStr = R"(
5448 HloModule m
5449 test {
5450 param0 = f32[4] parameter(0)
5451 param1 = f32[1] parameter(1)
5452 constant = f32[5] constant({-0.38, 0.07, -0.62, 0.66, 0.20})
5453 concat = f32[5] concatenate(param0, param1), dimensions={0}
5454 ROOT dot = f32[] dot(concat, constant), lhs_contracting_dims={0},
5455 rhs_contracting_dims={0}
5456 })";
5457 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5458 AlgebraicSimplifierOptions options = default_options_;
5459 options.set_enable_dot_strength_reduction(false);
5460 ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
5461 }
5462
5463 // Test that DynamicUpdateSlice update param with any dimension equal to zero
5464 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)5465 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
5466 auto m = CreateNewVerifiedModule();
5467 HloComputation::Builder builder(TestName());
5468 const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
5469 HloInstruction* const operand = builder.AddInstruction(
5470 HloInstruction::CreateParameter(0, dslice_shape, "operand"));
5471 const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
5472 HloInstruction* const update = builder.AddInstruction(
5473 HloInstruction::CreateParameter(1, update_shape, "update"));
5474 HloInstruction* const start_indices = builder.AddInstruction(
5475 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
5476 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
5477 dslice_shape, operand, update,
5478 std::initializer_list<HloInstruction*>({start_indices})));
5479 const HloComputation* const computation =
5480 m->AddEntryComputation(builder.Build());
5481
5482 AlgebraicSimplifier simplifier(default_options_);
5483 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5484 EXPECT_THAT(computation->root_instruction(), operand);
5485 }
5486
5487 // Test that dynamic-update-slice with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPad)5488 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPad) {
5489 const char* hlo_string = R"(
5490 HloModule AddBroadcastZeroWithDynamicSlice
5491
5492 ENTRY AddBroadcastZeroWithDynamicSlice {
5493 param0 = f32[1800,12,512]{2,1,0} parameter(0)
5494 constant = f32[] constant(0)
5495 broadcast = f32[1800,12,512]{2,1,0} broadcast(constant), dimensions={}
5496 param1 = f32[1,12,512]{2,1,0} parameter(1)
5497 constant.1 = s32[] constant(0)
5498 dynamic-update-slice = f32[1800,12,512]{2,1,0} dynamic-update-slice(broadcast, param1, constant.1, constant.1, constant.1)
5499 ROOT add = f32[1800,12,512]{2,1,0} add(param0, dynamic-update-slice)
5500 }
5501 )";
5502 TF_ASSERT_OK_AND_ASSIGN(auto module,
5503 ParseAndReturnVerifiedModule(hlo_string));
5504 VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
5505 AlgebraicSimplifier simplifier(default_options_);
5506 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5507 VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
5508 auto root = module->entry_computation()->root_instruction();
5509 EXPECT_THAT(root->opcode(), HloOpcode::kAdd);
5510 EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
5511 }
5512
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReduction)5513 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
5514 const char* hlo_string = R"(
5515 HloModule ConstScalarMultiply
5516 ENTRY ConstScalarMultiply {
5517 param0 = f32[16,512,4096]{2,1,0} parameter(0)
5518 constant.0 = f32[] constant(0.5)
5519 broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
5520 multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
5521 param1 = f32[16,512,4096]{2,1,0} parameter(1)
5522 multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
5523 param2 = f32[16,512,1024]{2,1,0} parameter(2)
5524 constant.1 = f32[] constant(1.109)
5525 broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
5526 multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
5527 ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
5528 }
5529 )";
5530 TF_ASSERT_OK_AND_ASSIGN(auto module,
5531 ParseAndReturnVerifiedModule(hlo_string));
5532 AlgebraicSimplifierOptions options;
5533 options.set_enable_scalar_multiply_reduction(true);
5534 AlgebraicSimplifier simplifier(options);
5535 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5536 auto root = module->entry_computation()->root_instruction();
5537 EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
5538 EXPECT_THAT(root,
5539 GmockMatch(m::MultiplyAnyOrder(
5540 m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
5541 }
5542
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReductionMultiUser)5543 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
5544 const char* hlo_string = R"(
5545 HloModule ConstScalarMultiply
5546 ENTRY ConstScalarMultiply {
5547 param0 = f32[16,512,1024] parameter(0)
5548 param1 = f32[4096,1024,1] parameter(1)
5549 convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
5550 constant.1 = f32[] constant(0.5)
5551 broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
5552 multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
5553 param2 = f32[16,512,4096] parameter(2)
5554 multiply.2 = f32[16,512,4096] multiply(convolution, param2)
5555 ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
5556 }
5557 )";
5558 TF_ASSERT_OK_AND_ASSIGN(auto module,
5559 ParseAndReturnVerifiedModule(hlo_string));
5560 AlgebraicSimplifierOptions options;
5561 options.set_enable_scalar_multiply_reduction(true);
5562 AlgebraicSimplifier simplifier(options);
5563 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
5564 }
5565
5566 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
5567 DotOfConcatSimplificationTest,
5568 ::testing::ValuesIn(kDotOfConcatTestSpecs));
5569
5570 struct DotOfGatherTestSpec {
5571 int64 m;
5572 int64 k;
5573 int64 n;
5574 int s; // start index for dynamic slice on the non-contracting dimension
5575 int64 lcd; // left contracting dimension
5576 int64 rcd; // right contracting dimension
5577 bool neg; // is negative testcase
5578 };
5579
5580 class DotOfGatherSimplificationTest
5581 : public AlgebraicSimplifierTest,
5582 public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
5583
5584 // input: dot(DS(ctA), ctB))
5585 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
5586 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
5587 // output: DS(dot(ctA, ctB))
5588 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)5589 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
5590 auto m = CreateNewVerifiedModule();
5591 HloComputation::Builder builder(TestName());
5592
5593 DotOfGatherTestSpec spec = GetParam();
5594
5595 ASSERT_LE(spec.s, spec.m);
5596
5597 // For negative tests, increase k of the dynamic slice argument to prevent the
5598 // optimization (constants ctA, ctB must have equal contracting dimensions).
5599 int64 k_increase = spec.neg ? 5 : 0;
5600 int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
5601 int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
5602 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
5603 auto* lhs = builder.AddInstruction(
5604 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5605 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
5606 /*cols=*/lhs_cols)));
5607
5608 int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
5609 int32 start_col = (spec.lcd == 0) ? spec.s : 0;
5610 std::vector<HloInstruction*> start_indices = {
5611 builder.AddInstruction(HloInstruction::CreateConstant(
5612 LiteralUtil::CreateR0<int32>(start_row))),
5613 builder.AddInstruction(HloInstruction::CreateConstant(
5614 LiteralUtil::CreateR0<int32>(start_col)))};
5615 int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
5616 int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
5617 std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
5618 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
5619 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5620 ds_shape, lhs, start_indices, slice_sizes));
5621
5622 int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
5623 int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
5624 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
5625 auto* rhs = builder.AddInstruction(
5626 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5627 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
5628 /*cols=*/rhs_cols)));
5629
5630 DotDimensionNumbers dot_dnums;
5631 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
5632 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
5633
5634 int64 dot_row_size = 1;
5635 int64 dot_col_size = spec.n;
5636 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
5637 builder.AddInstruction(HloInstruction::CreateDot(
5638 dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5639
5640 auto computation = m->AddEntryComputation(builder.Build());
5641 AlgebraicSimplifier simplifier(default_options_);
5642 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5643 ASSERT_TRUE(run_successful);
5644 EXPECT_TRUE(
5645 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5646
5647 if (spec.neg) {
5648 EXPECT_NE(computation->root_instruction()->opcode(),
5649 HloOpcode::kDynamicSlice);
5650 } else {
5651 EXPECT_THAT(computation->root_instruction(),
5652 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
5653 m::Constant(), m::Constant())));
5654 }
5655 }
5656
5657 // input: dot(ctA, DS(ctB))
5658 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
5659 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
5660 // output: DS(dot(ctA, ctB))
5661 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)5662 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
5663 auto m = CreateNewVerifiedModule();
5664 HloComputation::Builder builder(TestName());
5665
5666 DotOfGatherTestSpec spec = GetParam();
5667
5668 ASSERT_LE(spec.s, spec.n);
5669
5670 int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
5671 int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
5672 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
5673 auto* lhs = builder.AddInstruction(
5674 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5675 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
5676 /*cols=*/lhs_cols)));
5677
5678 // For negative tests increase k of the dynamic slice argument to prevent the
5679 // optimization
5680 int64 k_increase = spec.neg ? 5 : 0;
5681 int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
5682 int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
5683 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
5684 auto* rhs = builder.AddInstruction(
5685 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5686 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
5687 /*cols=*/rhs_cols)));
5688
5689 int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
5690 int32 start_col = (spec.rcd == 0) ? spec.s : 0;
5691 std::vector<HloInstruction*> start_indices = {
5692 builder.AddInstruction(HloInstruction::CreateConstant(
5693 LiteralUtil::CreateR0<int32>(start_row))),
5694 builder.AddInstruction(HloInstruction::CreateConstant(
5695 LiteralUtil::CreateR0<int32>(start_col)))};
5696 int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
5697 int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
5698 std::vector<int64> slice_sizes = {slice_row_size, slice_col_size};
5699 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
5700 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5701 ds_shape, rhs, start_indices, slice_sizes));
5702
5703 DotDimensionNumbers dot_dnums;
5704 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
5705 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
5706
5707 int64 dot_row_size = spec.m;
5708 int64 dot_col_size = 1;
5709 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
5710 builder.AddInstruction(HloInstruction::CreateDot(
5711 dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
5712
5713 auto computation = m->AddEntryComputation(builder.Build());
5714 AlgebraicSimplifier simplifier(default_options_);
5715 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5716 ASSERT_TRUE(run_successful);
5717 EXPECT_TRUE(
5718 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
5719
5720 if (spec.neg) {
5721 EXPECT_NE(computation->root_instruction()->opcode(),
5722 HloOpcode::kDynamicSlice);
5723 } else {
5724 EXPECT_THAT(computation->root_instruction(),
5725 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
5726 m::Constant(), m::Constant())));
5727 }
5728 }
5729
DotOfGatherPositiveNegativeTests()5730 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
5731 std::vector<DotOfGatherTestSpec> positives = {
5732 // "Classical dot", i.e. matrix multiply:
5733 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
5734 /*neg=*/false},
5735 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
5736 /*neg=*/false},
5737 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
5738 /*neg=*/false},
5739 // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
5740 // dot(ct, ct) before DotOfGather optimization kicks in.
5741 // Contract on rows:
5742 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
5743 /*neg=*/false},
5744 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
5745 /*neg=*/false},
5746 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
5747 /*neg=*/false},
5748 // Reverse matrix multiply:
5749 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
5750 /*neg=*/false},
5751 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
5752 /*neg=*/false},
5753 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
5754 /*neg=*/false},
5755 // Contract on columns:
5756 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
5757 /*neg=*/false},
5758 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
5759 /*neg=*/false},
5760 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
5761 /*neg=*/false},
5762 };
5763 std::vector<DotOfGatherTestSpec> all;
5764 for (int i = 0; i < positives.size(); i++) {
5765 DotOfGatherTestSpec positive_test = positives[i];
5766 all.push_back(positive_test);
5767 DotOfGatherTestSpec negative_test = positive_test;
5768 negative_test.neg = true;
5769 all.push_back(negative_test);
5770 }
5771 return all;
5772 }
5773
5774 INSTANTIATE_TEST_SUITE_P(
5775 DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
5776 ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
5777
TEST_F(AlgebraicSimplifierTest,GatherOfScalarToBroadcast)5778 TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
5779 const char* hlo_string = R"(
5780 HloModule repeat
5781
5782 ENTRY main {
5783 o = f32[1,1] parameter(0)
5784 i = s32[100,2] parameter(1)
5785 ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
5786 start_index_map={0,1},
5787 index_vector_dim=1,
5788 offset_dims={},
5789 slice_sizes={1,1}
5790 }
5791 )";
5792 TF_ASSERT_OK_AND_ASSIGN(auto module,
5793 ParseAndReturnVerifiedModule(hlo_string));
5794
5795 AlgebraicSimplifierOptions options;
5796 AlgebraicSimplifier simplifier(options);
5797 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5798 auto root = module->entry_computation()->root_instruction();
5799 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
5800 }
5801
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)5802 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
5803 const char* hlo_string = R"(
5804 HloModule module
5805
5806 reducer {
5807 parameter.1 = f32[] parameter(0)
5808 parameter.3 = f32[] parameter(2)
5809 add.2 = f32[] add(parameter.1, parameter.3)
5810 parameter.0 = f32[] parameter(1)
5811 parameter.2 = f32[] parameter(3)
5812 add.3 = f32[] add(parameter.0, parameter.2)
5813 ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
5814 }
5815
5816 ENTRY entry {
5817 parameter.6 = (f32[], f32[]) parameter(0)
5818 get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
5819 get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
5820 constant = f32[] constant(0)
5821 ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
5822 }
5823 )";
5824 TF_ASSERT_OK_AND_ASSIGN(auto module,
5825 ParseAndReturnVerifiedModule(hlo_string));
5826
5827 AlgebraicSimplifierOptions options;
5828 AlgebraicSimplifier simplifier(options);
5829 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5830 auto root = module->entry_computation()->root_instruction();
5831 EXPECT_THAT(root, GmockMatch(m::Tuple(
5832 m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
5833 m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
5834 }
5835
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)5836 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
5837 const char* hlo_string = R"(
5838 HloModule module
5839
5840 reducer {
5841 parameter.1 = f32[] parameter(0)
5842 parameter.3 = f32[] parameter(2)
5843 mul.2 = f32[] add(parameter.1, parameter.3)
5844 parameter.0 = f32[] parameter(1)
5845 parameter.2 = f32[] parameter(3)
5846 add.3 = f32[] add(parameter.0, parameter.2)
5847 ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
5848 }
5849
5850 ENTRY entry {
5851 parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
5852 get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
5853 get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
5854 constant.0 = f32[] constant(0)
5855 constant.1 = f32[] constant(1)
5856 ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
5857 }
5858 )";
5859 TF_ASSERT_OK_AND_ASSIGN(auto module,
5860 ParseAndReturnVerifiedModule(hlo_string));
5861
5862 AlgebraicSimplifierOptions options;
5863 AlgebraicSimplifier simplifier(options);
5864 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5865 auto root = module->entry_computation()->root_instruction();
5866 EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
5867 m::Broadcast(m::ConstantScalar(1)))));
5868 }
5869
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)5870 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
5871 auto builder = HloComputation::Builder(TestName());
5872 HloInstruction* param =
5873 builder.AddInstruction(HloInstruction::CreateParameter(
5874 0, ShapeUtil::MakeShape(F32, {1}), "param"));
5875 HloInstruction* broadcast =
5876 builder.AddInstruction(HloInstruction::CreateBroadcast(
5877 ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
5878
5879 // Create a reshape with zero sized result and without layout.
5880 Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
5881 reshaped_shape.clear_layout();
5882 builder.AddInstruction(
5883 HloInstruction::CreateReshape(reshaped_shape, broadcast));
5884
5885 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
5886 module->AddEntryComputation(builder.Build());
5887
5888 AlgebraicSimplifierOptions options;
5889 AlgebraicSimplifier simplifier(options);
5890 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5891 HloInstruction* root = module->entry_computation()->root_instruction();
5892 EXPECT_THAT(root, GmockMatch(m::Constant()));
5893 }
5894
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)5895 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
5896 Shape shape = ShapeUtil::MakeShape(F32, {});
5897 shape.clear_layout();
5898 auto builder = HloComputation::Builder(TestName());
5899 HloInstruction* param = builder.AddInstruction(
5900 HloInstruction::CreateParameter(0, shape, "param"));
5901
5902 HloInstruction* const_value = builder.AddInstruction(
5903 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
5904 builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
5905 param, const_value));
5906
5907 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
5908 module->AddEntryComputation(builder.Build());
5909
5910 AlgebraicSimplifierOptions options;
5911 AlgebraicSimplifier simplifier(options);
5912 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5913 HloInstruction* root = module->entry_computation()->root_instruction();
5914 EXPECT_THAT(root, GmockMatch(m::Multiply()));
5915 }
5916
5917 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)5918 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
5919 const char* kModuleStr = R"(
5920 HloModule m
5921 test {
5922 p0 = f32[] parameter(0)
5923 p1 = f32[] parameter(1)
5924 sqrt = f32[] sqrt(p0)
5925 ROOT div = f32[] divide(p1, sqrt)
5926 }
5927 )";
5928 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5929 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5930 EXPECT_THAT(m->entry_computation()->root_instruction(),
5931 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
5932 m::Rsqrt(m::Parameter(0)))));
5933 }
5934
5935 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)5936 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
5937 const char* kModuleStr = R"(
5938 HloModule m
5939 test {
5940 p0 = f32[] parameter(0)
5941 p1 = f32[] parameter(1)
5942 rsqrt = f32[] rsqrt(p0)
5943 ROOT div = f32[] divide(p1, rsqrt)
5944 }
5945 )";
5946 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5947 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5948 EXPECT_THAT(m->entry_computation()->root_instruction(),
5949 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
5950 m::Sqrt(m::Parameter(0)))));
5951 }
5952
TEST_F(AlgebraicSimplifierTest,CopyReshape)5953 TEST_F(AlgebraicSimplifierTest, CopyReshape) {
5954 const char* kModuleStr = R"(
5955 HloModule m
5956 test {
5957 p0 = f32[168,168,48,48]{3,2,1,0} parameter(0)
5958 r0 = f32[1,168,168,2304]{3,2,1,0} reshape(p0)
5959 ROOT c0 = f32[1,168,168,2304]{3,0,2,1} copy(r0)
5960 })";
5961 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5962 Shape result_shape = m->entry_computation()->root_instruction()->shape();
5963 AlgebraicSimplifierOptions options(
5964 [](const Shape&, const Shape&) { return false; });
5965 options.set_is_layout_sensitive(true);
5966 ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
5967 EXPECT_THAT(
5968 m->entry_computation()->root_instruction(),
5969 GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape)));
5970 }
5971
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RL)5972 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) {
5973 const char* kModuleStr = R"(
5974 HloModule m
5975 test {
5976 rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
5977 t0 = f32[2, 2, 3] parameter(0)
5978 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
5979 lhs = f32[2, 6] reshape(t1)
5980 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5981 }
5982 )";
5983 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5984 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
5985 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
5986 auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2});
5987 auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2});
5988 // The transformation of moving transpose and reshape to the constant side
5989 // is layout insensitive. We ignore layout when checking shapes.
5990 const HloInstruction* transpose;
5991 ASSERT_THAT(m->entry_computation()->root_instruction(),
5992 GmockMatch(m::Dot(
5993 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
5994 m::Reshape(m::Transpose(&transpose,
5995 m::Reshape(m::Constant())
5996 .WithShapeCompatibleTo(&shape2))
5997 .WithShapeCompatibleTo(&shape3)))));
5998 EXPECT_THAT(transpose->dimensions(), ElementsAre(1, 0, 2));
5999 }
6000
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RR)6001 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) {
6002 const char* kModuleStr = R"(
6003 HloModule m
6004 test {
6005 rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6006 {1, 1, 1, 1, 1, 1}})
6007 t0 = f32[2, 2, 3] parameter(0)
6008 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6009 lhs = f32[2, 6] reshape(t1)
6010 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
6011 }
6012 )";
6013 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6014 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6015 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6016 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6017 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6018 EXPECT_THAT(m->entry_computation()->root_instruction(),
6019 GmockMatch(m::Dot(
6020 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6021 m::Reshape(m::Transpose(m::Reshape(m::Constant())
6022 .WithShapeCompatibleTo(&shape2))
6023 .WithShapeCompatibleTo(&shape3)))));
6024 }
6025
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR)6026 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) {
6027 const char* kModuleStr = R"(
6028 HloModule m
6029 test {
6030 rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6031 {1, 1, 1, 1, 1, 1}})
6032 t0 = f32[2, 3, 2] parameter(0)
6033 t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2}
6034 lhs = f32[6, 2] reshape(t1)
6035 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1}
6036 }
6037 )";
6038 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6039 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6040 auto shape1 = ShapeUtil::MakeShape(F32, {6, 2});
6041 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6042 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6043 EXPECT_THAT(m->entry_computation()->root_instruction(),
6044 GmockMatch(m::Dot(
6045 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6046 m::Reshape(m::Transpose(m::Reshape(m::Constant())
6047 .WithShapeCompatibleTo(&shape2))
6048 .WithShapeCompatibleTo(&shape3)))));
6049 }
6050
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR2)6051 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) {
6052 const char* kModuleStr = R"(
6053 HloModule m
6054 test {
6055 rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}})
6056 t0 = f32[2, 2, 2, 2] parameter(0)
6057 t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1}
6058 lhs = f32[2, 8] reshape(t1)
6059 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1},
6060 rhs_contracting_dims={0}
6061 }
6062 )";
6063 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6064 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6065 auto shape1 = ShapeUtil::MakeShape(F32, {2, 8});
6066 auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
6067 const HloInstruction* transpose;
6068 ASSERT_THAT(
6069 m->entry_computation()->root_instruction(),
6070 GmockMatch(m::Dot(
6071 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6072 m::Reshape(m::Transpose(
6073 &transpose,
6074 m::Reshape(m::Constant()).WithShapeCompatibleTo(&shape2))))));
6075 EXPECT_THAT(transpose->dimensions(), ElementsAre(2, 0, 1, 3));
6076 }
6077
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_MM)6078 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) {
6079 const char* kModuleStr = R"(
6080 HloModule m
6081 test {
6082 rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}},
6083 {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}})
6084 t0 = f32[2, 2, 3, 2] parameter(0)
6085 t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3}
6086 lhs = f32[2, 6, 2] reshape(t1)
6087 ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1},
6088 rhs_batch_dims={0}, rhs_contracting_dims={1}
6089 }
6090 )";
6091 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6092 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6093 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2});
6094 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2});
6095 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
6096 const HloInstruction* transpose;
6097 ASSERT_THAT(m->entry_computation()->root_instruction(),
6098 GmockMatch(m::Dot(
6099 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6100 m::Reshape(m::Transpose(&transpose,
6101 m::Reshape(m::Constant())
6102 .WithShapeCompatibleTo(&shape2))
6103 .WithShapeCompatibleTo(&shape3)))));
6104 EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6105 }
6106
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegTranspose)6107 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) {
6108 const char* kModuleStr = R"(
6109 HloModule m
6110 test {
6111 rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}})
6112 t0 = f32[3, 4, 2] parameter(0)
6113 t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1}
6114 lhs = f32[2, 12] reshape(t1)
6115 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6116 }
6117 )";
6118 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6119 // Transpose affects non-contracting dimension. The transpose and reshape
6120 // should not be moved to the constant side.
6121 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6122 }
6123
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegReshape)6124 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) {
6125 const char* kModuleStr = R"(
6126 HloModule m
6127 test {
6128 rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}})
6129 t0 = f32[2, 4, 3] parameter(0)
6130 t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1}
6131 lhs = f32[3, 8] reshape(t1)
6132 ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6133 }
6134 )";
6135 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6136 // Reshape affects non-contracting dimensions. The transpose and reshape
6137 // should not be moved to the constant side.
6138 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6139 }
6140
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegConstant)6141 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) {
6142 const char* kModuleStr = R"(
6143 HloModule m
6144 test {
6145 t0 = f32[2, 3, 4] parameter(0)
6146 t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1}
6147 lhs = f32[2, 12] reshape(t1)
6148 rhs = f32[12, 2] parameter(1)
6149 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6150 }
6151 )";
6152 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6153 // Both operands are non-constant, so the optimization should not happen.
6154 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6155 }
6156
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegLayout)6157 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) {
6158 const char* kModuleStr = R"(
6159 HloModule m
6160 test {
6161 rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6162 t0 = f32[2, 2, 3] parameter(0)
6163 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6164 lhs = f32[2, 6] reshape(t1)
6165 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6166 }
6167 )";
6168 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6169 // We disable converting reshape to bitcast to make sure algsimp pass does
6170 // not catch the reshape in this test, then we can simply check if algsimp
6171 // pass does not make any change.
6172 AlgebraicSimplifierOptions options(
6173 [](const Shape&, const Shape&) { return false; });
6174 options.set_is_layout_sensitive(true);
6175 // The transformation of moving transpose and reshape to the constant side is
6176 // layout insensitive. It should not happen if AlgebraicSimplifier is set up
6177 // to be layout sensitive.
6178 ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6179 }
6180
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDimsNoChange)6181 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) {
6182 // This isn't transformed (notice that the relative order of the `2` and `3`
6183 // dims doesn't change, so there's no opportunity here), but it's nonetheless
6184 // an interesting testcase because of the presence of the size-1 dimensions.
6185 const char* kModuleStr = R"(
6186 HloModule m
6187 test {
6188 param = f32[1,2,5,3] parameter(0)
6189 transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3}
6190 reshape = f32[5,6] reshape(transpose)
6191 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6192 ROOT dot = f32[5,4] dot(reshape, constant),
6193 lhs_contracting_dims={1}, rhs_contracting_dims={0}
6194 }
6195 )";
6196 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6197 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6198 }
6199
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDims)6200 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
6201 const char* kModuleStr = R"(
6202 HloModule m
6203 test {
6204 param = f32[1,2,3,5] parameter(0)
6205 transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3}
6206 reshape = f32[6,5] reshape(transpose)
6207 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6208 ROOT dot = f32[5,4] dot(reshape, constant),
6209 lhs_contracting_dims={0}, rhs_contracting_dims={0}
6210 }
6211 )";
6212 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6213 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6214 auto shape1 = ShapeUtil::MakeShape(F32, {6, 5});
6215 auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4});
6216 auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
6217 const HloInstruction* transpose;
6218 ASSERT_THAT(m->entry_computation()->root_instruction(),
6219 GmockMatch(m::Dot(
6220 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6221 m::Reshape(m::Transpose(&transpose,
6222 m::Reshape(m::Constant())
6223 .WithShapeCompatibleTo(&shape2))
6224 .WithShapeCompatibleTo(&shape3)))));
6225 EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6226 }
6227
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NoChangeInContractingDimsOrder)6228 TEST_F(AlgebraicSimplifierTest,
6229 DotContractingReorder_NoChangeInContractingDimsOrder) {
6230 // No optimization opportunity here because the transpose does not reorder the
6231 // contracting dims.
6232 const char* kModuleStr = R"(
6233 HloModule m
6234 test {
6235 param = f32[2,5,1,3] parameter(0)
6236 transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
6237 reshape = f32[5,6] reshape(transpose)
6238 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6239 ROOT dot = f32[5,4] dot(reshape, constant),
6240 lhs_contracting_dims={1}, rhs_contracting_dims={0}
6241 }
6242 )";
6243 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6244 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6245 }
6246
TEST_F(AlgebraicSimplifierTest,CompareIota)6247 TEST_F(AlgebraicSimplifierTest, CompareIota) {
6248 const char* kModuleStr = R"(
6249 HloModule m
6250 test {
6251 zero = s32[] constant(0)
6252 iota = s32[128] iota(), iota_dimension=0
6253 broad = s32[128] broadcast(zero), dimensions={}
6254 ROOT compare = pred[128] compare(iota, broad), direction=LT
6255 })";
6256 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6257 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6258 EXPECT_THAT(m->entry_computation()->root_instruction(),
6259 GmockMatch(m::Broadcast(m::ConstantScalar(false))));
6260 }
6261
TEST_F(AlgebraicSimplifierTest,CompareLtZero)6262 TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
6263 const char* kModuleStr = R"(
6264 HloModule m
6265 test {
6266 zero = u32[] constant(0)
6267 param = u32[] parameter(0)
6268 ROOT compare = pred[] compare(param, zero), direction=LT
6269 })";
6270 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6271 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6272 EXPECT_THAT(m->entry_computation()->root_instruction(),
6273 GmockMatch(m::ConstantScalar(false)));
6274 }
6275
TEST_F(AlgebraicSimplifierTest,CompareLeZero)6276 TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
6277 const char* kModuleStr = R"(
6278 HloModule m
6279 test {
6280 zero = u32[] constant(0)
6281 param = u32[] parameter(0)
6282 ROOT compare = pred[] compare(param, zero), direction=LE
6283 })";
6284 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6285 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6286 EXPECT_THAT(
6287 m->entry_computation()->root_instruction(),
6288 GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6289 }
6290
TEST_F(AlgebraicSimplifierTest,CompareGeZero)6291 TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
6292 const char* kModuleStr = R"(
6293 HloModule m
6294 test {
6295 zero = u32[] constant(0)
6296 param = u32[] parameter(0)
6297 ROOT compare = pred[] compare(param, zero), direction=GE
6298 })";
6299 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6300 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6301 EXPECT_THAT(m->entry_computation()->root_instruction(),
6302 GmockMatch(m::ConstantScalar(true)));
6303 }
6304
TEST_F(AlgebraicSimplifierTest,CompareGtZero)6305 TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
6306 const char* kModuleStr = R"(
6307 HloModule m
6308 test {
6309 zero = u32[] constant(0)
6310 param = u32[] parameter(0)
6311 ROOT compare = pred[] compare(param, zero), direction=GT
6312 })";
6313 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6314 EXPECT_THAT(
6315 m->entry_computation()->root_instruction(),
6316 GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6317 }
6318
TEST_F(AlgebraicSimplifierTest,CompareZeroGt)6319 TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
6320 const char* kModuleStr = R"(
6321 HloModule m
6322 test {
6323 zero = u32[] constant(0)
6324 param = u32[] parameter(0)
6325 ROOT compare = pred[] compare(zero, param), direction=GT
6326 })";
6327 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6328 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6329 EXPECT_THAT(m->entry_computation()->root_instruction(),
6330 GmockMatch(m::ConstantScalar(false)));
6331 }
6332
TEST_F(AlgebraicSimplifierTest,CompareZeroGe)6333 TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
6334 const char* kModuleStr = R"(
6335 HloModule m
6336 test {
6337 zero = u32[] constant(0)
6338 param = u32[] parameter(0)
6339 ROOT compare = pred[] compare(zero, param), direction=GE
6340 })";
6341 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6342 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6343 EXPECT_THAT(
6344 m->entry_computation()->root_instruction(),
6345 GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
6346 }
6347
TEST_F(AlgebraicSimplifierTest,CompareZeroLe)6348 TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
6349 const char* kModuleStr = R"(
6350 HloModule m
6351 test {
6352 zero = u32[] constant(0)
6353 param = u32[] parameter(0)
6354 ROOT compare = pred[] compare(zero, param), direction=LE
6355 })";
6356 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6357 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6358 EXPECT_THAT(m->entry_computation()->root_instruction(),
6359 GmockMatch(m::ConstantScalar(true)));
6360 }
6361
TEST_F(AlgebraicSimplifierTest,CompareZeroLt)6362 TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
6363 const char* kModuleStr = R"(
6364 HloModule m
6365 test {
6366 zero = u32[] constant(0)
6367 param = u32[] parameter(0)
6368 ROOT compare = pred[] compare(zero, param), direction=LT
6369 })";
6370 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6371 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6372 EXPECT_THAT(
6373 m->entry_computation()->root_instruction(),
6374 GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
6375 }
6376
TEST_F(AlgebraicSimplifierTest,CompareSame)6377 TEST_F(AlgebraicSimplifierTest, CompareSame) {
6378 const char* kModuleStr = R"(
6379 HloModule m
6380 test {
6381 param = s32[123] parameter(0)
6382 ROOT compare = pred[123] compare(param, param), direction=GE
6383 })";
6384 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6385 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6386 EXPECT_THAT(m->entry_computation()->root_instruction(),
6387 GmockMatch(m::Broadcast(m::ConstantScalar(true))));
6388 }
6389
TEST_F(AlgebraicSimplifierTest,CompareSimplified)6390 TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
6391 const char* kModuleStr = R"(
6392 HloModule m
6393 test {
6394 param = s32[] parameter(0)
6395 c1 = s32[] constant(10)
6396 c2 = s32[] constant(100)
6397 cmp1 = pred[] compare(param, c1), direction=LT
6398 cmp2 = pred[] compare(param, c2), direction=LT
6399 ROOT out = pred[] and(cmp1, cmp2)
6400 })";
6401 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6402 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6403 EXPECT_THAT(
6404 m->entry_computation()->root_instruction(),
6405 GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
6406 .WithComparisonDirection(ComparisonDirection::kLt)));
6407 }
6408
TEST_F(AlgebraicSimplifierTest,CompareSimplifiedReversed)6409 TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) {
6410 const char* kModuleStr = R"(
6411 HloModule m
6412 test {
6413 param = s32[] parameter(0)
6414 c1 = s32[] constant(10)
6415 c2 = s32[] constant(100)
6416 cmp1 = pred[] compare(param, c1), direction=LT
6417 cmp2 = pred[] compare(c2, param), direction=GT
6418 ROOT out = pred[] and(cmp1, cmp2)
6419 })";
6420 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6421 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6422 EXPECT_THAT(
6423 m->entry_computation()->root_instruction(),
6424 GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
6425 .WithComparisonDirection(ComparisonDirection::kLt)));
6426 }
6427
TEST_F(AlgebraicSimplifierTest,CanDisableDotToMultiplyRewrite)6428 TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
6429 // Some backends may have better performance by treating an outer product as a
6430 // Dot, rather than a broadcast Multiply
6431 const char* kModuleStr = R"(
6432 HloModule m
6433 test {
6434 param1 = f32[64] parameter(0)
6435 param2 = f32[64] parameter(1)
6436 ROOT compare = f32[64, 64] dot(param1, param2),
6437 lhs_contracting_dims={}, rhs_contracting_dims={}
6438 })";
6439
6440 // Verify that the default is to re-write
6441 TF_ASSERT_OK_AND_ASSIGN(auto m1, ParseAndReturnVerifiedModule(kModuleStr));
6442 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m1.get()).ValueOrDie());
6443 EXPECT_THAT(m1->entry_computation()->root_instruction(),
6444 GmockMatch(m::Multiply(m::Op(), m::Op())));
6445
6446 // Verify that we can disable the re-write
6447 AlgebraicSimplifierOptions opts = default_options_;
6448 opts.set_enable_dot_to_multiply_rewrite(false);
6449 TF_ASSERT_OK_AND_ASSIGN(auto m2, ParseAndReturnVerifiedModule(kModuleStr));
6450 ASSERT_FALSE(AlgebraicSimplifier(opts).Run(m2.get()).ValueOrDie());
6451 }
6452
TEST_F(AlgebraicSimplifierTest,RemainderOfIota)6453 TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
6454 const char* kModuleStr = R"(
6455 HloModule m
6456 test {
6457 iota = s32[5,1000] iota(), iota_dimension=0
6458 five = s32[] constant(5)
6459 five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
6460 ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
6461 })";
6462 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6463 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6464 EXPECT_THAT(m->entry_computation()->root_instruction(),
6465 GmockMatch(m::Iota()));
6466 }
6467
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIota)6468 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
6469 const char* kModuleStr = R"(
6470 HloModule m
6471 test {
6472 iota = s32[5,1000] iota(), iota_dimension=0
6473 five = s32[] constant(5)
6474 five_bcast = s32[5,1000] broadcast(five), dimensions={}
6475 sum = s32[5,1000] add(iota, five_bcast)
6476 ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
6477 })";
6478 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6479 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6480 EXPECT_THAT(m->entry_computation()->root_instruction(),
6481 GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
6482 }
6483
6484 // No simplification because 125 + 5 overflows S8.
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIotaOverflow)6485 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
6486 const char* kModuleStr = R"(
6487 HloModule m
6488 test {
6489 iota = s8[126] iota(), iota_dimension=0
6490 five = s8[] constant(5)
6491 five_bcast = s8[126] broadcast(five), dimensions={}
6492 sum = s8[126] add(iota, five_bcast)
6493 ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
6494 })";
6495 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6496 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6497 }
6498
TEST_F(AlgebraicSimplifierTest,RepeatedRemainder)6499 TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
6500 const char* kModuleStr = R"(
6501 HloModule m
6502 test {
6503 p = s32[1000] parameter(0)
6504 q = s32[1000] parameter(1)
6505 r = s32[1000] remainder(p, q)
6506 ROOT rr = s32[1000] remainder(r, q)
6507 })";
6508 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6509 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6510 EXPECT_THAT(m->entry_computation()->root_instruction(),
6511 GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
6512 }
6513
TEST_F(AlgebraicSimplifierTest,SlicePadLayout)6514 TEST_F(AlgebraicSimplifierTest, SlicePadLayout) {
6515 const char* kModuleStr = R"(
6516 HloModule m
6517 test {
6518 %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0)
6519 %param.1 = f32[] parameter(1)
6520 %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0),
6521 slice={[0:128], [0:9], [0:9], [0:1024]}
6522 ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1),
6523 padding=0_0x-1_0x0_0x0_0
6524 })";
6525 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6526 const Shape root_shape = m->entry_computation()->root_instruction()->shape();
6527 AlgebraicSimplifierOptions options;
6528 options.set_is_layout_sensitive(true);
6529 ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6530 EXPECT_THAT(m->entry_computation()->root_instruction(),
6531 GmockMatch(m::Slice().WithShapeEqualTo(&root_shape)));
6532 }
6533
TEST_F(AlgebraicSimplifierTest,MinOfMaxToClamp)6534 TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) {
6535 const char* kModuleStr = R"(
6536 HloModule m
6537 test {
6538 p0 = f32[4] parameter(0)
6539 c0 = f32[] constant(3.0)
6540 c1 = f32[] constant(4.0)
6541 b0 = f32[4] broadcast(c0), dimensions={}
6542 b1 = f32[4] broadcast(c1), dimensions={}
6543 m0 = f32[4] maximum(b0, p0)
6544 ROOT m1 = f32[4] minimum(m0, b1)
6545 }
6546 )";
6547 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6548 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6549 EXPECT_THAT(
6550 m->entry_computation()->root_instruction(),
6551 GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
6552 m::Broadcast(m::ConstantScalar(4.0)))));
6553 }
6554
TEST_F(AlgebraicSimplifierTest,MaxOfMinToClamp)6555 TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) {
6556 const char* kModuleStr = R"(
6557 HloModule m
6558 test {
6559 p0 = f32[4] parameter(0)
6560 c0 = f32[] constant(3.0)
6561 c1 = f32[] constant(4.0)
6562 b0 = f32[4] broadcast(c0), dimensions={}
6563 b1 = f32[4] broadcast(c1), dimensions={}
6564 m0 = f32[4] minimum(p0, b1)
6565 ROOT m1 = f32[4] maximum(b0, m0)
6566 }
6567 )";
6568 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6569 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6570 EXPECT_THAT(
6571 m->entry_computation()->root_instruction(),
6572 GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
6573 m::Broadcast(m::ConstantScalar(4.0)))));
6574 }
6575
TEST_F(AlgebraicSimplifierTest,ClampOfClamp)6576 TEST_F(AlgebraicSimplifierTest, ClampOfClamp) {
6577 const char* kModuleStr = R"(
6578 HloModule m
6579 test {
6580 p0 = f32[] parameter(0)
6581 p1 = f32[] parameter(1)
6582 p2 = f32[] parameter(2)
6583 c0 = f32[] clamp(p0, p1, p2)
6584 ROOT c1 = f32[] clamp(p0, c0, p2)
6585 }
6586 )";
6587 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6588 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6589 EXPECT_THAT(
6590 m->entry_computation()->root_instruction(),
6591 GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
6592 }
6593
TEST_F(AlgebraicSimplifierTest,MaxOfClamp)6594 TEST_F(AlgebraicSimplifierTest, MaxOfClamp) {
6595 const char* kModuleStr = R"(
6596 HloModule m
6597 test {
6598 p0 = f32[] parameter(0)
6599 p1 = f32[] parameter(1)
6600 p2 = f32[] parameter(2)
6601 c0 = f32[] clamp(p0, p1, p2)
6602 ROOT m0 = f32[] maximum(p0, c0)
6603 }
6604 )";
6605 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6606 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6607 EXPECT_THAT(
6608 m->entry_computation()->root_instruction(),
6609 GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
6610 }
6611
TEST_F(AlgebraicSimplifierTest,SliceOfConcat)6612 TEST_F(AlgebraicSimplifierTest, SliceOfConcat) {
6613 const char* kModuleStr = R"(
6614 HloModule m
6615 test {
6616 p0 = f32[100,50] parameter(0)
6617 p1 = f32[50,50] parameter(1)
6618 c0 = f32[150,50] concatenate(p0, p1), dimensions={0}
6619 ROOT s0 = f32[50,50] slice(c0), slice={[100:150], [0:50]}
6620 }
6621 )";
6622 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6623 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6624 EXPECT_THAT(m->entry_computation()->root_instruction(),
6625 GmockMatch(m::Parameter(1)));
6626 }
6627
TEST_F(AlgebraicSimplifierTest,SqrtOfSelfMultiply)6628 TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) {
6629 const char* kModuleStr = R"(
6630 HloModule m
6631 test {
6632 p0 = f32[32]{0} parameter(0)
6633 m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0)
6634 ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0)
6635 }
6636 )";
6637 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6638 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6639 EXPECT_THAT(m->entry_computation()->root_instruction(),
6640 GmockMatch(m::Abs(m::Parameter(0))));
6641 }
6642
TEST_F(AlgebraicSimplifierTest,ReduceOfBatchDotToContractingDimension)6643 TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
6644 const char* kModuleStr = R"(
6645 HloModule m
6646 a {
6647 p0 = f32[] parameter(0)
6648 p1 = f32[] parameter(1)
6649 ROOT r = f32[] add(p0, p1)
6650 }
6651 test {
6652 p0 = f32[32,8,5,6] parameter(0)
6653 p1 = f32[8,32,6,7] parameter(1)
6654 d = f32[32,8,5,7] dot(p0, p1),
6655 lhs_batch_dims={0,1},
6656 rhs_batch_dims={1,0},
6657 rhs_contracting_dims={2},
6658 lhs_contracting_dims={3}
6659 c = f32[] constant(0)
6660 ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
6661 }
6662 )";
6663 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6664 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6665 EXPECT_THAT(m->entry_computation()->root_instruction(),
6666 GmockMatch(m::Dot(m::Parameter(0), m::Parameter(1))));
6667 }
6668
TEST_F(AlgebraicSimplifierTest,RsqrtOfRPower)6669 TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) {
6670 const char* kModuleStr = R"(
6671 HloModule m
6672 test {
6673 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6674 p1 = f32[32]{0} parameter(1)
6675 p2 = f32[32]{0} parameter(2)
6676 c0 = f32[] constant(0.001)
6677 c1 = s64[] constant(1)
6678 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6679 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6680 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6681 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6682 c2 = f32[] constant(-2)
6683 broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={}
6684 power = f32[32]{0} power(get-tuple-element, broadcast)
6685 rsqrt = f32[32]{0} rsqrt(f32[32]{0} power)
6686 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
6687 }
6688 )";
6689 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6690 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6691 "__cudnn$batchNormalizationForwardTraining");
6692 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6693 // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2)
6694 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr);
6695 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6696 auto computation = m->entry_computation();
6697 auto root = computation->root_instruction();
6698 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6699 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs);
6700 EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
6701 HloOpcode::kGetTupleElement);
6702 }
6703
TEST_F(AlgebraicSimplifierTest,RsqrtDivide)6704 TEST_F(AlgebraicSimplifierTest, RsqrtDivide) {
6705 const char* kModuleStr = R"(
6706 HloModule m
6707 test {
6708 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6709 p1 = f32[32]{0} parameter(1)
6710 p2 = f32[32]{0} parameter(2)
6711 constant = f32[] constant(0.001)
6712 constant.1 = s64[] constant(1)
6713 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6714 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6715 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6716 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6717 constant.2 = f32[] constant(1)
6718 broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={}
6719 divide = f32[32]{0} divide(broadcast.1, get-tuple-element)
6720 rsqrt = f32[32]{0} rsqrt(divide)
6721 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
6722 }
6723 )";
6724 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6725 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6726 "__cudnn$batchNormalizationForwardTraining");
6727 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6728 // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2)
6729 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
6730 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6731 auto computation = m->entry_computation();
6732 auto root = computation->root_instruction();
6733 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6734 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt);
6735 EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
6736 HloOpcode::kGetTupleElement);
6737 }
6738
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt)6739 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) {
6740 const char* kModuleStr = R"(
6741 HloModule m
6742 test {
6743 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6744 p1 = f32[32]{0} parameter(1)
6745 p2 = f32[32]{0} parameter(2)
6746 constant = f32[] constant(0.001)
6747 constant.1 = s64[] constant(1)
6748 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6749 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6750 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6751 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6752 rsqrt = f32[32]{0} rsqrt(get-tuple-element)
6753 multiply = f32[32]{0} multiply(rsqrt, rsqrt)
6754 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
6755 }
6756 )";
6757 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6758 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6759 "__cudnn$batchNormalizationForwardTraining");
6760 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6761
6762 // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1,
6763 // gte.2)
6764 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
6765 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6766
6767 auto computation = m->entry_computation();
6768 auto root = computation->root_instruction();
6769 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
6770 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide);
6771 EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast);
6772 EXPECT_EQ(root->operand(2)->operand(1)->opcode(),
6773 HloOpcode::kGetTupleElement);
6774 }
6775
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt_NegativeTestCase)6776 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) {
6777 const char* kModuleStr = R"(
6778 HloModule m
6779 test {
6780 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6781 p1 = f32[32]{0} parameter(1)
6782 p2 = f32[32]{0} parameter(2)
6783 constant = f32[] constant(0.001)
6784 constant.1 = s64[] constant(1)
6785 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6786 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6787 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6788 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6789 rsqrt = f32[32]{0} rsqrt(get-tuple-element)
6790 multiply = f32[32]{0} multiply(rsqrt, rsqrt)
6791 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
6792 }
6793 )";
6794 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6795 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6796 "__cudnn$batchNormalizationForward");
6797 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6798 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
6799 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
6800 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
6801 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr);
6802 EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
6803 HloOpcode::kMultiply);
6804 }
6805
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining)6806 TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) {
6807 const char* kModuleStr = R"(
6808 HloModule m
6809 test {
6810 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6811 p1 = f32[32]{0} parameter(1)
6812 p2 = f32[32]{0} parameter(2)
6813 constant = f32[] constant(0.001)
6814 constant.1 = s64[] constant(1)
6815 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6816 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6817 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6818 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6819 abs = f32[32]{0} abs(get-tuple-element)
6820 ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
6821 }
6822 )";
6823 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6824 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6825 "__cudnn$batchNormalizationForwardTraining");
6826 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6827 // Verify that the module doesn't have any abs node.
6828 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
6829 EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
6830 HloOpcode::kGetTupleElement);
6831 }
6832
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining_NegativeTestCase)6833 TEST_F(AlgebraicSimplifierTest,
6834 AbsEliminationBatchnormTraining_NegativeTestCase) {
6835 const char* kModuleStr = R"(
6836 HloModule m
6837 test {
6838 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
6839 p1 = f32[32]{0} parameter(1)
6840 p2 = f32[32]{0} parameter(2)
6841 constant = f32[] constant(0.001)
6842 constant.1 = s64[] constant(1)
6843 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
6844 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
6845 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
6846 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
6847 abs = f32[32]{0} abs(get-tuple-element)
6848 ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
6849 }
6850 )";
6851 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6852 default_options_.set_cudnn_batchnorm_forward_training_metadata(
6853 "__cudnn$batchNormalizationForwardInference");
6854 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6855 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
6856 }
6857
TEST_F(AlgebraicSimplifierTest,AbsEliminationMultiply)6858 TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
6859 const char* kModuleStr = R"(
6860 HloModule m
6861 test {
6862 p = f32[32]{0} parameter(0)
6863 m = f32[32]{0} multiply(p, p)
6864 ROOT a = f32[32]{0} abs(m)
6865 }
6866 )";
6867 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6868 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6869 EXPECT_THAT(m->entry_computation()->root_instruction(),
6870 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
6871 }
6872
TEST_F(AlgebraicSimplifierTest,BroadcastCompareSimplification)6873 TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) {
6874 std::string module_string = R"(
6875 HloModule m
6876 test {
6877 a = s32[] parameter(0)
6878 b = s32[] parameter(1)
6879 x = s32[10]{0} parameter(2)
6880 broadcast_a = s32[10]{0} broadcast(a), dimensions={}
6881 broadcast_b = s32[10]{0} broadcast(b), dimensions={}
6882 add = s32[10]{0} add(broadcast_a, x)
6883 ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ
6884 }
6885 )";
6886 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string));
6887 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6888 EXPECT_THAT(m->entry_computation()->root_instruction(),
6889 GmockMatch(m::Compare(m::Parameter(2),
6890 m::Broadcast(m::Subtract(
6891 m::Parameter(1), m::Parameter(0))))));
6892
6893 // Numerically unstable transformation shouldn't be applied to floating types.
6894 std::string module_string_f32 =
6895 absl::StrReplaceAll(module_string, {{"s32", "f32"}});
6896 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6897 }
6898
TEST_F(AlgebraicSimplifierTest,AbsEliminationPower2)6899 TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
6900 const char* kModuleStr = R"(
6901 HloModule m
6902 test {
6903 p0 = f32[32]{0} parameter(0)
6904 c0 = f32[] constant(2)
6905 b0 = f32[32]{0} broadcast(c0), dimensions={}
6906 pow = f32[32]{0} power(p0, b0)
6907 ROOT a = f32[32]{0} abs(pow)
6908 }
6909 )";
6910 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6911 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6912 // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is
6913 // transformed to AA.
6914 EXPECT_THAT(m->entry_computation()->root_instruction(),
6915 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
6916 }
6917
TEST_F(AlgebraicSimplifierTest,ScatterAddCombined)6918 TEST_F(AlgebraicSimplifierTest, ScatterAddCombined) {
6919 const char* hlo_string = R"(
6920 HloModule m
6921 apply {
6922 a = f32[] parameter(0)
6923 b = f32[] parameter(1)
6924 ROOT c = f32[] add(a, b)
6925 }
6926 test {
6927 z = f32[] constant(0)
6928 init = f32[100,4] broadcast(z), dimensions={}
6929 shared = f32[100,4] parameter(0)
6930 index0 = s32[20] parameter(1)
6931 index1 = s32[10] parameter(2)
6932 update0 = f32[20,4] parameter(3)
6933 update1 = f32[10,4] parameter(4)
6934 scatter.0 = f32[100,4] scatter(init, index0, update0),
6935 to_apply=apply,
6936 update_window_dims={1},
6937 inserted_window_dims={0},
6938 scatter_dims_to_operand_dims={0},
6939 index_vector_dim=1
6940 scatter.1 = f32[100,4] scatter(init, index1, update1),
6941 to_apply=apply,
6942 update_window_dims={1},
6943 inserted_window_dims={0},
6944 scatter_dims_to_operand_dims={0},
6945 index_vector_dim=1
6946 add.0 = f32[100,4] add(shared, scatter.0)
6947 ROOT add.1 = f32[100,4] add(add.0, scatter.1)
6948 }
6949 )";
6950 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
6951 // Combine Scatters
6952 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6953 // Optimize Add with 0
6954 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6955 EXPECT_THAT(
6956 m->entry_computation()->root_instruction(),
6957 GmockMatch(m::Scatter(m::Parameter(0),
6958 m::Concatenate(m::Parameter(1), m::Parameter(2)),
6959 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
6960 }
6961
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedSwapped)6962 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedSwapped) {
6963 const char* hlo_string = R"(
6964 HloModule m
6965 apply {
6966 a = f32[] parameter(0)
6967 b = f32[] parameter(1)
6968 ROOT c = f32[] add(a, b)
6969 }
6970 test {
6971 z = f32[] constant(0)
6972 init = f32[100,4] broadcast(z), dimensions={}
6973 shared = f32[100,4] parameter(0)
6974 index0 = s32[20] parameter(1)
6975 index1 = s32[10] parameter(2)
6976 update0 = f32[20,4] parameter(3)
6977 update1 = f32[10,4] parameter(4)
6978 scatter.0 = f32[100,4] scatter(init, index0, update0),
6979 to_apply=apply,
6980 update_window_dims={1},
6981 inserted_window_dims={0},
6982 scatter_dims_to_operand_dims={0},
6983 index_vector_dim=1
6984 scatter.1 = f32[100,4] scatter(init, index1, update1),
6985 to_apply=apply,
6986 update_window_dims={1},
6987 inserted_window_dims={0},
6988 scatter_dims_to_operand_dims={0},
6989 index_vector_dim=1
6990 add.0 = f32[100,4] add(shared, scatter.0)
6991 ROOT add.1 = f32[100,4] add(scatter.1, add.0)
6992 }
6993 )";
6994 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
6995 // Combine Scatters
6996 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6997 // Optimize Add with 0
6998 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6999 EXPECT_THAT(
7000 m->entry_computation()->root_instruction(),
7001 GmockMatch(m::Scatter(m::Parameter(0),
7002 m::Concatenate(m::Parameter(2), m::Parameter(1)),
7003 m::Concatenate(m::Parameter(4), m::Parameter(3)))));
7004 }
7005
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums)7006 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums) {
7007 const char* hlo_string = R"(
7008 HloModule m
7009 apply {
7010 a = f32[] parameter(0)
7011 b = f32[] parameter(1)
7012 ROOT c = f32[] add(a, b)
7013 }
7014 test {
7015 z = f32[] constant(0)
7016 init = f32[100,4] broadcast(z), dimensions={}
7017 shared = f32[100,4] parameter(0)
7018 index0 = s32[1,4,5] parameter(1)
7019 index1 = s32[1,2,5] parameter(2)
7020 update0 = f32[4,4,5] parameter(3)
7021 update1 = f32[2,4,5] parameter(4)
7022 scatter.0 = f32[100,4] scatter(init, index0, update0),
7023 to_apply=apply,
7024 update_window_dims={1},
7025 inserted_window_dims={0},
7026 scatter_dims_to_operand_dims={0},
7027 index_vector_dim=0
7028 scatter.1 = f32[100,4] scatter(init, index1, update1),
7029 to_apply=apply,
7030 update_window_dims={1},
7031 inserted_window_dims={0},
7032 scatter_dims_to_operand_dims={0},
7033 index_vector_dim=0
7034 ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7035 }
7036 )";
7037 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7038 // Combine Scatters
7039 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7040 // Simplify Add
7041 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7042 EXPECT_THAT(
7043 m->entry_computation()->root_instruction(),
7044 GmockMatch(m::Scatter(m::Broadcast(),
7045 m::Concatenate(m::Parameter(1), m::Parameter(2)),
7046 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7047 }
7048
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums2)7049 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
7050 const char* hlo_string = R"(
7051 HloModule m
7052 apply {
7053 a = f32[] parameter(0)
7054 b = f32[] parameter(1)
7055 ROOT c = f32[] add(a, b)
7056 }
7057 test {
7058 z = f32[] constant(0)
7059 init = f32[100,4] broadcast(z), dimensions={}
7060 shared = f32[100,4] parameter(0)
7061 index0 = s32[4,3,1] parameter(1)
7062 index1 = s32[4,5,1] parameter(2)
7063 update0 = f32[4,4,3] parameter(3)
7064 update1 = f32[4,4,5] parameter(4)
7065 scatter.0 = f32[100,4] scatter(init, index0, update0),
7066 to_apply=apply,
7067 update_window_dims={0},
7068 inserted_window_dims={0},
7069 scatter_dims_to_operand_dims={0},
7070 index_vector_dim=2
7071 scatter.1 = f32[100,4] scatter(init, index1, update1),
7072 to_apply=apply,
7073 update_window_dims={0},
7074 inserted_window_dims={0},
7075 scatter_dims_to_operand_dims={0},
7076 index_vector_dim=2
7077 ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7078 }
7079 )";
7080 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7081 // Combine Scatters
7082 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7083 // Simplify Add
7084 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7085 EXPECT_THAT(
7086 m->entry_computation()->root_instruction(),
7087 GmockMatch(m::Scatter(m::Broadcast(),
7088 m::Concatenate(m::Parameter(1), m::Parameter(2)),
7089 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7090 }
7091
TEST_F(AlgebraicSimplifierTest,ScalarScatter)7092 TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
7093 const char* hlo_string = R"(
7094 HloModule m
7095 apply {
7096 a = f32[] parameter(0)
7097 b = f32[] parameter(1)
7098 ROOT c = f32[] add(a, b)
7099 }
7100 test {
7101 z = f32[] constant(0)
7102 init = f32[100,4,20] broadcast(z), dimensions={}
7103 shared = f32[100,4,20] parameter(0)
7104 index0 = s32[1] parameter(1)
7105 index1 = s32[1] parameter(2)
7106 update0 = f32[4,20] parameter(3)
7107 update1 = f32[4,20] parameter(4)
7108 scatter.0 = f32[100,4,20] scatter(init, index0, update0),
7109 to_apply=apply,
7110 update_window_dims={0, 1},
7111 inserted_window_dims={0},
7112 scatter_dims_to_operand_dims={0},
7113 index_vector_dim=0
7114 scatter.1 = f32[100,4,20] scatter(init, index1, update1),
7115 to_apply=apply,
7116 update_window_dims={0, 1},
7117 inserted_window_dims={0},
7118 scatter_dims_to_operand_dims={0},
7119 index_vector_dim=0
7120 ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
7121 }
7122 )";
7123 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7124 // Combine Scatters
7125 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7126 }
7127
TEST_F(AlgebraicSimplifierTest,SwapConvOperands)7128 TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
7129 const char* hlo_string = R"(
7130 HloModule m
7131 test {
7132 a = f32[3,3,160,160] parameter(0)
7133 b = f32[128,32,32,160] parameter(1)
7134 ROOT c = f32[128,32,32,160] convolution(a,b),
7135 window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1},
7136 dim_labels=01bf_o01i->f01b
7137 }
7138 )";
7139 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7140 // Combine Scatters
7141 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7142 const HloInstruction* conv = m->entry_computation()->root_instruction();
7143 EXPECT_THAT(conv,
7144 GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0))));
7145 EXPECT_EQ(conv->window().dimensions(0).size(), 3);
7146 EXPECT_EQ(conv->window().dimensions(1).size(), 3);
7147 EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true);
7148 EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true);
7149 EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1);
7150 EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1);
7151 EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1);
7152 EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
7153 }
7154
TEST_F(AlgebraicSimplifierTest,ScalarDividePredicate)7155 TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
7156 const char* kModuleStr = R"(
7157 HloModule m
7158 test {
7159 p0 = pred[2] parameter(0)
7160 cvt = f32[2] convert(p0)
7161 p1 = f32[] parameter(1)
7162 bcast = f32[2] broadcast(p1), dimensions={}
7163 ROOT div = f32[2] divide(cvt, bcast)
7164 }
7165 )";
7166 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7167 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7168 EXPECT_THAT(
7169 m->entry_computation()->root_instruction(),
7170 GmockMatch(m::MultiplyAnyOrder(
7171 m::Convert(m::Parameter(0)),
7172 m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
7173 }
7174
TEST_F(AlgebraicSimplifierTest,MultipleDotStrengthReductions)7175 TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
7176 constexpr char kModuleStr[] = R"(
7177 HloModule test
7178 ENTRY test {
7179 a = c64[2,2] parameter(0)
7180 b = c64[2] parameter(1)
7181 cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7182 c = f64[2,2] parameter(2)
7183 d = f64[2] parameter(3)
7184 dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7185 ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
7186 }
7187 )";
7188 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7189 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7190 EXPECT_EQ(3, m->computation_count());
7191 }
7192
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduce)7193 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) {
7194 const char* kModuleStr = R"(
7195 HloModule m
7196 fn {
7197 p0 = f32[] parameter(0)
7198 p1 = f32[] parameter(1)
7199 a = f32[] add(p0, p1)
7200 ROOT t = (f32[]) tuple(a)
7201 }
7202 test {
7203 p0 = f32[32,8,6,7] parameter(0)
7204 c = f32[] constant(0)
7205 ROOT r = (f32[8,6,7]) reduce(p0, c), dimensions={0}, to_apply=fn
7206 }
7207 )";
7208 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7209 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7210 ASSERT_THAT(
7211 m->entry_computation()->root_instruction(),
7212 GmockMatch(m::Tuple(m::Reduce(m::Parameter(0), m::ConstantScalar(0)))));
7213 ASSERT_EQ(m->entry_computation()
7214 ->root_instruction()
7215 ->operand(0)
7216 ->called_computations()
7217 .size(),
7218 1);
7219 EXPECT_THAT(m->entry_computation()
7220 ->root_instruction()
7221 ->operand(0)
7222 ->called_computations()[0]
7223 ->root_instruction(),
7224 GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
7225 }
7226
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorder)7227 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) {
7228 const char* kModuleStr = R"(
7229 HloModule m
7230 test {
7231 c1 = pred[] constant(true)
7232 b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={}
7233 c3 = pred[] constant(false)
7234 ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0
7235 }
7236 )";
7237 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7238 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7239 EXPECT_THAT(m->entry_computation()->root_instruction(),
7240 GmockMatch(m::Broadcast(
7241 m::Pad(m::Broadcast(m::Constant()), m::Constant()))));
7242 }
7243
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithUse)7244 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) {
7245 const char* kModuleStr = R"(
7246 HloModule m
7247 test {
7248 c1 = pred[] constant(true)
7249 b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={}
7250 c3 = pred[] constant(false)
7251 p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
7252 ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
7253 }
7254 )";
7255 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7256 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7257 EXPECT_THAT(m->entry_computation()->root_instruction(),
7258 GmockMatch(m::Tuple(m::Broadcast(
7259 m::Pad(m::Broadcast(m::Constant()), m::Constant())))));
7260 }
7261
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithNonScalar)7262 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
7263 const char* kModuleStr = R"(
7264 HloModule m
7265 test {
7266 c1 = pred[32] parameter(0)
7267 b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2}
7268 c3 = pred[] constant(false)
7269 p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
7270 ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
7271 }
7272 )";
7273 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7274 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7275 EXPECT_THAT(m->entry_computation()->root_instruction(),
7276 GmockMatch(m::Tuple(m::Broadcast(
7277 m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
7278 }
7279
7280 // Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
7281 // start_indices are too big.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPadOob)7282 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
7283 const char* hlo_string = R"(
7284 HloModule module
7285
7286 ENTRY f {
7287 constant.546 = f32[] constant(0)
7288 broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
7289 parameter.1 = f32[1]{0} parameter(0)
7290 constant.551 = s32[] constant(2)
7291 ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
7292 }
7293 )";
7294 TF_ASSERT_OK_AND_ASSIGN(auto module,
7295 ParseAndReturnVerifiedModule(hlo_string));
7296 VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
7297 AlgebraicSimplifier simplifier(default_options_);
7298 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
7299 VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
7300 auto* pad = module->entry_computation()->root_instruction();
7301 EXPECT_THAT(pad,
7302 GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
7303 EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
7304 ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
7305 EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
7306 EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
7307 }
7308
7309 } // namespace
7310 } // namespace xla
7311