1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33
34 namespace xla {
35 namespace gpu {
36 namespace {
37
38 namespace op = xla::testing::opcode_matchers;
39
40 class HorizontalLoopFusionTest : public HloTestBase {};
41
TEST_F(HorizontalLoopFusionTest,BasicTest)42 TEST_F(HorizontalLoopFusionTest, BasicTest) {
43 auto module = ParseAndReturnVerifiedModule(R"(
44 HloModule BasicTest
45
46 fused_computation.1 {
47 arg.1 = f16[1024]{0} parameter(0)
48 arg.2 = f16[1024]{0} parameter(1)
49 ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
50 }
51
52 fused_computation.2 {
53 arg.1 = f16[123]{0} parameter(0)
54 arg.2 = f16[123]{0} parameter(1)
55 ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
56 }
57
58 ENTRY entry_computation {
59 arg.1 = f16[1024]{0} parameter(0)
60 arg.2 = f16[1024]{0} parameter(1)
61 arg.3 = f16[123]{0} parameter(2)
62 arg.4 = f16[123]{0} parameter(3)
63 fusion.1 = f16[1024]{0}
64 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
65 fusion.2 = f16[123]{0}
66 fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
67 ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
68 tuple(fusion.1, fusion.2)
69 }
70 )")
71 .ValueOrDie();
72
73 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
74 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
75
76 const HloInstruction* entry_root =
77 module->entry_computation()->root_instruction();
78 EXPECT_THAT(entry_root,
79 op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
80 op::Bitcast(op::GetTupleElement(op::Fusion()))));
81
82 const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
83 ASSERT_TRUE(fusion->IsMultiOutputFusion());
84 EXPECT_THAT(
85 fusion->fused_expression_root(),
86 op::Tuple(op::Slice(op::Concatenate(op::Reshape(), op::Reshape())),
87 op::Slice(op::Concatenate(op::Reshape(), op::Reshape()))));
88 }
89
90 // Horizontal fusion should not be triggered as fusion will create cycles.
TEST_F(HorizontalLoopFusionTest,NegativeTestForCycle)91 TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
92 auto module = ParseAndReturnVerifiedModule(R"(
93 HloModule NegativeTestForCycle
94
95 fused_computation.1 {
96 arg.1 = f16[123]{0} parameter(0)
97 arg.2 = f16[123]{0} parameter(1)
98 ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
99 }
100
101 fused_computation.2 {
102 arg.1 = f16[123]{0} parameter(0)
103 arg.2 = f16[123]{0} parameter(1)
104 ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
105 }
106
107 ENTRY entry_computation {
108 arg.1 = f16[123]{0} parameter(0)
109 arg.2 = f16[123]{0} parameter(1)
110 arg.3 = f16[123]{0} parameter(2)
111 arg.4 = f16[123]{0} parameter(3)
112 // fusion.1 and fusion.2 will not be horizontally fused as it will create
113 // a cycle through fusion.1 -> add.2 -> fusion.2
114 fusion.1 = f16[123]{0}
115 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
116 add.2 = f16[123]{0} add(fusion.1, arg.4)
117 fusion.2 = f16[123]{0}
118 fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
119 ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
120 tuple(fusion.1, fusion.2, add.2)
121 }
122 )")
123 .ValueOrDie();
124
125 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
126 }
127
TEST_F(HorizontalLoopFusionTest,NegativeTestForIncompatibleTypes)128 TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
129 auto module = ParseAndReturnVerifiedModule(R"(
130 HloModule NegativeTestForIncompatibleTypes
131
132 fused_computation.1 {
133 arg.1 = f16[1024]{0} parameter(0)
134 arg.2 = f16[1024]{0} parameter(1)
135 ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
136 }
137
138 fused_computation.2 {
139 arg.1 = s32[123]{0} parameter(0)
140 arg.2 = s32[123]{0} parameter(1)
141 ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
142 }
143
144 ENTRY entry_computation {
145 arg.1 = f16[1024]{0} parameter(0)
146 arg.2 = f16[1024]{0} parameter(1)
147 arg.3 = s32[123]{0} parameter(2)
148 arg.4 = s32[123]{0} parameter(3)
149 // fusion.1 and fusion.2 will not be horizontally fused because their output
150 // types are different.
151 fusion.1 = f16[1024]{0}
152 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
153 fusion.2 = s32[123]{0}
154 fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
155 ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
156 tuple(fusion.1, fusion.2)
157 }
158 )")
159 .ValueOrDie();
160
161 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
162 }
163
TEST_F(HorizontalLoopFusionTest,HorizontalLoopFusionAfterVerticalFusion)164 TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
165 auto module = ParseAndReturnVerifiedModule(R"(
166 HloModule MergeSharedFusionInstruction
167
168 ENTRY MergeSharedFusionInstruction.Computation0 {
169 param.1.1 = f32[4,1024]{1,0} parameter(0)
170 param.1.2 = f32[4,1024]{1,0} parameter(1)
171 param.1.3 = f32[4,1024]{1,0} parameter(2)
172 param.2.1 = f32[321,5]{1,0} parameter(3)
173 param.2.2 = f32[321,5]{1,0} parameter(4)
174 param.2.3 = f32[321,5]{1,0} parameter(5)
175 const.1 = f32[] constant(3)
176 const.2 = f32[] constant(3)
177 broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
178 broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
179 mul.1.1 = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
180 mul.1.2 = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
181 add.1 = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
182 mul.2.1 = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
183 mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
184 add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
185 ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
186 })")
187 .ValueOrDie();
188
189 HloPassPipeline fusion("fusion");
190 fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
191 fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
192 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
194
195 VLOG(2) << "Dump after horizontal fusion:";
196 VLOG(2) << module->ToString();
197
198 EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
199 }
200
TEST_F(HorizontalLoopFusionTest,GradientDescentOptimizerLike)201 TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
202 HloComputation::Builder builder(TestName());
203
204 std::vector<HloInstruction*> var_outs;
205 for (int64 i = 0; i < 128; ++i) {
206 // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
207 auto shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
208 HloInstruction* param_var_in = builder.AddInstruction(
209 HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
210 HloInstruction* param_alpha =
211 builder.AddInstruction(HloInstruction::CreateParameter(
212 i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
213 HloInstruction* param_delta = builder.AddInstruction(
214 HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
215 auto alpha_broadcasted = builder.AddInstruction(
216 HloInstruction::CreateBroadcast(shape, param_alpha, {}));
217 auto alpha_delta = builder.AddInstruction(HloInstruction::CreateBinary(
218 shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
219 auto var_out = builder.AddInstruction(HloInstruction::CreateBinary(
220 shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
221 var_outs.push_back(var_out);
222 }
223 builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
224
225 auto module = CreateNewVerifiedModule();
226 module->AddEntryComputation(builder.Build());
227
228 // Testing with the entire gpu optimization pipeline.
229 EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
230 }
231
TEST_F(HorizontalLoopFusionTest,FusingDifferentOutputs)232 TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
233 auto module = ParseAndReturnVerifiedModule(R"(
234 HloModule HeterogeneousMultiOutputFusions
235
236 fused_computation.1 {
237 arg.1 = f16[1024]{0} parameter(0)
238 arg.2 = f16[1024]{0} parameter(1)
239 arg.3 = f16[1024]{0} parameter(2)
240 arg.4 = f16[1024]{0} parameter(3)
241 mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
242 mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
243 add.1 = f16[1024]{0} add(mul.1, mul.2)
244 ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
245 }
246
247 fused_computation.2 {
248 arg.1 = f16[123]{0} parameter(0)
249 arg.2 = f16[123]{0} parameter(1)
250 arg.3 = f16[123]{0} parameter(2)
251 arg.4 = f16[123]{0} parameter(3)
252 add.1 = f16[123]{0} add(arg.1, arg.2)
253 add.2 = f16[123]{0} add(arg.3, arg.4)
254 mul.1 = f16[123]{0} multiply(add.1, add.2)
255 ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
256 }
257
258 ENTRY entry_computation {
259 arg.1 = f16[1024]{0} parameter(0)
260 arg.2 = f16[1024]{0} parameter(1)
261 arg.3 = f16[1024]{0} parameter(2)
262 arg.4 = f16[1024]{0} parameter(3)
263 arg.5 = f16[123]{0} parameter(4)
264 arg.6 = f16[123]{0} parameter(5)
265 arg.7 = f16[123]{0} parameter(6)
266 arg.8 = f16[123]{0} parameter(7)
267 fusion.1 = (f16[1024]{0}, f16[1024]{0})
268 fusion(arg.1, arg.2, arg.3, arg.4),
269 kind=kLoop, calls=fused_computation.1
270 fusion.2 = (f16[123]{0}, f16[123]{0})
271 fusion(arg.5, arg.6, arg.7, arg.8),
272 kind=kLoop, calls=fused_computation.2
273 gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
274 gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
275 gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
276 gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
277 ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
278 tuple(gte.1, gte.2, gte.3, gte.4)
279 }
280 )")
281 .ValueOrDie();
282
283 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
284 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
285
286 VLOG(2) << "Dump after horizontal fusion:";
287 VLOG(2) << module->ToString();
288
289 EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
290 }
291
TEST_F(HorizontalLoopFusionTest,RMSPropLike)292 TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
293 HloComputation::Builder builder(TestName());
294
295 std::vector<HloInstruction*> all_outputs;
296 for (int64 i = 0; i < 48; ++i) {
297 auto shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
298 // ms <- grad**2 (1 - rho) + ms * rho
299 HloInstruction* grad = builder.AddInstruction(
300 HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
301 HloInstruction* ms = builder.AddInstruction(
302 HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
303 HloInstruction* rho =
304 builder.AddInstruction(HloInstruction::CreateParameter(
305 i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
306 HloInstruction* one_minus_rho =
307 builder.AddInstruction(HloInstruction::CreateParameter(
308 i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
309 auto rho_broadcasted =
310 builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
311 auto one_mins_rho_broadcasted = builder.AddInstruction(
312 HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
313 auto grad_squared = builder.AddInstruction(
314 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
315 auto ms_1st_term = builder.AddInstruction(HloInstruction::CreateBinary(
316 shape, HloOpcode::kMultiply, grad_squared, one_mins_rho_broadcasted));
317 auto ms_2nd_term = builder.AddInstruction(HloInstruction::CreateBinary(
318 shape, HloOpcode::kMultiply, ms, rho_broadcasted));
319 auto ms_out = builder.AddInstruction(HloInstruction::CreateBinary(
320 shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
321
322 // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
323 HloInstruction* momentum = builder.AddInstruction(
324 HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
325 HloInstruction* mom = builder.AddInstruction(
326 HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
327 HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
328 i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
329 HloInstruction* epsilon =
330 builder.AddInstruction(HloInstruction::CreateParameter(
331 i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
332 auto lr_broadcasted =
333 builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
334 auto epsilon_broadcasted = builder.AddInstruction(
335 HloInstruction::CreateBroadcast(shape, epsilon, {}));
336 auto mom_1st_term = builder.AddInstruction(HloInstruction::CreateBinary(
337 shape, HloOpcode::kMultiply, momentum, mom));
338 auto ms_eps = builder.AddInstruction(HloInstruction::CreateBinary(
339 shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
340 auto ms_eps_rsq = builder.AddInstruction(
341 HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
342 auto grad_ms_eps_rsq = builder.AddInstruction(HloInstruction::CreateBinary(
343 shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
344 auto mom_2nd_term = builder.AddInstruction(HloInstruction::CreateBinary(
345 shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
346 auto mom_out = builder.AddInstruction(HloInstruction::CreateBinary(
347 shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
348
349 // var <- var - mom
350 HloInstruction* var = builder.AddInstruction(
351 HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
352 auto var_out = builder.AddInstruction(HloInstruction::CreateBinary(
353 shape, HloOpcode::kSubtract, var, mom_out));
354
355 all_outputs.push_back(ms_out);
356 all_outputs.push_back(mom_out);
357 all_outputs.push_back(var_out);
358 }
359 builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
360
361 auto module = CreateNewVerifiedModule();
362 module->AddEntryComputation(builder.Build());
363
364 EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
365 }
366
TEST_F(HorizontalLoopFusionTest,NegativeTestForDynamicUpdateSlice)367 TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
368 auto module = ParseAndReturnVerifiedModule(R"(
369 HloModule NegativeTestForDynamicUpdateSlice
370
371 fusion.1 {
372 p.0 = f16[5,9,10]{2,1,0} parameter(0)
373 p.1 = s32[1]{0} parameter(1)
374 p.2 = f16[1,9,10]{2,1,0} parameter(2)
375 c.0 = s32[] constant(0)
376 pad = s32[3]{0} pad(p.1, c.0), padding=0_2
377 ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
378 }
379
380 fusion.2 {
381 p.0 = f16[5,9,10]{2,1,0} parameter(0)
382 p.1 = s32[1]{0} parameter(1)
383 p.2 = f16[1,9,10]{2,1,0} parameter(2)
384 c.0 = s32[] constant(0)
385 pad = s32[3]{0} pad(p.1, c.0), padding=0_2
386 ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
387 }
388
389 ENTRY entry {
390 p.00 = f16[5,9,10]{2,1,0} parameter(0)
391 p.01 = f16[5,9,10]{2,1,0} parameter(1)
392 p.10 = s32[1]{0} parameter(2)
393 p.11 = s32[1]{0} parameter(3)
394 p.20 = f16[1,9,10]{2,1,0} parameter(4)
395 p.21 = f16[1,9,10]{2,1,0} parameter(5)
396
397 f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
398 f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
399 ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
400 })")
401 .ValueOrDie();
402
403 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
404 }
405
406 } // namespace
407 } // namespace gpu
408 } // namespace xla
409