1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 #include <algorithm>
18 #include <memory>
19 #include <new>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_runner.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/test_benchmark.h"
45 #include "tensorflow/core/platform/types.h"
46 
47 namespace xla {
48 namespace {
49 
50 class MultiOutputFusionTest : public HloTestBase {
51  protected:
MultiOutputFusionTest()52   MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
53 
54   // Layout assignment assumes that there are no fusions in the input graph.
55   // Since the purpose of this test is to send pre-fused graphs to XLA, we have
56   // to do layout assignment ourselves.
GetDebugOptionsForTest()57   DebugOptions GetDebugOptionsForTest() override {
58     auto opts = HloTestBase::GetDebugOptionsForTest();
59     opts.add_xla_disable_hlo_passes("layout-assignment");
60     return opts;
61   }
62 
RunTest2D(bool manual_fusion,int64 size)63   void RunTest2D(bool manual_fusion, int64 size) {
64     auto builder = HloComputation::Builder(TestName());
65     auto hlo_module = CreateNewUnverifiedModule();
66 
67     const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
68     const Shape elem_shape2 =
69         ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
70 
71     auto const0 = builder.AddInstruction(
72         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
73     auto param0 = builder.AddInstruction(
74         HloInstruction::CreateParameter(0, elem_shape0, "0"));
75 
76     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
77         elem_shape0, HloOpcode::kAdd, param0, const0));
78 
79     HloInstruction* broadcast = builder.AddInstruction(
80         HloInstruction::CreateBroadcast(elem_shape2, add1, {}));
81 
82     auto param1 = builder.AddInstruction(
83         HloInstruction::CreateParameter(1, elem_shape2, "1"));
84 
85     HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
86         elem_shape2, HloOpcode::kAdd, broadcast, param1));
87     HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
88         elem_shape2, HloOpcode::kSubtract, param1, broadcast));
89     DotDimensionNumbers dot_dnums;
90     dot_dnums.add_lhs_contracting_dimensions(1);
91     dot_dnums.add_rhs_contracting_dimensions(0);
92     HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
93         elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
94     auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
95 
96     if (manual_fusion) {
97       auto tuple =
98           computation->AddInstruction(HloInstruction::CreateTuple({sub, add2}));
99       auto gte0 = computation->AddInstruction(
100           HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0));
101       auto gte1 = computation->AddInstruction(
102           HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 1));
103       TF_CHECK_OK(dot->ReplaceOperandWith(0, gte0));
104       TF_CHECK_OK(dot->ReplaceOperandWith(1, gte1));
105 
106       CHECK_NE(
107           computation->CreateFusionInstruction(
108               {tuple, sub, add2, broadcast}, HloInstruction::FusionKind::kLoop),
109           nullptr);
110     }
111 
112     Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
113     arg1.PopulateWithValue<float>(2.5f);
114 
115     Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
116     expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
117     Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
118     auto actual =
119         ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
120     EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
121   }
122 
RunTest1D(bool manual_fusion,int size)123   void RunTest1D(bool manual_fusion, int size) {
124     auto builder = HloComputation::Builder(TestName());
125     auto hlo_module = CreateNewUnverifiedModule();
126 
127     const Shape elem_shape_F32 =
128         ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
129     const Shape elem_shape_U8 =
130         ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
131     auto param0 = builder.AddInstruction(
132         HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
133     auto param1 = builder.AddInstruction(
134         HloInstruction::CreateParameter(1, elem_shape_U8, "1"));
135 
136     HloInstruction* param0_U8 = builder.AddInstruction(
137         HloInstruction::CreateConvert(elem_shape_U8, param0));
138     HloInstruction* param1_F32 = builder.AddInstruction(
139         HloInstruction::CreateConvert(elem_shape_F32, param1));
140     HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
141         elem_shape_F32, HloOpcode::kAdd, param0, param1_F32));
142     HloInstruction* sub_U8 =
143         builder.AddInstruction(HloInstruction::CreateBinary(
144             elem_shape_U8, HloOpcode::kSubtract, param0_U8, param1));
145     HloInstruction* sub = builder.AddInstruction(
146         HloInstruction::CreateConvert(elem_shape_F32, sub_U8));
147 
148     HloInstruction* reshape =
149         builder.AddInstruction(HloInstruction::CreateReshape(
150             ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
151     DotDimensionNumbers dot_dnums;
152     dot_dnums.add_lhs_contracting_dimensions(0);
153     dot_dnums.add_rhs_contracting_dimensions(0);
154     HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
155         ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
156         dot_dnums, DefaultPrecisionConfig(2)));
157     auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
158 
159     if (manual_fusion) {
160       auto tuple = computation->AddInstruction(
161           HloInstruction::CreateTuple({sub_U8, add}));
162 
163       auto gte0 = computation->AddInstruction(
164           HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0));
165       auto gte1 = computation->AddInstruction(
166           HloInstruction::CreateGetTupleElement(elem_shape_F32, tuple, 1));
167       TF_CHECK_OK(sub->ReplaceOperandWith(0, gte0));
168       TF_CHECK_OK(reshape->ReplaceOperandWith(0, gte1));
169 
170       CHECK_NE(computation->CreateFusionInstruction(
171                    {tuple, sub_U8, add, param0_U8, param1_F32},
172                    HloInstruction::FusionKind::kLoop),
173                nullptr);
174     }
175 
176     Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
177     input0.PopulateWithValue(2.5f);
178     Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
179     input1.PopulateWithValue(1.);
180 
181     Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
182     auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
183     EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
184   }
185 };
186 
187 XLA_TEST_F(MultiOutputFusionTest, 2DNofusion) { RunTest2D(false, 5); }
188 XLA_TEST_F(MultiOutputFusionTest, 2DFusion) { RunTest2D(true, 5); }
189 XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); }
XLA_TEST_F(MultiOutputFusionTest,DiffentTypesNoFusion)190 XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); }
XLA_TEST_F(MultiOutputFusionTest,DiffentTypesFusion)191 XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); }
192 
XLA_TEST_F(MultiOutputFusionTest,FusionNodeIsRoot)193 XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
194   const char* testcase = R"(
195     HloModule m, is_scheduled=true
196 
197     fused_computation {
198       x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0)
199       gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0
200       gte.2 = (s32[]) get-tuple-element(gte.3), index=0
201       gte.4 = s32[] get-tuple-element(gte.2), index=0
202       copy = s32[] copy(gte.4)
203       ROOT tuple = (s32[]) tuple(copy)
204     }
205 
206     ENTRY thing.v3 {
207       x = (((s32[]), f32[]), (f32[], s32[])) parameter(0)
208       ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
209     }
210   )";
211   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
212   auto param = LiteralUtil::MakeTupleOwned(
213       LiteralUtil::MakeTupleOwned(
214           LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
215           LiteralUtil::CreateR0<float>(1.0)),
216       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
217                                   LiteralUtil::CreateR0<int32>(4)));
218   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
219   EXPECT_TRUE(LiteralTestUtil::Equal(
220       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
221 }
222 
XLA_TEST_F(MultiOutputFusionTest,MultiOutputLoopFusion)223 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
224   const char* testcase = R"(
225     HloModule m, is_scheduled=true
226 
227     fused_computation {
228       p = f32[4] parameter(0)
229       multiply = f32[4] multiply(p, p)
230       less-than = pred[4] compare(p, multiply), direction=LT
231       ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
232     }
233 
234     ENTRY PredFloatMOF {
235       p0 = f32[4] parameter(0)
236       fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation
237       gte0 = pred[4] get-tuple-element(fusion), index=0
238       gte1 = f32[4] get-tuple-element(fusion), index=1
239       const = f32[4] constant({0, 0, 0, 0})
240       ROOT select = f32[4] select(gte0, gte1, const)
241     })";
242   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
243   auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
244   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
245   LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
246 }
247 
XLA_TEST_F(MultiOutputFusionTest,MultiOutputLoopFeedingMap)248 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
249   const char* testcase = R"(
250     HloModule m, is_scheduled=true
251 
252     fused_computation {
253       p = f32[] parameter(0)
254       multiply = f32[] multiply(p, p)
255       less-than = pred[] compare(p, multiply), direction=LT
256       ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
257     }
258 
259     map_computation {
260       p0 = f32[] parameter(0)
261       fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation
262       gte0 = pred[] get-tuple-element(fusion), index=0
263       gte1 = f32[] get-tuple-element(fusion), index=1
264       const = f32[] constant(0)
265       ROOT select = f32[] select(gte0, gte1, const)
266     }
267 
268     ENTRY MapMOF {
269       p1 = f32[3] parameter(0)
270       ROOT map = f32[3] map(p1), to_apply=map_computation
271     })";
272   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
273   auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
274   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
275   LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
276 }
277 
278 const char* const kScalarOps = R"(
279     HloModule m, is_scheduled=true
280 
281     Add {
282       lhsadd = f32[] parameter(0)
283       rhsadd = f32[] parameter(1)
284       ROOT add = f32[] add(lhsadd, rhsadd)
285     }
286 
287     Max {
288       lhsmax = f32[] parameter(0)
289       rhsmax = f32[] parameter(1)
290       ROOT max = f32[] maximum(lhsmax, rhsmax)
291     }
292 )";
293 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMinor))294 XLA_TEST_F(MultiOutputFusionTest,
295            DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
296   const string testcase = absl::StrCat(kScalarOps, R"(
297     fused_reduce {
298       p0 = f32[2,2,2]{2,1,0} parameter(0)
299       c0 = f32[] constant(0)
300       r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
301       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
302       c1 = f32[] constant(5)
303       r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
304       ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
305     }
306 
307     ENTRY reduce {
308       p = f32[2,2,2]{2,1,0} parameter(0)
309       ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
310                                                         calls=fused_reduce
311     })");
312   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
313   auto param =
314       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
315   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
316   EXPECT_TRUE(LiteralTestUtil::Equal(
317       LiteralUtil::MakeTupleOwned(
318           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
319           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
320       result));
321 }
322 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMajor))323 XLA_TEST_F(MultiOutputFusionTest,
324            DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
325   const string testcase = absl::StrCat(kScalarOps, R"(
326     fused_reduce {
327       p0 = f32[2,2,2]{2,1,0} parameter(0)
328       c0 = f32[] constant(0)
329       r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
330       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
331       c1 = f32[] constant(5)
332       r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
333       ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
334     }
335 
336     ENTRY reduce {
337       p = f32[2,2,2]{2,1,0} parameter(0)
338       ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
339                                                         calls=fused_reduce
340     })");
341   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
342   auto param =
343       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
344   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
345   EXPECT_TRUE(LiteralTestUtil::Equal(
346       LiteralUtil::MakeTupleOwned(
347           LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
348           LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
349       result));
350 }
351 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionScalar))352 XLA_TEST_F(MultiOutputFusionTest,
353            DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
354   const string testcase = absl::StrCat(kScalarOps, R"(
355     fused_reduce {
356       p0 = f32[2,2,2]{2,1,0} parameter(0)
357       c0 = f32[] constant(0)
358       r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
359       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
360       c1 = f32[] constant(1.17549e-38)
361       r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
362       r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
363       ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
364     }
365 
366     ENTRY reduce {
367       p = f32[2,2,2]{2,1,0} parameter(0)
368       ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
369                                                         calls=fused_reduce
370     })");
371   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
372   auto param =
373       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
374   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
375   EXPECT_TRUE(LiteralTestUtil::Equal(
376       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
377                                   LiteralUtil::CreateR1<float>({36, 64}),
378                                   LiteralUtil::CreateR1<float>({66, 138})),
379       result));
380 }
381 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMinorWithExtraOutput))382 XLA_TEST_F(MultiOutputFusionTest,
383            DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
384   const string testcase = absl::StrCat(kScalarOps, R"(
385     fused_reduce {
386       p0 = f32[2,2,2]{2,1,0} parameter(0)
387       c0 = f32[] constant(0)
388       r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
389       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
390       c1 = f32[] constant(5)
391       r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
392       ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
393                      tuple(p0, r1, r2)
394     }
395 
396     ENTRY reduce {
397       p = f32[2,2,2]{2,1,0} parameter(0)
398       ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
399                                                  kind=kInput, calls=fused_reduce
400     })");
401   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
402   auto param =
403       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
404   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
405   EXPECT_TRUE(LiteralTestUtil::Equal(
406       LiteralUtil::MakeTupleOwned(
407           LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
408           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
409           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
410       result));
411 }
412 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionMajorWithExtraOutput))413 XLA_TEST_F(MultiOutputFusionTest,
414            DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
415   const string testcase = absl::StrCat(kScalarOps, R"(
416     fused_reduce {
417       p0 = f32[2,2,2]{2,1,0} parameter(0)
418       c0 = f32[] constant(0)
419       r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
420       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
421       c1 = f32[] constant(5)
422       r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
423       ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
424                      tuple(r1, mul, r2)
425     }
426 
427     ENTRY reduce {
428       p = f32[2,2,2]{2,1,0} parameter(0)
429       ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
430                                                  kind=kInput, calls=fused_reduce
431     })");
432   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
433   auto param =
434       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
435   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
436   EXPECT_TRUE(LiteralTestUtil::Equal(
437       LiteralUtil::MakeTupleOwned(
438           LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
439           LiteralUtil::CreateR3<float>(
440               {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
441           LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
442       result));
443 }
444 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionScalarWithExtraOutput))445 XLA_TEST_F(MultiOutputFusionTest,
446            DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
447   const string testcase = absl::StrCat(kScalarOps, R"(
448     fused_reduce {
449       p0 = f32[2,2,2]{2,1,0} parameter(0)
450       c0 = f32[] constant(0)
451       r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
452       mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
453       c1 = f32[] constant(5)
454       b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
455       mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
456       ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
457                                                            tuple(r1, mul, mul2)
458     }
459 
460     ENTRY reduce {
461       p = f32[2,2,2]{2,1,0} parameter(0)
462       ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
463                                                  kind=kInput, calls=fused_reduce
464     })");
465   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
466   auto param =
467       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
468   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
469   EXPECT_TRUE(LiteralTestUtil::Equal(
470       LiteralUtil::MakeTupleOwned(
471           LiteralUtil::CreateR1<float>({14, 22}),
472           LiteralUtil::CreateR3<float>(
473               {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
474           LiteralUtil::CreateR3<float>(
475               {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
476       result));
477 }
478 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionNonConstInit))479 XLA_TEST_F(MultiOutputFusionTest,
480            DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
481   const string testcase = absl::StrCat(kScalarOps, R"(
482     fused_reduce {
483       p0 = f32[2,2,2]{2,1,0} parameter(0)
484       init1 = f32[] parameter(1)
485       init2 = f32[] parameter(2)
486       r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
487       r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
488       ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
489     }
490 
491     ENTRY reduce {
492       p = f32[2,2,2]{2,1,0} parameter(0)
493       i = f32[] parameter(1)
494       j = f32[] parameter(2)
495       ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
496                                                               calls=fused_reduce
497     })");
498   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
499   auto param =
500       LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
501   auto init1 = LiteralUtil::CreateR0<float>(5);
502   auto init2 = LiteralUtil::CreateR0<float>(6);
503   Literal result =
504       ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
505   EXPECT_TRUE(LiteralTestUtil::Equal(
506       LiteralUtil::MakeTupleOwned(
507           LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
508           LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
509       result));
510 }
511 
XLA_TEST_F(MultiOutputFusionTest,DISABLED_ON_CPU (MultiOutputReduceFusionDifferentElementTypes))512 XLA_TEST_F(MultiOutputFusionTest,
513            DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
514   const string testcase = absl::StrCat(kScalarOps, R"(
515     fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
516       p0 = f16[2,2,2]{2,1,0} parameter(0)
517       convert = f32[2,2,2]{2,1,0} convert(p0)
518       c0 = f32[] constant(0)
519       r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
520       mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
521       c1 = f32[] constant(5)
522       r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
523       ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0})
524                    tuple(r1, r2, p0)
525     }
526 
527     ENTRY reduce {
528       p = f16[2,2,2]{2,1,0} parameter(0)
529       ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
530                     kind=kInput, calls=fused_reduce
531     })");
532   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
533   auto param = LiteralUtil::CreateR3<Eigen::half>(
534       {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
535        {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
536   Literal result = ExecuteNoHloPasses(std::move(module), {&param});
537   EXPECT_TRUE(LiteralTestUtil::Equal(
538       LiteralUtil::MakeTupleOwned(
539           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
540           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
541           LiteralUtil::CreateR3<Eigen::half>(
542               {{{Eigen::half(1), Eigen::half(2)},
543                 {Eigen::half(3), Eigen::half(4)}},
544                {{Eigen::half(5), Eigen::half(6)},
545                 {Eigen::half(7), Eigen::half(8)}}})),
546       result));
547 }
548 
549 }  // namespace
550 }  // namespace xla
551