1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 
34 namespace xla {
35 namespace gpu {
36 namespace {
37 
38 namespace op = xla::testing::opcode_matchers;
39 
40 class HorizontalLoopFusionTest : public HloTestBase {};
41 
TEST_F(HorizontalLoopFusionTest,BasicTest)42 TEST_F(HorizontalLoopFusionTest, BasicTest) {
43   auto module = ParseAndReturnVerifiedModule(R"(
44  HloModule BasicTest
45 
46  fused_computation.1 {
47    arg.1 = f16[1024]{0} parameter(0)
48    arg.2 = f16[1024]{0} parameter(1)
49    ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
50  }
51 
52  fused_computation.2 {
53    arg.1 = f16[123]{0} parameter(0)
54    arg.2 = f16[123]{0} parameter(1)
55    ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
56  }
57 
58  ENTRY entry_computation {
59    arg.1 = f16[1024]{0} parameter(0)
60    arg.2 = f16[1024]{0} parameter(1)
61    arg.3 = f16[123]{0} parameter(2)
62    arg.4 = f16[123]{0} parameter(3)
63    fusion.1 = f16[1024]{0}
64        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
65    fusion.2 = f16[123]{0}
66        fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
67    ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
68        tuple(fusion.1, fusion.2)
69  }
70 )")
71                     .ValueOrDie();
72 
73   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
74   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
75 
76   const HloInstruction* entry_root =
77       module->entry_computation()->root_instruction();
78   EXPECT_THAT(entry_root,
79               op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
80                         op::Bitcast(op::GetTupleElement(op::Fusion()))));
81 
82   const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
83   ASSERT_TRUE(fusion->IsMultiOutputFusion());
84   EXPECT_THAT(
85       fusion->fused_expression_root(),
86       op::Tuple(op::Slice(op::Concatenate(op::Reshape(), op::Reshape())),
87                 op::Slice(op::Concatenate(op::Reshape(), op::Reshape()))));
88 }
89 
90 // Horizontal fusion should not be triggered as fusion will create cycles.
TEST_F(HorizontalLoopFusionTest,NegativeTestForCycle)91 TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
92   auto module = ParseAndReturnVerifiedModule(R"(
93  HloModule NegativeTestForCycle
94 
95  fused_computation.1 {
96    arg.1 = f16[123]{0} parameter(0)
97    arg.2 = f16[123]{0} parameter(1)
98    ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
99  }
100 
101  fused_computation.2 {
102    arg.1 = f16[123]{0} parameter(0)
103    arg.2 = f16[123]{0} parameter(1)
104    ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
105  }
106 
107  ENTRY entry_computation {
108    arg.1 = f16[123]{0} parameter(0)
109    arg.2 = f16[123]{0} parameter(1)
110    arg.3 = f16[123]{0} parameter(2)
111    arg.4 = f16[123]{0} parameter(3)
112    // fusion.1 and fusion.2 will not be horizontally fused as it will create
113    // a cycle through fusion.1 -> add.2 -> fusion.2
114    fusion.1 = f16[123]{0}
115        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
116    add.2 = f16[123]{0} add(fusion.1, arg.4)
117    fusion.2 = f16[123]{0}
118        fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
119    ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
120        tuple(fusion.1, fusion.2, add.2)
121  }
122 )")
123                     .ValueOrDie();
124 
125   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
126 }
127 
TEST_F(HorizontalLoopFusionTest,NegativeTestForIncompatibleTypes)128 TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
129   auto module = ParseAndReturnVerifiedModule(R"(
130  HloModule NegativeTestForIncompatibleTypes
131 
132  fused_computation.1 {
133    arg.1 = f16[1024]{0} parameter(0)
134    arg.2 = f16[1024]{0} parameter(1)
135    ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
136  }
137 
138  fused_computation.2 {
139    arg.1 = s32[123]{0} parameter(0)
140    arg.2 = s32[123]{0} parameter(1)
141    ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
142  }
143 
144  ENTRY entry_computation {
145    arg.1 = f16[1024]{0} parameter(0)
146    arg.2 = f16[1024]{0} parameter(1)
147    arg.3 = s32[123]{0} parameter(2)
148    arg.4 = s32[123]{0} parameter(3)
149    // fusion.1 and fusion.2 will not be horizontally fused because their output
150    // types are different.
151    fusion.1 = f16[1024]{0}
152        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
153    fusion.2 = s32[123]{0}
154        fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
155    ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
156        tuple(fusion.1, fusion.2)
157  }
158 )")
159                     .ValueOrDie();
160 
161   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
162 }
163 
TEST_F(HorizontalLoopFusionTest,HorizontalLoopFusionAfterVerticalFusion)164 TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
165   auto module = ParseAndReturnVerifiedModule(R"(
166  HloModule MergeSharedFusionInstruction
167 
168  ENTRY MergeSharedFusionInstruction.Computation0 {
169   param.1.1   = f32[4,1024]{1,0} parameter(0)
170   param.1.2   = f32[4,1024]{1,0} parameter(1)
171   param.1.3   = f32[4,1024]{1,0} parameter(2)
172   param.2.1   = f32[321,5]{1,0} parameter(3)
173   param.2.2   = f32[321,5]{1,0} parameter(4)
174   param.2.3   = f32[321,5]{1,0} parameter(5)
175   const.1     = f32[] constant(3)
176   const.2     = f32[] constant(3)
177   broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
178   broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
179   mul.1.1     = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
180   mul.1.2     = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
181   add.1       = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
182   mul.2.1     = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
183   mul.2.2     = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
184   add.2       = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
185   ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
186 })")
187                     .ValueOrDie();
188 
189   HloPassPipeline fusion("fusion");
190   fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
191   fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
192   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
194 
195   VLOG(2) << "Dump after horizontal fusion:";
196   VLOG(2) << module->ToString();
197 
198   EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
199 }
200 
TEST_F(HorizontalLoopFusionTest,GradientDescentOptimizerLike)201 TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
202   HloComputation::Builder builder(TestName());
203 
204   std::vector<HloInstruction*> var_outs;
205   for (int64 i = 0; i < 128; ++i) {
206     // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
207     auto shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
208     HloInstruction* param_var_in = builder.AddInstruction(
209         HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
210     HloInstruction* param_alpha =
211         builder.AddInstruction(HloInstruction::CreateParameter(
212             i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
213     HloInstruction* param_delta = builder.AddInstruction(
214         HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
215     auto alpha_broadcasted = builder.AddInstruction(
216         HloInstruction::CreateBroadcast(shape, param_alpha, {}));
217     auto alpha_delta = builder.AddInstruction(HloInstruction::CreateBinary(
218         shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
219     auto var_out = builder.AddInstruction(HloInstruction::CreateBinary(
220         shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
221     var_outs.push_back(var_out);
222   }
223   builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
224 
225   auto module = CreateNewVerifiedModule();
226   module->AddEntryComputation(builder.Build());
227 
228   // Testing with the entire gpu optimization pipeline.
229   EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
230 }
231 
TEST_F(HorizontalLoopFusionTest,FusingDifferentOutputs)232 TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
233   auto module = ParseAndReturnVerifiedModule(R"(
234  HloModule HeterogeneousMultiOutputFusions
235 
236  fused_computation.1 {
237    arg.1 = f16[1024]{0} parameter(0)
238    arg.2 = f16[1024]{0} parameter(1)
239    arg.3 = f16[1024]{0} parameter(2)
240    arg.4 = f16[1024]{0} parameter(3)
241    mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
242    mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
243    add.1 = f16[1024]{0} add(mul.1, mul.2)
244    ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
245  }
246 
247  fused_computation.2 {
248    arg.1 = f16[123]{0} parameter(0)
249    arg.2 = f16[123]{0} parameter(1)
250    arg.3 = f16[123]{0} parameter(2)
251    arg.4 = f16[123]{0} parameter(3)
252    add.1 = f16[123]{0} add(arg.1, arg.2)
253    add.2 = f16[123]{0} add(arg.3, arg.4)
254    mul.1 = f16[123]{0} multiply(add.1, add.2)
255    ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
256  }
257 
258  ENTRY entry_computation {
259    arg.1 = f16[1024]{0} parameter(0)
260    arg.2 = f16[1024]{0} parameter(1)
261    arg.3 = f16[1024]{0} parameter(2)
262    arg.4 = f16[1024]{0} parameter(3)
263    arg.5 = f16[123]{0} parameter(4)
264    arg.6 = f16[123]{0} parameter(5)
265    arg.7 = f16[123]{0} parameter(6)
266    arg.8 = f16[123]{0} parameter(7)
267    fusion.1 = (f16[1024]{0}, f16[1024]{0})
268        fusion(arg.1, arg.2, arg.3, arg.4),
269            kind=kLoop, calls=fused_computation.1
270    fusion.2 = (f16[123]{0}, f16[123]{0})
271        fusion(arg.5, arg.6, arg.7, arg.8),
272            kind=kLoop, calls=fused_computation.2
273    gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
274    gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
275    gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
276    gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
277    ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
278        tuple(gte.1, gte.2, gte.3, gte.4)
279  }
280 )")
281                     .ValueOrDie();
282 
283   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
284   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
285 
286   VLOG(2) << "Dump after horizontal fusion:";
287   VLOG(2) << module->ToString();
288 
289   EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
290 }
291 
TEST_F(HorizontalLoopFusionTest,RMSPropLike)292 TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
293   HloComputation::Builder builder(TestName());
294 
295   std::vector<HloInstruction*> all_outputs;
296   for (int64 i = 0; i < 48; ++i) {
297     auto shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
298     // ms <- grad**2 (1 - rho) + ms * rho
299     HloInstruction* grad = builder.AddInstruction(
300         HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
301     HloInstruction* ms = builder.AddInstruction(
302         HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
303     HloInstruction* rho =
304         builder.AddInstruction(HloInstruction::CreateParameter(
305             i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
306     HloInstruction* one_minus_rho =
307         builder.AddInstruction(HloInstruction::CreateParameter(
308             i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
309     auto rho_broadcasted =
310         builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
311     auto one_mins_rho_broadcasted = builder.AddInstruction(
312         HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
313     auto grad_squared = builder.AddInstruction(
314         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
315     auto ms_1st_term = builder.AddInstruction(HloInstruction::CreateBinary(
316         shape, HloOpcode::kMultiply, grad_squared, one_mins_rho_broadcasted));
317     auto ms_2nd_term = builder.AddInstruction(HloInstruction::CreateBinary(
318         shape, HloOpcode::kMultiply, ms, rho_broadcasted));
319     auto ms_out = builder.AddInstruction(HloInstruction::CreateBinary(
320         shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
321 
322     // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
323     HloInstruction* momentum = builder.AddInstruction(
324         HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
325     HloInstruction* mom = builder.AddInstruction(
326         HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
327     HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
328         i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
329     HloInstruction* epsilon =
330         builder.AddInstruction(HloInstruction::CreateParameter(
331             i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
332     auto lr_broadcasted =
333         builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
334     auto epsilon_broadcasted = builder.AddInstruction(
335         HloInstruction::CreateBroadcast(shape, epsilon, {}));
336     auto mom_1st_term = builder.AddInstruction(HloInstruction::CreateBinary(
337         shape, HloOpcode::kMultiply, momentum, mom));
338     auto ms_eps = builder.AddInstruction(HloInstruction::CreateBinary(
339         shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
340     auto ms_eps_rsq = builder.AddInstruction(
341         HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
342     auto grad_ms_eps_rsq = builder.AddInstruction(HloInstruction::CreateBinary(
343         shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
344     auto mom_2nd_term = builder.AddInstruction(HloInstruction::CreateBinary(
345         shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
346     auto mom_out = builder.AddInstruction(HloInstruction::CreateBinary(
347         shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
348 
349     // var <- var - mom
350     HloInstruction* var = builder.AddInstruction(
351         HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
352     auto var_out = builder.AddInstruction(HloInstruction::CreateBinary(
353         shape, HloOpcode::kSubtract, var, mom_out));
354 
355     all_outputs.push_back(ms_out);
356     all_outputs.push_back(mom_out);
357     all_outputs.push_back(var_out);
358   }
359   builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
360 
361   auto module = CreateNewVerifiedModule();
362   module->AddEntryComputation(builder.Build());
363 
364   EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
365 }
366 
TEST_F(HorizontalLoopFusionTest,NegativeTestForDynamicUpdateSlice)367 TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
368   auto module = ParseAndReturnVerifiedModule(R"(
369   HloModule NegativeTestForDynamicUpdateSlice
370 
371   fusion.1 {
372     p.0 = f16[5,9,10]{2,1,0} parameter(0)
373     p.1 = s32[1]{0} parameter(1)
374     p.2 = f16[1,9,10]{2,1,0} parameter(2)
375     c.0 = s32[] constant(0)
376     pad = s32[3]{0} pad(p.1, c.0), padding=0_2
377     ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
378   }
379 
380   fusion.2 {
381     p.0 = f16[5,9,10]{2,1,0} parameter(0)
382     p.1 = s32[1]{0} parameter(1)
383     p.2 = f16[1,9,10]{2,1,0} parameter(2)
384     c.0 = s32[] constant(0)
385     pad = s32[3]{0} pad(p.1, c.0), padding=0_2
386     ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
387   }
388 
389   ENTRY entry {
390     p.00 = f16[5,9,10]{2,1,0} parameter(0)
391     p.01 = f16[5,9,10]{2,1,0} parameter(1)
392     p.10 = s32[1]{0} parameter(2)
393     p.11 = s32[1]{0} parameter(3)
394     p.20 = f16[1,9,10]{2,1,0} parameter(4)
395     p.21 = f16[1,9,10]{2,1,0} parameter(5)
396 
397     f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
398     f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
399     ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
400   })")
401                     .ValueOrDie();
402 
403   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
404 }
405 
406 }  // namespace
407 }  // namespace gpu
408 }  // namespace xla
409