1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/reshape_mover.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace {
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 class ReshapeMoverTest : public HloTestBase {};
38 
TEST_F(ReshapeMoverTest,ReshapesWithDifferentInputShapesNotMoved)39 TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
40   auto m = CreateNewVerifiedModule();
41   HloComputation::Builder builder(TestName());
42   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
43   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
44       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
45   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
46       1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param1"));
47   auto reshape0 =
48       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
49   auto reshape1 =
50       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
51   builder.AddInstruction(HloInstruction::CreateBinary(
52       root_shape, HloOpcode::kAdd, reshape0, reshape1));
53 
54   auto computation = m->AddEntryComputation(builder.Build());
55 
56   EXPECT_THAT(computation->root_instruction(),
57               op::Add(op::Reshape(param0), op::Reshape(param1)));
58 
59   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
60 
61   EXPECT_THAT(computation->root_instruction(),
62               op::Add(op::Reshape(param0), op::Reshape(param1)));
63 }
64 
65 // For a graph that looks like:
66 //
67 // +- reshape0 - rng0
68 // |
69 // +- const1
70 // |
71 // add
72 //
73 // where rng0 has a different shape than reshape0.
74 //
75 // Verifies that the reshape is not moved, since rng0 is trivially reshapable
76 // and therefore there is no nontrivial reshapes to move.
77 TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
78   auto m = CreateNewVerifiedModule();
79   HloComputation::Builder builder(TestName());
80   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
81   auto rng0 = builder.AddInstruction(HloInstruction::CreateRng(
82       ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
83       RandomDistribution::RNG_UNIFORM,
84       {builder.AddInstruction(
85            HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
86        builder.AddInstruction(HloInstruction::CreateConstant(
87            LiteralUtil::CreateR0<float>(1.0f)))}));
88   auto reshape0 =
89       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0));
90 
91   auto const1 = builder.AddInstruction(
92       HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape)));
93 
94   builder.AddInstruction(HloInstruction::CreateBinary(
95       root_shape, HloOpcode::kAdd, reshape0, const1));
96 
97   auto computation = m->AddEntryComputation(builder.Build());
98 
99   EXPECT_THAT(computation->root_instruction(),
100               op::Add(op::Reshape(rng0), const1));
101 
102   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
103 
104   EXPECT_THAT(computation->root_instruction(),
105               op::Add(op::Reshape(rng0), const1));
106 }
107 
TEST_F(ReshapeMoverTest,ScalarReshapesNotMoved)108 TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) {
109   auto m = CreateNewVerifiedModule();
110   HloComputation::Builder builder(TestName());
111   auto root_shape = ShapeUtil::MakeShape(F32, {});
112   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
113       0, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param0"));
114   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
115       1, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param1"));
116   auto reshape0 =
117       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
118   auto reshape1 =
119       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
120   builder.AddInstruction(HloInstruction::CreateBinary(
121       root_shape, HloOpcode::kAdd, reshape0, reshape1));
122 
123   auto computation = m->AddEntryComputation(builder.Build());
124 
125   EXPECT_THAT(computation->root_instruction(),
126               op::Add(op::Reshape(param0), op::Reshape(param1)));
127 
128   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
129 
130   EXPECT_THAT(
131       computation->root_instruction(),
132       op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter())));
133 }
134 
TEST_F(ReshapeMoverTest,EquivalentReshapesMoved)135 TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
136   auto m = CreateNewVerifiedModule();
137   HloComputation::Builder builder(TestName());
138   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
139   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
140       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
141   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
142       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
143   auto reshape0 =
144       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
145   auto reshape1 =
146       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
147   builder.AddInstruction(HloInstruction::CreateBinary(
148       root_shape, HloOpcode::kAdd, reshape0, reshape1));
149 
150   auto computation = m->AddEntryComputation(builder.Build());
151 
152   EXPECT_THAT(computation->root_instruction(),
153               op::Add(op::Reshape(param0), op::Reshape(param1)));
154   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
155 
156   EXPECT_THAT(computation->root_instruction(),
157               op::Reshape(op::Add(param0, param1)));
158   EXPECT_EQ(root_shape.DebugString(),
159             computation->root_instruction()->shape().DebugString());
160 }
161 
162 // For a graph that looks like:
163 //
164 // +- reshape2 - param2
165 // |
166 // +- reshape1 - param1
167 // |
168 // +- constant0
169 // |
170 // select
171 //
172 // Verifies that the reshape1 and reshape2 sink past select:
173 //
174 // +- param2
175 // |
176 // +- param1
177 // |
178 // +- reshape3(constant0)
179 // |
180 // select
181 // |
182 // reshape4
183 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
184   auto m = CreateNewVerifiedModule();
185   HloComputation::Builder builder(TestName());
186   auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
187   auto const0 = builder.AddInstruction(
188       HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
189           {{true, true, false}, {false, false, true}})));
190 
191   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
192       0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1"));
193   auto reshape1 =
194       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
195 
196   auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(
197       1, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param2"));
198   auto reshape2 =
199       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2));
200 
201   builder.AddInstruction(HloInstruction::CreateTernary(
202       root_shape, HloOpcode::kSelect, const0, reshape1, reshape2));
203 
204   auto computation = m->AddEntryComputation(builder.Build());
205 
206   EXPECT_THAT(computation->root_instruction(),
207               op::Select(const0, reshape1, reshape2));
208 
209   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
210 
211   EXPECT_THAT(computation->root_instruction(),
212               op::Reshape(op::Select(op::Reshape(const0), param1, param2)));
213 
214   EXPECT_EQ(root_shape.DebugString(),
215             computation->root_instruction()->shape().DebugString());
216 }
217 
218 // For a graph that looks like:
219 //
220 // +- reshape0 - param0
221 // |
222 // +- param1
223 // |
224 // add
225 //
226 // Verifies that the reshape0 does not sink below add, because param1 is not
227 // trivially reshapable nor is a Reshape/Transpose.
228 TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
229   auto m = CreateNewVerifiedModule();
230   HloComputation::Builder builder(TestName());
231   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
232   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
233       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
234   auto reshape0 =
235       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
236   auto param1 = builder.AddInstruction(
237       HloInstruction::CreateParameter(1, root_shape, "param1"));
238   builder.AddInstruction(HloInstruction::CreateBinary(
239       root_shape, HloOpcode::kAdd, reshape0, param1));
240 
241   auto computation = m->AddEntryComputation(builder.Build());
242 
243   EXPECT_THAT(computation->root_instruction(),
244               op::Add(op::Reshape(param0), param1));
245   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
246 
247   EXPECT_THAT(computation->root_instruction(),
248               op::Add(op::Reshape(param0), param1));
249   EXPECT_EQ(root_shape.DebugString(),
250             computation->root_instruction()->shape().DebugString());
251 }
252 
253 // For a graph that looks like:
254 //
255 // +- pred
256 // |
257 // +- reshape0 - const0
258 // |
259 // +- reshape1 - const1
260 // |
261 // select
262 //
263 // Verifies that we don't unnecessarily sink reshapes, which are in fact
264 // trivial reshapes.
265 TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
266   auto m = CreateNewVerifiedModule();
267   HloComputation::Builder builder(TestName());
268   auto root_shape = ShapeUtil::MakeShape(F32, {3, 2});
269   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
270       LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
271   auto reshape0 =
272       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0));
273 
274   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
275       LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
276   auto reshape1 =
277       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
278 
279   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
280       0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred"));
281 
282   builder.AddInstruction(HloInstruction::CreateTernary(
283       root_shape, HloOpcode::kSelect, pred, reshape0, reshape1));
284 
285   auto computation = m->AddEntryComputation(builder.Build());
286 
287   EXPECT_THAT(computation->root_instruction(),
288               op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
289 
290   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
291 
292   EXPECT_THAT(computation->root_instruction(),
293               op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
294   EXPECT_EQ(root_shape.DebugString(),
295             computation->root_instruction()->shape().DebugString());
296 }
297 
298 // For a graph that looks like:
299 //
300 // +- reshape0 - param0
301 // |
302 // +- const1
303 // |
304 // add
305 //
306 // where there is only 1 non-trivial reshape (reshape0), we sink the reshape
307 // here for canonicalization benefit:
308 //
309 // +- param0
310 // |
311 // +- reshape1 - const1
312 // |
313 // add
314 // |
315 // reshape2
316 //
317 // (note that reshape1 here is trivial).
318 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
319   auto m = CreateNewVerifiedModule();
320   HloComputation::Builder builder(TestName());
321   auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
322   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
323       0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0"));
324   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
325       LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
326   auto reshape0 =
327       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
328   builder.AddInstruction(HloInstruction::CreateBinary(
329       root_shape, HloOpcode::kAdd, reshape0, const1));
330 
331   auto computation = m->AddEntryComputation(builder.Build());
332 
333   EXPECT_THAT(computation->root_instruction(),
334               op::Add(op::Reshape(param0), const1));
335 
336   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
337 
338   EXPECT_THAT(computation->root_instruction(),
339               op::Reshape(op::Add(param0, op::Reshape(const1))));
340   EXPECT_EQ(root_shape.DebugString(),
341             computation->root_instruction()->shape().DebugString());
342 }
343 
344 // For a graph that looks like:
345 //
346 // +- reshape0 - param0 (shape A)
347 // |
348 // +- reshape1 - const1 (shape B)
349 // |
350 // add
351 //
352 // There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1
353 // should be trivial or not; conceptually it's trivial, but handling it would
354 // complicate the rest of our logic.
355 //
356 // For now we treat it as non-trivial, so we verify that we don't sink the
357 // reshapes in this case.
358 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) {
359   auto m = CreateNewVerifiedModule();
360   HloComputation::Builder builder(TestName());
361   auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
362   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
363       0, ShapeUtil::MakeShape(F32, {1, 3}), "param0"));
364   auto const1 = builder.AddInstruction(
365       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({9, 8, 7})));
366   auto reshape0 =
367       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
368   auto reshape1 =
369       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
370 
371   builder.AddInstruction(HloInstruction::CreateBinary(
372       root_shape, HloOpcode::kAdd, reshape0, reshape1));
373 
374   auto computation = m->AddEntryComputation(builder.Build());
375 
376   EXPECT_THAT(computation->root_instruction(),
377               op::Add(op::Reshape(param0), op::Reshape(const1)));
378 
379   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
380 
381   EXPECT_THAT(computation->root_instruction(),
382               op::Add(op::Reshape(param0), op::Reshape(const1)));
383   EXPECT_EQ(root_shape.DebugString(),
384             computation->root_instruction()->shape().DebugString());
385 }
386 
TEST_F(ReshapeMoverTest,EquivalentReshapesMovedAcrossFusion)387 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
388   auto m = CreateNewVerifiedModule();
389   HloComputation::Builder builder(TestName());
390   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
391   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
392       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
393   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
394       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
395   auto reshape0 =
396       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
397   auto reshape1 =
398       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
399   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
400       root_shape, HloOpcode::kAdd, reshape0, reshape1));
401 
402   auto computation = m->AddEntryComputation(builder.Build());
403   computation->CreateFusionInstruction({add},
404                                        HloInstruction::FusionKind::kLoop);
405 
406   EXPECT_THAT(computation->root_instruction(),
407               op::Fusion(op::Reshape(param0), op::Reshape(param1)));
408 
409   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
410 
411   EXPECT_THAT(computation->root_instruction(),
412               op::Reshape(op::Fusion(param0, param1)));
413   EXPECT_EQ(root_shape.DebugString(),
414             computation->root_instruction()->shape().DebugString());
415 }
416 
TEST_F(ReshapeMoverTest,EquivalentReshapesMovedAcrossSelect)417 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) {
418   auto m = CreateNewVerifiedModule();
419   HloComputation::Builder builder(TestName());
420   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
421   auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7});
422   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
423       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
424   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
425       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
426   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
427       2, ShapeUtil::MakeShape(PRED, {1, 8, 1, 7}), "pred"));
428   auto reshape0 =
429       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
430   auto reshape1 =
431       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
432   auto reshape_pred =
433       builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
434   builder.AddInstruction(HloInstruction::CreateTernary(
435       root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1));
436 
437   auto computation = m->AddEntryComputation(builder.Build());
438 
439   EXPECT_THAT(
440       computation->root_instruction(),
441       op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1)));
442 
443   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
444 
445   EXPECT_THAT(computation->root_instruction(),
446               op::Reshape(op::Select(pred, param0, param1)));
447   EXPECT_EQ(root_shape.DebugString(),
448             computation->root_instruction()->shape().DebugString());
449 }
450 
TEST_F(ReshapeMoverTest,ScalarReshapeNotMovedAcrossSelect)451 TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
452   auto m = CreateNewVerifiedModule();
453   HloComputation::Builder builder(TestName());
454   auto root_shape = ShapeUtil::MakeShape(F32, {});
455   auto pred_shape = ShapeUtil::MakeShape(PRED, {});
456   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
457       0, ShapeUtil::MakeShape(F32, {}), "param0"));
458   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
459       1, ShapeUtil::MakeShape(F32, {}), "param1"));
460   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
461       2, ShapeUtil::MakeShape(PRED, {1, 1, 1}), "pred"));
462   auto reshape_pred =
463       builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
464   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
465       root_shape, HloOpcode::kSelect, reshape_pred, param0, param1));
466 
467   auto computation = m->AddEntryComputation(builder.Build());
468   EXPECT_THAT(computation->root_instruction(),
469               op::Select(op::Reshape(pred), param0, param1));
470 
471   EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie());
472 
473   EXPECT_THAT(computation->root_instruction(),
474               op::Select(op::Reshape(pred), param0, param1));
475   EXPECT_EQ(select, computation->root_instruction());
476 }
477 
478 // Tree looks like this:
479 //
480 // add1
481 // |
482 // +- reshape2 - param2
483 // |
484 // +- reshape3 - add0
485 //               |
486 //               + reshape0 - param0
487 //               |
488 //               + reshape1 - param1
489 //
490 // We expect reshape{0,1} AND reshape{2,3} to be lifted.
TEST_F(ReshapeMoverTest,MultiplePasses)491 TEST_F(ReshapeMoverTest, MultiplePasses) {
492   auto m = CreateNewVerifiedModule();
493   auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7});
494   auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1});
495   auto shape3 = ShapeUtil::MakeShape(F32, {8, 7});
496   HloComputation::Builder builder(TestName());
497   auto param0 = builder.AddInstruction(
498       HloInstruction::CreateParameter(0, shape1, "param0"));
499   auto param1 = builder.AddInstruction(
500       HloInstruction::CreateParameter(1, shape1, "param1"));
501   auto param2 = builder.AddInstruction(
502       HloInstruction::CreateParameter(2, shape2, "param2"));
503   auto reshape0 =
504       builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0));
505   auto reshape1 =
506       builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1));
507   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
508       shape2, HloOpcode::kAdd, reshape0, reshape1));
509   auto reshape2 =
510       builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2));
511   auto reshape3 =
512       builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0));
513   builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd,
514                                                       reshape2, reshape3));
515 
516   auto computation = m->AddEntryComputation(builder.Build());
517 
518   EXPECT_THAT(
519       computation->root_instruction(),
520       op::Add(op::Reshape(param2),
521               op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1)))));
522 
523   EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie());
524 
525   EXPECT_THAT(
526       computation->root_instruction(),
527       op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1)))));
528 }
529 
TEST_F(ReshapeMoverTest,SinkTransposeAcrossBroadcastScalar)530 TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) {
531   const string hlo_string = R"(
532     HloModule TransposeMulInversedTransposeModule
533     ENTRY TransposeMulInversedTranspose {
534       src0 = f32[20,8]{1,0} parameter(0)
535       transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0}
536       src1 = f32[] parameter(1)
537       broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={}
538       ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0)
539     }
540   )";
541 
542   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
543   TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
544   EXPECT_TRUE(changed);
545 
546   EXPECT_THAT(m->entry_computation()->root_instruction(),
547               op::Transpose(op::Multiply()));
548 }
549 
TEST_F(ReshapeMoverTest,ReshapeWithUsersOutsideCandidatesNotSink)550 TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) {
551   const string hlo_string = R"(
552     HloModule ReshapeWithUsersOutsideCandidates
553     ENTRY ReshapeWithMultipleUsers {
554       param0 = f32[20,8]{1,0} parameter(0)
555       reshape0 = f32[8,20]{1,0} reshape(param0)
556       param1 = f32[] parameter(1)
557       broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
558       param2 = f32[20,8]{1,0} parameter(2)
559       reshape1 = f32[8,20]{1,0} reshape(param2)
560       param3 = f32[20,8]{1,0} parameter(3)
561       reshape2 = f32[8,20]{1,0} reshape(param3)
562       param4 = f32[8,20]{1,0} parameter(4)
563       add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
564       add1 = f32[8,20]{1,0} add(reshape0, reshape1)
565       add2 = f32[8,20]{1,0} add(reshape1, param4)
566       ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
567         f32[8,20]{1,0}) tuple(add0, add1, add2)
568     }
569   )";
570 
571   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
572   TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
573   EXPECT_FALSE(changed);
574 }
575 
TEST_F(ReshapeMoverTest,ReshapeNoUsersOutsideCandidatesSink1)576 TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) {
577   const string hlo_string = R"(
578     HloModule ReshapeNoUsersOutsideCandidates1
579     ENTRY ReshapeWithMultipleUsers1 {
580       param0 = f32[20,8]{1,0} parameter(0)
581       reshape0 = f32[8,20]{1,0} reshape(param0)
582       param1 = f32[] parameter(1)
583       broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
584       param2 = f32[20,8]{1,0} parameter(2)
585       reshape1 = f32[8,20]{1,0} reshape(param2)
586       param3 = f32[20,8]{1,0} parameter(3)
587       reshape2 = f32[8,20]{1,0} reshape(param3)
588       add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
589       add1 = f32[8,20]{1,0} add(reshape0, reshape1)
590       add2 = f32[8,20]{1,0} add(reshape1, reshape2)
591       ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
592         f32[8,20]{1,0}) tuple(add0, add1, add2)
593     }
594   )";
595 
596   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
597   TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
598   EXPECT_TRUE(changed);
599   EXPECT_THAT(m->entry_computation()->root_instruction(),
600               op::Tuple(op::Reshape(), op::Reshape(), op::Reshape()));
601 }
602 
TEST_F(ReshapeMoverTest,ReshapeNoUsersOutsideCandidatesSink2)603 TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) {
604   const string hlo_string = R"(
605     HloModule ReshapeNoUsersOutsideCandidates2
606     ENTRY ReshapeWithMultipleUsers2 {
607       param0 = f32[20,8]{1,0} parameter(0)
608       reshape0 = f32[8,20]{1,0} reshape(param0)
609       ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0)
610     }
611   )";
612 
613   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
614   TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get()));
615   EXPECT_TRUE(changed);
616   EXPECT_THAT(m->entry_computation()->root_instruction(),
617               op::Reshape(op::Add()));
618 }
619 
620 }  // namespace
621 }  // namespace xla
622