1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <math.h>
17 #include <algorithm>
18 #include <memory>
19 #include <new>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_runner.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/test_benchmark.h"
45 #include "tensorflow/core/platform/types.h"
46
47 namespace xla {
48 namespace {
49
50 class MultiOutputFusionTest : public HloTestBase {
51 protected:
MultiOutputFusionTest()52 MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
53
54 // Layout assignment assumes that there are no fusions in the input graph.
55 // Since the purpose of this test is to send pre-fused graphs to XLA, we have
56 // to do layout assignment ourselves.
GetDebugOptionsForTest()57 DebugOptions GetDebugOptionsForTest() override {
58 auto opts = HloTestBase::GetDebugOptionsForTest();
59 opts.add_xla_disable_hlo_passes("layout-assignment");
60 return opts;
61 }
62
RunTest2D(bool manual_fusion,int64 size)63 void RunTest2D(bool manual_fusion, int64 size) {
64 auto builder = HloComputation::Builder(TestName());
65 auto hlo_module = CreateNewUnverifiedModule();
66
67 const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
68 const Shape elem_shape2 =
69 ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
70
71 auto const0 = builder.AddInstruction(
72 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
73 auto param0 = builder.AddInstruction(
74 HloInstruction::CreateParameter(0, elem_shape0, "0"));
75
76 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
77 elem_shape0, HloOpcode::kAdd, param0, const0));
78
79 HloInstruction* broadcast = builder.AddInstruction(
80 HloInstruction::CreateBroadcast(elem_shape2, add1, {}));
81
82 auto param1 = builder.AddInstruction(
83 HloInstruction::CreateParameter(1, elem_shape2, "1"));
84
85 HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
86 elem_shape2, HloOpcode::kAdd, broadcast, param1));
87 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
88 elem_shape2, HloOpcode::kSubtract, param1, broadcast));
89 DotDimensionNumbers dot_dnums;
90 dot_dnums.add_lhs_contracting_dimensions(1);
91 dot_dnums.add_rhs_contracting_dimensions(0);
92 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
93 elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
94 auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
95
96 if (manual_fusion) {
97 auto tuple =
98 computation->AddInstruction(HloInstruction::CreateTuple({sub, add2}));
99 auto gte0 = computation->AddInstruction(
100 HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0));
101 auto gte1 = computation->AddInstruction(
102 HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 1));
103 TF_CHECK_OK(dot->ReplaceOperandWith(0, gte0));
104 TF_CHECK_OK(dot->ReplaceOperandWith(1, gte1));
105
106 CHECK_NE(
107 computation->CreateFusionInstruction(
108 {tuple, sub, add2, broadcast}, HloInstruction::FusionKind::kLoop),
109 nullptr);
110 }
111
112 Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
113 arg1.PopulateWithValue<float>(2.5f);
114
115 Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
116 expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
117 Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
118 auto actual =
119 ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
120 EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
121 }
122
RunTest1D(bool manual_fusion,int size)123 void RunTest1D(bool manual_fusion, int size) {
124 auto builder = HloComputation::Builder(TestName());
125 auto hlo_module = CreateNewUnverifiedModule();
126
127 const Shape elem_shape_F32 =
128 ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
129 const Shape elem_shape_U8 =
130 ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
131 auto param0 = builder.AddInstruction(
132 HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
133 auto param1 = builder.AddInstruction(
134 HloInstruction::CreateParameter(1, elem_shape_U8, "1"));
135
136 HloInstruction* param0_U8 = builder.AddInstruction(
137 HloInstruction::CreateConvert(elem_shape_U8, param0));
138 HloInstruction* param1_F32 = builder.AddInstruction(
139 HloInstruction::CreateConvert(elem_shape_F32, param1));
140 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
141 elem_shape_F32, HloOpcode::kAdd, param0, param1_F32));
142 HloInstruction* sub_U8 =
143 builder.AddInstruction(HloInstruction::CreateBinary(
144 elem_shape_U8, HloOpcode::kSubtract, param0_U8, param1));
145 HloInstruction* sub = builder.AddInstruction(
146 HloInstruction::CreateConvert(elem_shape_F32, sub_U8));
147
148 HloInstruction* reshape =
149 builder.AddInstruction(HloInstruction::CreateReshape(
150 ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
151 DotDimensionNumbers dot_dnums;
152 dot_dnums.add_lhs_contracting_dimensions(0);
153 dot_dnums.add_rhs_contracting_dimensions(0);
154 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
155 ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
156 dot_dnums, DefaultPrecisionConfig(2)));
157 auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
158
159 if (manual_fusion) {
160 auto tuple = computation->AddInstruction(
161 HloInstruction::CreateTuple({sub_U8, add}));
162
163 auto gte0 = computation->AddInstruction(
164 HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0));
165 auto gte1 = computation->AddInstruction(
166 HloInstruction::CreateGetTupleElement(elem_shape_F32, tuple, 1));
167 TF_CHECK_OK(sub->ReplaceOperandWith(0, gte0));
168 TF_CHECK_OK(reshape->ReplaceOperandWith(0, gte1));
169
170 CHECK_NE(computation->CreateFusionInstruction(
171 {tuple, sub_U8, add, param0_U8, param1_F32},
172 HloInstruction::FusionKind::kLoop),
173 nullptr);
174 }
175
176 Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
177 input0.PopulateWithValue(2.5f);
178 Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
179 input1.PopulateWithValue(1.);
180
181 Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
182 auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
183 EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
184 }
185 };
186
187 XLA_TEST_F(MultiOutputFusionTest, 2DNofusion) { RunTest2D(false, 5); }
188 XLA_TEST_F(MultiOutputFusionTest, 2DFusion) { RunTest2D(true, 5); }
189 XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); }
XLA_TEST_F(MultiOutputFusionTest,DiffentTypesNoFusion)190 XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); }
XLA_TEST_F(MultiOutputFusionTest,DiffentTypesFusion)191 XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); }
192
XLA_TEST_F(MultiOutputFusionTest,FusionNodeIsRoot)193 XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
194 const char* testcase = R"(
195 HloModule m, is_scheduled=true
196
197 fused_computation {
198 x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0)
199 gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0
200 gte.2 = (s32[]) get-tuple-element(gte.3), index=0
201 gte.4 = s32[] get-tuple-element(gte.2), index=0
202 copy = s32[] copy(gte.4)
203 ROOT tuple = (s32[]) tuple(copy)
204 }
205
206 ENTRY thing.v3 {
207 x = (((s32[]), f32[]), (f32[], s32[])) parameter(0)
208 ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
209 }
210 )";
211 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
212 auto param = LiteralUtil::MakeTupleOwned(
213 LiteralUtil::MakeTupleOwned(
214 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
215 LiteralUtil::CreateR0<float>(1.0)),
216 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
217 LiteralUtil::CreateR0<int32>(4)));
218 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
219 EXPECT_TRUE(LiteralTestUtil::Equal(
220 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
221 }
222
XLA_TEST_F(MultiOutputFusionTest,MultiOutputLoopFusion)223 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
224 const char* testcase = R"(
225 HloModule m, is_scheduled=true
226
227 fused_computation {
228 p = f32[4] parameter(0)
229 multiply = f32[4] multiply(p, p)
230 less-than = pred[4] compare(p, multiply), direction=LT
231 ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
232 }
233
234 ENTRY PredFloatMOF {
235 p0 = f32[4] parameter(0)
236 fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation
237 gte0 = pred[4] get-tuple-element(fusion), index=0
238 gte1 = f32[4] get-tuple-element(fusion), index=1
239 const = f32[4] constant({0, 0, 0, 0})
240 ROOT select = f32[4] select(gte0, gte1, const)
241 })";
242 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
243 auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
244 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
245 LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
246 }
247
XLA_TEST_F(MultiOutputFusionTest,MultiOutputLoopFeedingMap)248 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
249 const char* testcase = R"(
250 HloModule m, is_scheduled=true
251
252 fused_computation {
253 p = f32[] parameter(0)
254 multiply = f32[] multiply(p, p)
255 less-than = pred[] compare(p, multiply), direction=LT
256 ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
257 }
258
259 map_computation {
260 p0 = f32[] parameter(0)
261 fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation
262 gte0 = pred[] get-tuple-element(fusion), index=0
263 gte1 = f32[] get-tuple-element(fusion), index=1
264 const = f32[] constant(0)
265 ROOT select = f32[] select(gte0, gte1, const)
266 }
267
268 ENTRY MapMOF {
269 p1 = f32[3] parameter(0)
270 ROOT map = f32[3] map(p1), to_apply=map_computation
271 })";
272 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
273 auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
274 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
275 LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
276 }
277
278 const char* const kScalarOps = R"(
279 HloModule m, is_scheduled=true
280
281 Add {
282 lhsadd = f32[] parameter(0)
283 rhsadd = f32[] parameter(1)
284 ROOT add = f32[] add(lhsadd, rhsadd)
285 }
286
287 Max {
288 lhsmax = f32[] parameter(0)
289 rhsmax = f32[] parameter(1)
290 ROOT max = f32[] maximum(lhsmax, rhsmax)
291 }
292 )";
293
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMinor))294 XLA_TEST_F(MultiOutputFusionTest,
295 DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
296 const string testcase = absl::StrCat(kScalarOps, R"(
297 fused_reduce {
298 p0 = f32[2,2,2]{2,1,0} parameter(0)
299 c0 = f32[] constant(0)
300 r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
301 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
302 c1 = f32[] constant(5)
303 r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
304 ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
305 }
306
307 ENTRY reduce {
308 p = f32[2,2,2]{2,1,0} parameter(0)
309 ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
310 calls=fused_reduce
311 })");
312 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
313 auto param =
314 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
315 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
316 EXPECT_TRUE(LiteralTestUtil::Equal(
317 LiteralUtil::MakeTupleOwned(
318 LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
319 LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
320 result));
321 }
322
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMajor))323 XLA_TEST_F(MultiOutputFusionTest,
324 DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
325 const string testcase = absl::StrCat(kScalarOps, R"(
326 fused_reduce {
327 p0 = f32[2,2,2]{2,1,0} parameter(0)
328 c0 = f32[] constant(0)
329 r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
330 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
331 c1 = f32[] constant(5)
332 r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
333 ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
334 }
335
336 ENTRY reduce {
337 p = f32[2,2,2]{2,1,0} parameter(0)
338 ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
339 calls=fused_reduce
340 })");
341 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
342 auto param =
343 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
344 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
345 EXPECT_TRUE(LiteralTestUtil::Equal(
346 LiteralUtil::MakeTupleOwned(
347 LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
348 LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
349 result));
350 }
351
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionScalar))352 XLA_TEST_F(MultiOutputFusionTest,
353 DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
354 const string testcase = absl::StrCat(kScalarOps, R"(
355 fused_reduce {
356 p0 = f32[2,2,2]{2,1,0} parameter(0)
357 c0 = f32[] constant(0)
358 r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
359 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
360 c1 = f32[] constant(1.17549e-38)
361 r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
362 r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
363 ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
364 }
365
366 ENTRY reduce {
367 p = f32[2,2,2]{2,1,0} parameter(0)
368 ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
369 calls=fused_reduce
370 })");
371 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
372 auto param =
373 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
374 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
375 EXPECT_TRUE(LiteralTestUtil::Equal(
376 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
377 LiteralUtil::CreateR1<float>({36, 64}),
378 LiteralUtil::CreateR1<float>({66, 138})),
379 result));
380 }
381
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMinorWithExtraOutput))382 XLA_TEST_F(MultiOutputFusionTest,
383 DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
384 const string testcase = absl::StrCat(kScalarOps, R"(
385 fused_reduce {
386 p0 = f32[2,2,2]{2,1,0} parameter(0)
387 c0 = f32[] constant(0)
388 r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
389 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
390 c1 = f32[] constant(5)
391 r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
392 ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
393 tuple(p0, r1, r2)
394 }
395
396 ENTRY reduce {
397 p = f32[2,2,2]{2,1,0} parameter(0)
398 ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
399 kind=kInput, calls=fused_reduce
400 })");
401 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
402 auto param =
403 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
404 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
405 EXPECT_TRUE(LiteralTestUtil::Equal(
406 LiteralUtil::MakeTupleOwned(
407 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
408 LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
409 LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
410 result));
411 }
412
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMajorWithExtraOutput))413 XLA_TEST_F(MultiOutputFusionTest,
414 DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
415 const string testcase = absl::StrCat(kScalarOps, R"(
416 fused_reduce {
417 p0 = f32[2,2,2]{2,1,0} parameter(0)
418 c0 = f32[] constant(0)
419 r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
420 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
421 c1 = f32[] constant(5)
422 r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
423 ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
424 tuple(r1, mul, r2)
425 }
426
427 ENTRY reduce {
428 p = f32[2,2,2]{2,1,0} parameter(0)
429 ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
430 kind=kInput, calls=fused_reduce
431 })");
432 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
433 auto param =
434 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
435 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
436 EXPECT_TRUE(LiteralTestUtil::Equal(
437 LiteralUtil::MakeTupleOwned(
438 LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
439 LiteralUtil::CreateR3<float>(
440 {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
441 LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
442 result));
443 }
444
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionScalarWithExtraOutput))445 XLA_TEST_F(MultiOutputFusionTest,
446 DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
447 const string testcase = absl::StrCat(kScalarOps, R"(
448 fused_reduce {
449 p0 = f32[2,2,2]{2,1,0} parameter(0)
450 c0 = f32[] constant(0)
451 r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
452 mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
453 c1 = f32[] constant(5)
454 b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
455 mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
456 ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
457 tuple(r1, mul, mul2)
458 }
459
460 ENTRY reduce {
461 p = f32[2,2,2]{2,1,0} parameter(0)
462 ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
463 kind=kInput, calls=fused_reduce
464 })");
465 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
466 auto param =
467 LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
468 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
469 EXPECT_TRUE(LiteralTestUtil::Equal(
470 LiteralUtil::MakeTupleOwned(
471 LiteralUtil::CreateR1<float>({14, 22}),
472 LiteralUtil::CreateR3<float>(
473 {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
474 LiteralUtil::CreateR3<float>(
475 {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
476 result));
477 }
478
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionNonConstInit))479 XLA_TEST_F(MultiOutputFusionTest,
480 DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
481 const string testcase = absl::StrCat(kScalarOps, R"(
482 fused_reduce {
483 p0 = f32[2,2,2]{2,1,0} parameter(0)
484 init1 = f32[] parameter(1)
485 init2 = f32[] parameter(2)
486 r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
487 r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
488 ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
489 }
490
491 ENTRY reduce {
492 p = f32[2,2,2]{2,1,0} parameter(0)
493 i = f32[] parameter(1)
494 j = f32[] parameter(2)
495 ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
496 calls=fused_reduce
497 })");
498 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
499 auto param =
500 LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
501 auto init1 = LiteralUtil::CreateR0<float>(5);
502 auto init2 = LiteralUtil::CreateR0<float>(6);
503 Literal result =
504 ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2});
505 EXPECT_TRUE(LiteralTestUtil::Equal(
506 LiteralUtil::MakeTupleOwned(
507 LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
508 LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
509 result));
510 }
511
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionDifferentElementTypes))512 XLA_TEST_F(MultiOutputFusionTest,
513 DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
514 const string testcase = absl::StrCat(kScalarOps, R"(
515 fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
516 p0 = f16[2,2,2]{2,1,0} parameter(0)
517 convert = f32[2,2,2]{2,1,0} convert(p0)
518 c0 = f32[] constant(0)
519 r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
520 mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
521 c1 = f32[] constant(5)
522 r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
523 ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0})
524 tuple(r1, r2, p0)
525 }
526
527 ENTRY reduce {
528 p = f16[2,2,2]{2,1,0} parameter(0)
529 ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
530 kind=kInput, calls=fused_reduce
531 })");
532 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
533 auto param = LiteralUtil::CreateR3<Eigen::half>(
534 {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
535 {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
536 Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
537 EXPECT_TRUE(LiteralTestUtil::Equal(
538 LiteralUtil::MakeTupleOwned(
539 LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
540 LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
541 LiteralUtil::CreateR3<Eigen::half>(
542 {{{Eigen::half(1), Eigen::half(2)},
543 {Eigen::half(3), Eigen::half(4)}},
544 {{Eigen::half(5), Eigen::half(6)},
545 {Eigen::half(7), Eigen::half(8)}}})),
546 result));
547 }
548
549 } // namespace
550 } // namespace xla
551