1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/defuser.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22
23 namespace op = xla::testing::opcode_matchers;
24
25 namespace xla {
26 namespace {
27
28 class DefuserTest : public HloTestBase {
29 protected:
30 // Returns the number of fusion instructions in the module.
FusionCount(const HloModule * m)31 int FusionCount(const HloModule* m) {
32 int count = 0;
33 for (HloComputation* computation : m->computations()) {
34 if (computation->IsFusionComputation()) {
35 count++;
36 }
37 }
38 return count;
39 }
40
41 Defuser defuser_;
42 const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2});
43 };
44
TEST_F(DefuserTest,NoFusionInstruction)45 TEST_F(DefuserTest, NoFusionInstruction) {
46 auto m = CreateNewVerifiedModule();
47 auto builder = HloComputation::Builder(TestName());
48 auto param0 =
49 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
50 auto param1 =
51 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
52 builder.AddInstruction(
53 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
54
55 m->AddEntryComputation(builder.Build());
56 EXPECT_EQ(0, FusionCount(m.get()));
57
58 EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie());
59 }
60
TEST_F(DefuserTest,TrivialFusionInstructionAsRoot)61 TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) {
62 auto m = CreateNewVerifiedModule();
63 auto builder = HloComputation::Builder(TestName());
64 auto param0 =
65 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
66 auto param1 =
67 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
68 auto add = builder.AddInstruction(
69 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
70
71 auto computation = m->AddEntryComputation(builder.Build());
72 computation->CreateFusionInstruction({add},
73 HloInstruction::FusionKind::kLoop);
74
75 EXPECT_THAT(computation->root_instruction(), op::Fusion());
76
77 EXPECT_EQ(1, FusionCount(m.get()));
78 EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
79 EXPECT_EQ(0, FusionCount(m.get()));
80
81 EXPECT_THAT(computation->root_instruction(),
82 op::Add(op::Parameter(), op::Parameter()));
83 }
84
TEST_F(DefuserTest,TrivialFusionInstructionNotAsRoot)85 TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) {
86 auto m = CreateNewVerifiedModule();
87 auto builder = HloComputation::Builder(TestName());
88 auto param0 =
89 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
90 auto param1 =
91 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
92 auto add = builder.AddInstruction(
93 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
94 builder.AddInstruction(
95 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
96
97 auto computation = m->AddEntryComputation(builder.Build());
98 computation->CreateFusionInstruction({add},
99 HloInstruction::FusionKind::kLoop);
100
101 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion()));
102
103 EXPECT_EQ(1, FusionCount(m.get()));
104 EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
105 EXPECT_EQ(0, FusionCount(m.get()));
106
107 EXPECT_THAT(computation->root_instruction(),
108 op::Negate(op::Add(op::Parameter(), op::Parameter())));
109 }
110
TEST_F(DefuserTest,NonTrivialFusionInstruction)111 TEST_F(DefuserTest, NonTrivialFusionInstruction) {
112 auto m = CreateNewVerifiedModule();
113 auto builder = HloComputation::Builder(TestName());
114 auto param0 =
115 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
116 auto param1 =
117 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
118 auto param3 =
119 builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
120 auto add = builder.AddInstruction(
121 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
122 auto negate = builder.AddInstruction(
123 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
124 auto sub = builder.AddInstruction(
125 HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
126 auto mul = builder.AddInstruction(
127 HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
128 auto div = builder.AddInstruction(
129 HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
130 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
131 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
132 auto add2 = builder.AddInstruction(
133 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
134
135 auto computation = m->AddEntryComputation(builder.Build());
136 computation->CreateFusionInstruction(
137 {add2, constant, div, mul, sub, negate, add},
138 HloInstruction::FusionKind::kLoop);
139
140 EXPECT_THAT(computation->root_instruction(), op::Fusion());
141
142 EXPECT_EQ(1, FusionCount(m.get()));
143 EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
144 EXPECT_EQ(0, FusionCount(m.get()));
145
146 EXPECT_THAT(computation->root_instruction(),
147 op::Add(op::Constant(), op::Divide()));
148 }
149
TEST_F(DefuserTest,MultipleFusionInstructions)150 TEST_F(DefuserTest, MultipleFusionInstructions) {
151 auto m = CreateNewVerifiedModule();
152 auto builder = HloComputation::Builder(TestName());
153 auto param0 =
154 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
155 auto param1 =
156 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
157 auto param3 =
158 builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
159 auto add = builder.AddInstruction(
160 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
161 auto negate = builder.AddInstruction(
162 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
163 auto sub = builder.AddInstruction(
164 HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
165 auto mul = builder.AddInstruction(
166 HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
167 auto div = builder.AddInstruction(
168 HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
169 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
170 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
171 auto add2 = builder.AddInstruction(
172 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
173
174 auto computation = m->AddEntryComputation(builder.Build());
175 computation->CreateFusionInstruction({add2, constant, div, mul},
176 HloInstruction::FusionKind::kLoop);
177 computation->CreateFusionInstruction({sub, negate, add},
178 HloInstruction::FusionKind::kLoop);
179
180 EXPECT_THAT(computation->root_instruction(), op::Fusion());
181
182 EXPECT_EQ(2, FusionCount(m.get()));
183 EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
184 EXPECT_EQ(0, FusionCount(m.get()));
185
186 EXPECT_THAT(computation->root_instruction(),
187 op::Add(op::Constant(), op::Divide()));
188 }
189
TEST_F(DefuserTest,NestedFusionInstructions)190 TEST_F(DefuserTest, NestedFusionInstructions) {
191 auto m = CreateNewVerifiedModule();
192 auto builder = HloComputation::Builder(TestName());
193 auto param0 =
194 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
195 auto param1 =
196 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
197 auto add = builder.AddInstruction(
198 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
199 auto negate = builder.AddInstruction(
200 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
201
202 auto computation = m->AddEntryComputation(builder.Build());
203 auto outer_fusion = computation->CreateFusionInstruction(
204 {negate, add}, HloInstruction::FusionKind::kLoop);
205 HloInstruction* fused_negate = outer_fusion->fused_expression_root();
206 ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate);
207 outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
208 {fused_negate}, HloInstruction::FusionKind::kLoop);
209
210 EXPECT_THAT(computation->root_instruction(), op::Fusion());
211
212 EXPECT_EQ(2, FusionCount(m.get()));
213 EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
214 EXPECT_EQ(0, FusionCount(m.get()));
215
216 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add()));
217 }
218
219 } // namespace
220 } // namespace xla
221