1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/compiler/xla/tests/test_utils.h"
25
26 namespace xla {
27
28 class HloSubcomputationUnificationTest : public HloTestBase {
29 protected:
HloSubcomputationUnificationTest()30 HloSubcomputationUnificationTest() {}
31
CreateR0S32IdentityComputation()32 std::unique_ptr<HloComputation> CreateR0S32IdentityComputation() {
33 auto builder = HloComputation::Builder("Identity");
34 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
35 return builder.Build();
36 }
37
CreateR0S32AdditionComputation()38 std::unique_ptr<HloComputation> CreateR0S32AdditionComputation() {
39 auto builder = HloComputation::Builder("Addition");
40 auto x =
41 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
42 auto y =
43 builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y"));
44 builder.AddInstruction(
45 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
46 return builder.Build();
47 }
48
CreateR1S32AdditionComputation(const Shape & shape)49 std::unique_ptr<HloComputation> CreateR1S32AdditionComputation(
50 const Shape& shape) {
51 auto builder = HloComputation::Builder("Addition");
52 auto x =
53 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
54 auto y =
55 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y"));
56 builder.AddInstruction(
57 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y));
58 return builder.Build();
59 }
60
61 Shape r0s32_ = ShapeUtil::MakeShape(S32, {});
62 Shape r0f32_ = ShapeUtil::MakeShape(S32, {});
63 Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5});
64 Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3});
65 };
66
TEST_F(HloSubcomputationUnificationTest,UnifyIdentities)67 TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
68 auto module = CreateNewVerifiedModule();
69 auto builder = HloComputation::Builder(TestName());
70
71 auto callee1 =
72 module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
73 auto callee2 =
74 module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
75
76 auto constant = builder.AddInstruction(
77 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
78 auto x = builder.AddInstruction(
79 HloInstruction::CreateCall(r0s32_, {constant}, callee1));
80 auto y = builder.AddInstruction(
81 HloInstruction::CreateCall(r0s32_, {constant}, callee2));
82 builder.AddInstruction(
83 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
84
85 module->AddEntryComputation(builder.Build());
86
87 EXPECT_EQ(3, module->computation_count());
88 EXPECT_NE(x->to_apply(), y->to_apply());
89 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
90 EXPECT_EQ(2, module->computation_count());
91 EXPECT_EQ(x->to_apply(), y->to_apply());
92 }
93
TEST_F(HloSubcomputationUnificationTest,UnifyAdditions)94 TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
95 auto module = CreateNewVerifiedModule();
96 auto builder = HloComputation::Builder(TestName());
97
98 auto callee1 =
99 module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
100 auto callee2 =
101 module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
102
103 auto constant1 = builder.AddInstruction(
104 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
105 auto constant2 = builder.AddInstruction(
106 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
107 auto x = builder.AddInstruction(
108 HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
109 auto y = builder.AddInstruction(
110 HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2));
111 builder.AddInstruction(
112 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
113
114 module->AddEntryComputation(builder.Build());
115
116 EXPECT_EQ(3, module->computation_count());
117 EXPECT_NE(x->to_apply(), y->to_apply());
118 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
119 EXPECT_EQ(2, module->computation_count());
120 EXPECT_EQ(x->to_apply(), y->to_apply());
121 }
122
123 // Do not unify subcomputations with different parameter shapes.
TEST_F(HloSubcomputationUnificationTest,DifferentParameterShapes)124 TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
125 auto module = CreateNewUnverifiedModule();
126 auto builder = HloComputation::Builder(TestName());
127
128 auto callee1 =
129 module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_));
130 auto callee2 =
131 module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_));
132
133 auto param1 = builder.AddInstruction(
134 HloInstruction::CreateParameter(0, r1s32_5_, "param1"));
135 auto param2 = builder.AddInstruction(
136 HloInstruction::CreateParameter(1, r1s32_5_, "param2"));
137 auto x = builder.AddInstruction(
138 HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1));
139 auto y = builder.AddInstruction(
140 HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2));
141 builder.AddInstruction(HloInstruction::CreateConcatenate(
142 ShapeUtil::MakeShape(S32, {8}), {x, y}, 0));
143
144 module->AddEntryComputation(builder.Build());
145
146 EXPECT_EQ(3, module->computation_count());
147 EXPECT_NE(x->to_apply(), y->to_apply());
148 EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
149 EXPECT_EQ(3, module->computation_count());
150 EXPECT_NE(x->to_apply(), y->to_apply());
151 }
152
153 // Regression test for b/31466798. Checks that entry_computation is still valid
154 // after unification.
TEST_F(HloSubcomputationUnificationTest,TwoIdenticalComputations)155 TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
156 auto module = CreateNewVerifiedModule();
157 for (int i = 0; i < 2; ++i) {
158 HloComputation::Builder builder("pow");
159 auto x =
160 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
161 auto y =
162 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
163 builder.AddInstruction(
164 HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y));
165 if (i == 0) {
166 module->AddEmbeddedComputation(builder.Build());
167 } else {
168 module->AddEntryComputation(builder.Build());
169 }
170 }
171
172 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
173 EXPECT_EQ(1, module->computation_count());
174 EXPECT_EQ(*module->computations().begin(), module->entry_computation());
175 }
176
177 } // namespace xla
178