1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/reshape_mover.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace xla {
33 namespace {
34
35 namespace op = xla::testing::opcode_matchers;
36
37 class ReshapeMoverTest : public HloTestBase {};
38
TEST_F(ReshapeMoverTest,ReshapesWithDifferentInputShapesNotMoved)39 TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
40 auto m = CreateNewVerifiedModule();
41 HloComputation::Builder builder(TestName());
42 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
43 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
44 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
45 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
46 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param1"));
47 auto reshape0 =
48 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
49 auto reshape1 =
50 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
51 builder.AddInstruction(HloInstruction::CreateBinary(
52 root_shape, HloOpcode::kAdd, reshape0, reshape1));
53
54 auto computation = m->AddEntryComputation(builder.Build());
55
56 EXPECT_THAT(computation->root_instruction(),
57 op::Add(op::Reshape(param0), op::Reshape(param1)));
58
59 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
60
61 EXPECT_THAT(computation->root_instruction(),
62 op::Add(op::Reshape(param0), op::Reshape(param1)));
63 }
64
65 // For a graph that looks like:
66 //
67 // +- reshape0 - rng0
68 // |
69 // +- const1
70 // |
71 // add
72 //
73 // where rng0 has a different shape than reshape0.
74 //
75 // Verifies that the reshape is not moved, since rng0 is trivially reshapable
76 // and therefore there is no nontrivial reshapes to move.
77 TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
78 auto m = CreateNewVerifiedModule();
79 HloComputation::Builder builder(TestName());
80 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
81 auto rng0 = builder.AddInstruction(HloInstruction::CreateRng(
82 ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
83 RandomDistribution::RNG_UNIFORM,
84 {builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
86 builder.AddInstruction(HloInstruction::CreateConstant(
87 LiteralUtil::CreateR0<float>(1.0f)))}));
88 auto reshape0 =
89 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0));
90
91 auto const1 = builder.AddInstruction(
92 HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape)));
93
94 builder.AddInstruction(HloInstruction::CreateBinary(
95 root_shape, HloOpcode::kAdd, reshape0, const1));
96
97 auto computation = m->AddEntryComputation(builder.Build());
98
99 EXPECT_THAT(computation->root_instruction(),
100 op::Add(op::Reshape(rng0), const1));
101
102 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
103
104 EXPECT_THAT(computation->root_instruction(),
105 op::Add(op::Reshape(rng0), const1));
106 }
107
TEST_F(ReshapeMoverTest,ScalarReshapesNotMoved)108 TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) {
109 auto m = CreateNewVerifiedModule();
110 HloComputation::Builder builder(TestName());
111 auto root_shape = ShapeUtil::MakeShape(F32, {});
112 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
113 0, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param0"));
114 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
115 1, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param1"));
116 auto reshape0 =
117 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
118 auto reshape1 =
119 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
120 builder.AddInstruction(HloInstruction::CreateBinary(
121 root_shape, HloOpcode::kAdd, reshape0, reshape1));
122
123 auto computation = m->AddEntryComputation(builder.Build());
124
125 EXPECT_THAT(computation->root_instruction(),
126 op::Add(op::Reshape(param0), op::Reshape(param1)));
127
128 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
129
130 EXPECT_THAT(
131 computation->root_instruction(),
132 op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter())));
133 }
134
TEST_F(ReshapeMoverTest,EquivalentReshapesMoved)135 TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
136 auto m = CreateNewVerifiedModule();
137 HloComputation::Builder builder(TestName());
138 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
139 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
140 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
141 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
142 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
143 auto reshape0 =
144 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
145 auto reshape1 =
146 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
147 builder.AddInstruction(HloInstruction::CreateBinary(
148 root_shape, HloOpcode::kAdd, reshape0, reshape1));
149
150 auto computation = m->AddEntryComputation(builder.Build());
151
152 EXPECT_THAT(computation->root_instruction(),
153 op::Add(op::Reshape(param0), op::Reshape(param1)));
154 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
155
156 EXPECT_THAT(computation->root_instruction(),
157 op::Reshape(op::Add(param0, param1)));
158 EXPECT_EQ(root_shape.DebugString(),
159 computation->root_instruction()->shape().DebugString());
160 }
161
162 // For a graph that looks like:
163 //
164 // +- reshape2 - param2
165 // |
166 // +- reshape1 - param1
167 // |
168 // +- constant0
169 // |
170 // select
171 //
172 // Verifies that the reshape1 and reshape2 sink past select:
173 //
174 // +- param2
175 // |
176 // +- param1
177 // |
178 // +- reshape3(constant0)
179 // |
180 // select
181 // |
182 // reshape4
183 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
184 auto m = CreateNewVerifiedModule();
185 HloComputation::Builder builder(TestName());
186 auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
187 auto const0 = builder.AddInstruction(
188 HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
189 {{true, true, false}, {false, false, true}})));
190
191 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
192 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1"));
193 auto reshape1 =
194 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
195
196 auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(
197 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param2"));
198 auto reshape2 =
199 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2));
200
201 builder.AddInstruction(HloInstruction::CreateTernary(
202 root_shape, HloOpcode::kSelect, const0, reshape1, reshape2));
203
204 auto computation = m->AddEntryComputation(builder.Build());
205
206 EXPECT_THAT(computation->root_instruction(),
207 op::Select(const0, reshape1, reshape2));
208
209 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
210
211 EXPECT_THAT(computation->root_instruction(),
212 op::Reshape(op::Select(op::Reshape(const0), param1, param2)));
213
214 EXPECT_EQ(root_shape.DebugString(),
215 computation->root_instruction()->shape().DebugString());
216 }
217
218 // For a graph that looks like:
219 //
220 // +- reshape0 - param0
221 // |
222 // +- param1
223 // |
224 // add
225 //
226 // Verifies that the reshape0 does not sink below add, because param1 is not
227 // trivially reshapable nor is a Reshape/Transpose.
228 TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
229 auto m = CreateNewVerifiedModule();
230 HloComputation::Builder builder(TestName());
231 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
232 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
233 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
234 auto reshape0 =
235 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
236 auto param1 = builder.AddInstruction(
237 HloInstruction::CreateParameter(1, root_shape, "param1"));
238 builder.AddInstruction(HloInstruction::CreateBinary(
239 root_shape, HloOpcode::kAdd, reshape0, param1));
240
241 auto computation = m->AddEntryComputation(builder.Build());
242
243 EXPECT_THAT(computation->root_instruction(),
244 op::Add(op::Reshape(param0), param1));
245 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
246
247 EXPECT_THAT(computation->root_instruction(),
248 op::Add(op::Reshape(param0), param1));
249 EXPECT_EQ(root_shape.DebugString(),
250 computation->root_instruction()->shape().DebugString());
251 }
252
253 // For a graph that looks like:
254 //
255 // +- pred
256 // |
257 // +- reshape0 - const0
258 // |
259 // +- reshape1 - const1
260 // |
261 // select
262 //
263 // Verifies that we don't unnecessarily sink reshapes, which are in fact
264 // trivial reshapes.
265 TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
266 auto m = CreateNewVerifiedModule();
267 HloComputation::Builder builder(TestName());
268 auto root_shape = ShapeUtil::MakeShape(F32, {3, 2});
269 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
270 LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
271 auto reshape0 =
272 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0));
273
274 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
275 LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
276 auto reshape1 =
277 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
278
279 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
280 0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred"));
281
282 builder.AddInstruction(HloInstruction::CreateTernary(
283 root_shape, HloOpcode::kSelect, pred, reshape0, reshape1));
284
285 auto computation = m->AddEntryComputation(builder.Build());
286
287 EXPECT_THAT(computation->root_instruction(),
288 op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
289
290 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
291
292 EXPECT_THAT(computation->root_instruction(),
293 op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
294 EXPECT_EQ(root_shape.DebugString(),
295 computation->root_instruction()->shape().DebugString());
296 }
297
298 // For a graph that looks like:
299 //
300 // +- reshape0 - param0
301 // |
302 // +- const1
303 // |
304 // add
305 //
306 // where there is only 1 non-trivial reshape (reshape0), we sink the reshape
307 // here for canonicalization benefit:
308 //
309 // +- param0
310 // |
311 // +- reshape1 - const1
312 // |
313 // add
314 // |
315 // reshape2
316 //
317 // (note that reshape1 here is trivial).
318 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
319 auto m = CreateNewVerifiedModule();
320 HloComputation::Builder builder(TestName());
321 auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
322 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
323 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0"));
324 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
325 LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
326 auto reshape0 =
327 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
328 builder.AddInstruction(HloInstruction::CreateBinary(
329 root_shape, HloOpcode::kAdd, reshape0, const1));
330
331 auto computation = m->AddEntryComputation(builder.Build());
332
333 EXPECT_THAT(computation->root_instruction(),
334 op::Add(op::Reshape(param0), const1));
335
336 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
337
338 EXPECT_THAT(computation->root_instruction(),
339 op::Reshape(op::Add(param0, op::Reshape(const1))));
340 EXPECT_EQ(root_shape.DebugString(),
341 computation->root_instruction()->shape().DebugString());
342 }
343
344 // For a graph that looks like:
345 //
346 // +- reshape0 - param0 (shape A)
347 // |
348 // +- reshape1 - const1 (shape B)
349 // |
350 // add
351 //
352 // There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1
353 // should be trivial or not; conceptually it's trivial, but handling it would
354 // complicate the rest of our logic.
355 //
356 // For now we treat it as non-trivial, so we verify that we don't sink the
357 // reshapes in this case.
358 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) {
359 auto m = CreateNewVerifiedModule();
360 HloComputation::Builder builder(TestName());
361 auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
362 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
363 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0"));
364 auto const1 = builder.AddInstruction(
365 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({9, 8, 7})));
366 auto reshape0 =
367 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
368 auto reshape1 =
369 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
370
371 builder.AddInstruction(HloInstruction::CreateBinary(
372 root_shape, HloOpcode::kAdd, reshape0, reshape1));
373
374 auto computation = m->AddEntryComputation(builder.Build());
375
376 EXPECT_THAT(computation->root_instruction(),
377 op::Add(op::Reshape(param0), op::Reshape(const1)));
378
379 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
380
381 EXPECT_THAT(computation->root_instruction(),
382 op::Add(op::Reshape(param0), op::Reshape(const1)));
383 EXPECT_EQ(root_shape.DebugString(),
384 computation->root_instruction()->shape().DebugString());
385 }
386
TEST_F(ReshapeMoverTest,EquivalentReshapesMovedAcrossFusion)387 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
388 auto m = CreateNewVerifiedModule();
389 HloComputation::Builder builder(TestName());
390 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
391 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
392 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
393 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
394 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
395 auto reshape0 =
396 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
397 auto reshape1 =
398 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
399 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
400 root_shape, HloOpcode::kAdd, reshape0, reshape1));
401
402 auto computation = m->AddEntryComputation(builder.Build());
403 computation->CreateFusionInstruction({add},
404 HloInstruction::FusionKind::kLoop);
405
406 EXPECT_THAT(computation->root_instruction(),
407 op::Fusion(op::Reshape(param0), op::Reshape(param1)));
408
409 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
410
411 EXPECT_THAT(computation->root_instruction(),
412 op::Reshape(op::Fusion(param0, param1)));
413 EXPECT_EQ(root_shape.DebugString(),
414 computation->root_instruction()->shape().DebugString());
415 }
416
TEST_F(ReshapeMoverTest,EquivalentReshapesMovedAcrossSelect)417 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) {
418 auto m = CreateNewVerifiedModule();
419 HloComputation::Builder builder(TestName());
420 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
421 auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7});
422 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
423 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
424 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
425 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
426 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
427 2, ShapeUtil::MakeShape(PRED, {1, 8, 1, 7}), "pred"));
428 auto reshape0 =
429 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
430 auto reshape1 =
431 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
432 auto reshape_pred =
433 builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
434 builder.AddInstruction(HloInstruction::CreateTernary(
435 root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1));
436
437 auto computation = m->AddEntryComputation(builder.Build());
438
439 EXPECT_THAT(
440 computation->root_instruction(),
441 op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1)));
442
443 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
444
445 EXPECT_THAT(computation->root_instruction(),
446 op::Reshape(op::Select(pred, param0, param1)));
447 EXPECT_EQ(root_shape.DebugString(),
448 computation->root_instruction()->shape().DebugString());
449 }
450
TEST_F(ReshapeMoverTest,ScalarReshapeNotMovedAcrossSelect)451 TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
452 auto m = CreateNewVerifiedModule();
453 HloComputation::Builder builder(TestName());
454 auto root_shape = ShapeUtil::MakeShape(F32, {});
455 auto pred_shape = ShapeUtil::MakeShape(PRED, {});
456 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
457 0, ShapeUtil::MakeShape(F32, {}), "param0"));
458 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
459 1, ShapeUtil::MakeShape(F32, {}), "param1"));
460 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
461 2, ShapeUtil::MakeShape(PRED, {1, 1, 1}), "pred"));
462 auto reshape_pred =
463 builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
464 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
465 root_shape, HloOpcode::kSelect, reshape_pred, param0, param1));
466
467 auto computation = m->AddEntryComputation(builder.Build());
468 EXPECT_THAT(computation->root_instruction(),
469 op::Select(op::Reshape(pred), param0, param1));
470
471 EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
472
473 EXPECT_THAT(computation->root_instruction(),
474 op::Select(op::Reshape(pred), param0, param1));
475 EXPECT_EQ(select, computation->root_instruction());
476 }
477
478 // Tree looks like this:
479 //
480 // add1
481 // |
482 // +- reshape2 - param2
483 // |
484 // +- reshape3 - add0
485 // |
486 // + reshape0 - param0
487 // |
488 // + reshape1 - param1
489 //
490 // We expect reshape{0,1} AND reshape{2,3} to be lifted.
TEST_F(ReshapeMoverTest,MultiplePasses)491 TEST_F(ReshapeMoverTest, MultiplePasses) {
492 auto m = CreateNewVerifiedModule();
493 auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7});
494 auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1});
495 auto shape3 = ShapeUtil::MakeShape(F32, {8, 7});
496 HloComputation::Builder builder(TestName());
497 auto param0 = builder.AddInstruction(
498 HloInstruction::CreateParameter(0, shape1, "param0"));
499 auto param1 = builder.AddInstruction(
500 HloInstruction::CreateParameter(1, shape1, "param1"));
501 auto param2 = builder.AddInstruction(
502 HloInstruction::CreateParameter(2, shape2, "param2"));
503 auto reshape0 =
504 builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0));
505 auto reshape1 =
506 builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1));
507 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
508 shape2, HloOpcode::kAdd, reshape0, reshape1));
509 auto reshape2 =
510 builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2));
511 auto reshape3 =
512 builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0));
513 builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd,
514 reshape2, reshape3));
515
516 auto computation = m->AddEntryComputation(builder.Build());
517
518 EXPECT_THAT(
519 computation->root_instruction(),
520 op::Add(op::Reshape(param2),
521 op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1)))));
522
523 EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
524
525 EXPECT_THAT(
526 computation->root_instruction(),
527 op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1)))));
528 }
529
TEST_F(ReshapeMoverTest,SinkTransposeAcrossBroadcastScalar)530 TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) {
531 const string hlo_string = R"(
532 HloModule TransposeMulInversedTransposeModule
533 ENTRY TransposeMulInversedTranspose {
534 src0 = f32[20,8]{1,0} parameter(0)
535 transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0}
536 src1 = f32[] parameter(1)
537 broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={}
538 ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0)
539 }
540 )";
541
542 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
543 TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
544 EXPECT_TRUE(changed);
545
546 EXPECT_THAT(m->entry_computation()->root_instruction(),
547 op::Transpose(op::Multiply()));
548 }
549
TEST_F(ReshapeMoverTest,ReshapeWithUsersOutsideCandidatesNotSink)550 TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) {
551 const string hlo_string = R"(
552 HloModule ReshapeWithUsersOutsideCandidates
553 ENTRY ReshapeWithMultipleUsers {
554 param0 = f32[20,8]{1,0} parameter(0)
555 reshape0 = f32[8,20]{1,0} reshape(param0)
556 param1 = f32[] parameter(1)
557 broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
558 param2 = f32[20,8]{1,0} parameter(2)
559 reshape1 = f32[8,20]{1,0} reshape(param2)
560 param3 = f32[20,8]{1,0} parameter(3)
561 reshape2 = f32[8,20]{1,0} reshape(param3)
562 param4 = f32[8,20]{1,0} parameter(4)
563 add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
564 add1 = f32[8,20]{1,0} add(reshape0, reshape1)
565 add2 = f32[8,20]{1,0} add(reshape1, param4)
566 ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
567 f32[8,20]{1,0}) tuple(add0, add1, add2)
568 }
569 )";
570
571 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
572 TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
573 EXPECT_FALSE(changed);
574 }
575
TEST_F(ReshapeMoverTest,ReshapeNoUsersOutsideCandidatesSink1)576 TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) {
577 const string hlo_string = R"(
578 HloModule ReshapeNoUsersOutsideCandidates1
579 ENTRY ReshapeWithMultipleUsers1 {
580 param0 = f32[20,8]{1,0} parameter(0)
581 reshape0 = f32[8,20]{1,0} reshape(param0)
582 param1 = f32[] parameter(1)
583 broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
584 param2 = f32[20,8]{1,0} parameter(2)
585 reshape1 = f32[8,20]{1,0} reshape(param2)
586 param3 = f32[20,8]{1,0} parameter(3)
587 reshape2 = f32[8,20]{1,0} reshape(param3)
588 add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
589 add1 = f32[8,20]{1,0} add(reshape0, reshape1)
590 add2 = f32[8,20]{1,0} add(reshape1, reshape2)
591 ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
592 f32[8,20]{1,0}) tuple(add0, add1, add2)
593 }
594 )";
595
596 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
597 TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
598 EXPECT_TRUE(changed);
599 EXPECT_THAT(m->entry_computation()->root_instruction(),
600 op::Tuple(op::Reshape(), op::Reshape(), op::Reshape()));
601 }
602
TEST_F(ReshapeMoverTest,ReshapeNoUsersOutsideCandidatesSink2)603 TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) {
604 const string hlo_string = R"(
605 HloModule ReshapeNoUsersOutsideCandidates2
606 ENTRY ReshapeWithMultipleUsers2 {
607 param0 = f32[20,8]{1,0} parameter(0)
608 reshape0 = f32[8,20]{1,0} reshape(param0)
609 ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0)
610 }
611 )";
612
613 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
614 TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
615 EXPECT_TRUE(changed);
616 EXPECT_THAT(m->entry_computation()->root_instruction(),
617 op::Reshape(op::Add()));
618 }
619
620 } // namespace
621 } // namespace xla
622