1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/defuser.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 
23 namespace op = xla::testing::opcode_matchers;
24 
25 namespace xla {
26 namespace {
27 
28 class DefuserTest : public HloTestBase {
29  protected:
30   // Returns the number of fusion instructions in the module.
FusionCount(const HloModule * m)31   int FusionCount(const HloModule* m) {
32     int count = 0;
33     for (HloComputation* computation : m->computations()) {
34       if (computation->IsFusionComputation()) {
35         count++;
36       }
37     }
38     return count;
39   }
40 
41   Defuser defuser_;
42   const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2});
43 };
44 
TEST_F(DefuserTest,NoFusionInstruction)45 TEST_F(DefuserTest, NoFusionInstruction) {
46   auto m = CreateNewVerifiedModule();
47   auto builder = HloComputation::Builder(TestName());
48   auto param0 =
49       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
50   auto param1 =
51       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
52   builder.AddInstruction(
53       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
54 
55   m->AddEntryComputation(builder.Build());
56   EXPECT_EQ(0, FusionCount(m.get()));
57 
58   EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie());
59 }
60 
TEST_F(DefuserTest,TrivialFusionInstructionAsRoot)61 TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) {
62   auto m = CreateNewVerifiedModule();
63   auto builder = HloComputation::Builder(TestName());
64   auto param0 =
65       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
66   auto param1 =
67       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
68   auto add = builder.AddInstruction(
69       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
70 
71   auto computation = m->AddEntryComputation(builder.Build());
72   computation->CreateFusionInstruction({add},
73                                        HloInstruction::FusionKind::kLoop);
74 
75   EXPECT_THAT(computation->root_instruction(), op::Fusion());
76 
77   EXPECT_EQ(1, FusionCount(m.get()));
78   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
79   EXPECT_EQ(0, FusionCount(m.get()));
80 
81   EXPECT_THAT(computation->root_instruction(),
82               op::Add(op::Parameter(), op::Parameter()));
83 }
84 
TEST_F(DefuserTest,TrivialFusionInstructionNotAsRoot)85 TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) {
86   auto m = CreateNewVerifiedModule();
87   auto builder = HloComputation::Builder(TestName());
88   auto param0 =
89       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
90   auto param1 =
91       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
92   auto add = builder.AddInstruction(
93       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
94   builder.AddInstruction(
95       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
96 
97   auto computation = m->AddEntryComputation(builder.Build());
98   computation->CreateFusionInstruction({add},
99                                        HloInstruction::FusionKind::kLoop);
100 
101   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion()));
102 
103   EXPECT_EQ(1, FusionCount(m.get()));
104   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
105   EXPECT_EQ(0, FusionCount(m.get()));
106 
107   EXPECT_THAT(computation->root_instruction(),
108               op::Negate(op::Add(op::Parameter(), op::Parameter())));
109 }
110 
TEST_F(DefuserTest,NonTrivialFusionInstruction)111 TEST_F(DefuserTest, NonTrivialFusionInstruction) {
112   auto m = CreateNewVerifiedModule();
113   auto builder = HloComputation::Builder(TestName());
114   auto param0 =
115       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
116   auto param1 =
117       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
118   auto param3 =
119       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
120   auto add = builder.AddInstruction(
121       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
122   auto negate = builder.AddInstruction(
123       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
124   auto sub = builder.AddInstruction(
125       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
126   auto mul = builder.AddInstruction(
127       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
128   auto div = builder.AddInstruction(
129       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
130   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
131       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
132   auto add2 = builder.AddInstruction(
133       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
134 
135   auto computation = m->AddEntryComputation(builder.Build());
136   computation->CreateFusionInstruction(
137       {add2, constant, div, mul, sub, negate, add},
138       HloInstruction::FusionKind::kLoop);
139 
140   EXPECT_THAT(computation->root_instruction(), op::Fusion());
141 
142   EXPECT_EQ(1, FusionCount(m.get()));
143   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
144   EXPECT_EQ(0, FusionCount(m.get()));
145 
146   EXPECT_THAT(computation->root_instruction(),
147               op::Add(op::Constant(), op::Divide()));
148 }
149 
TEST_F(DefuserTest,MultipleFusionInstructions)150 TEST_F(DefuserTest, MultipleFusionInstructions) {
151   auto m = CreateNewVerifiedModule();
152   auto builder = HloComputation::Builder(TestName());
153   auto param0 =
154       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
155   auto param1 =
156       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
157   auto param3 =
158       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
159   auto add = builder.AddInstruction(
160       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
161   auto negate = builder.AddInstruction(
162       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
163   auto sub = builder.AddInstruction(
164       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
165   auto mul = builder.AddInstruction(
166       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
167   auto div = builder.AddInstruction(
168       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
169   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
170       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
171   auto add2 = builder.AddInstruction(
172       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
173 
174   auto computation = m->AddEntryComputation(builder.Build());
175   computation->CreateFusionInstruction({add2, constant, div, mul},
176                                        HloInstruction::FusionKind::kLoop);
177   computation->CreateFusionInstruction({sub, negate, add},
178                                        HloInstruction::FusionKind::kLoop);
179 
180   EXPECT_THAT(computation->root_instruction(), op::Fusion());
181 
182   EXPECT_EQ(2, FusionCount(m.get()));
183   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
184   EXPECT_EQ(0, FusionCount(m.get()));
185 
186   EXPECT_THAT(computation->root_instruction(),
187               op::Add(op::Constant(), op::Divide()));
188 }
189 
TEST_F(DefuserTest,NestedFusionInstructions)190 TEST_F(DefuserTest, NestedFusionInstructions) {
191   auto m = CreateNewVerifiedModule();
192   auto builder = HloComputation::Builder(TestName());
193   auto param0 =
194       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
195   auto param1 =
196       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
197   auto add = builder.AddInstruction(
198       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
199   auto negate = builder.AddInstruction(
200       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
201 
202   auto computation = m->AddEntryComputation(builder.Build());
203   auto outer_fusion = computation->CreateFusionInstruction(
204       {negate, add}, HloInstruction::FusionKind::kLoop);
205   HloInstruction* fused_negate = outer_fusion->fused_expression_root();
206   ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate);
207   outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
208       {fused_negate}, HloInstruction::FusionKind::kLoop);
209 
210   EXPECT_THAT(computation->root_instruction(), op::Fusion());
211 
212   EXPECT_EQ(2, FusionCount(m.get()));
213   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
214   EXPECT_EQ(0, FusionCount(m.get()));
215 
216   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add()));
217 }
218 
219 }  // namespace
220 }  // namespace xla
221