1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31
32 namespace op = xla::testing::opcode_matchers;
33
34 namespace xla {
35 namespace {
36
37 class TupleSimplifierTest : public HloTestBase {
38 protected:
Run(HloModule * module,bool change_expected)39 void Run(HloModule* module, bool change_expected) {
40 TupleSimplifier simplifier;
41 auto changed_status = simplifier.Run(module);
42 TF_ASSERT_OK(changed_status.status());
43 EXPECT_EQ(change_expected, changed_status.ValueOrDie());
44 }
Run(HloModule * module,bool change_expected,bool exclude_entry)45 void Run(HloModule* module, bool change_expected, bool exclude_entry) {
46 TupleSimplifier simplifier(exclude_entry);
47 auto changed_status = simplifier.Run(module);
48 TF_ASSERT_OK(changed_status.status());
49 EXPECT_EQ(change_expected, changed_status.ValueOrDie());
50 }
51
52 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
53 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
54 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
55 ShapeUtil::MakeShape(F32, {})});
56 };
57
TEST_F(TupleSimplifierTest,TupleOfParameters)58 TEST_F(TupleSimplifierTest, TupleOfParameters) {
59 // A Tuple constructed of a bunch of parameters should not be changed.
60 HloComputation::Builder builder(TestName());
61 HloInstruction* param0 = builder.AddInstruction(
62 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
63 HloInstruction* param1 = builder.AddInstruction(
64 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
65 HloInstruction* param2 = builder.AddInstruction(
66 HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
67 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
68 auto module = CreateNewVerifiedModule();
69 module->AddEntryComputation(builder.Build());
70
71 Run(module.get(), /*change_expected=*/false);
72 }
73
TEST_F(TupleSimplifierTest,GteOfTupleOfParameter)74 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
75 // A GTE of a tuple parameter should not be changed.
76 HloComputation::Builder builder(TestName());
77 HloInstruction* param = builder.AddInstruction(
78 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
79 builder.AddInstruction(
80 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
81 auto module = CreateNewVerifiedModule();
82 module->AddEntryComputation(builder.Build());
83
84 Run(module.get(), /*change_expected=*/false);
85 }
86
TEST_F(TupleSimplifierTest,GteOfTuple)87 TEST_F(TupleSimplifierTest, GteOfTuple) {
88 // A GTE of a Tuple should be short-circuited.
89 HloComputation::Builder builder(TestName());
90 HloInstruction* param0 = builder.AddInstruction(
91 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
92 HloInstruction* param1 = builder.AddInstruction(
93 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
94 HloInstruction* param2 = builder.AddInstruction(
95 HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
96 HloInstruction* tuple = builder.AddInstruction(
97 HloInstruction::CreateTuple({param0, param1, param2}));
98 HloInstruction* gte = builder.AddInstruction(
99 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
100
101 auto module = CreateNewVerifiedModule();
102 auto computation = module->AddEntryComputation(builder.Build());
103
104 EXPECT_THAT(computation->root_instruction(), gte);
105
106 Run(module.get(), /*change_expected=*/true);
107
108 EXPECT_THAT(computation->root_instruction(), param1);
109 }
110
TEST_F(TupleSimplifierTest,GteOfTupleChain)111 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
112 // Verify a chain of GTE/Tuple instructions is collapsed.
113 HloComputation::Builder builder(TestName());
114 HloInstruction* param = builder.AddInstruction(
115 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
116
117 const int kChainLength = 10;
118 HloInstruction* element = param;
119 for (int i = 0; i < kChainLength; ++i) {
120 HloInstruction* tuple = builder.AddInstruction(
121 HloInstruction::CreateTuple({element, element, element}));
122 element = builder.AddInstruction(
123 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
124 }
125 builder.AddInstruction(
126 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
127
128 auto module = CreateNewVerifiedModule();
129 auto computation = module->AddEntryComputation(builder.Build());
130
131 EXPECT_THAT(computation->root_instruction(),
132 op::Negate(op::GetTupleElement(op::Tuple())));
133
134 Run(module.get(), /*change_expected=*/true);
135
136 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
137 }
138
TEST_F(TupleSimplifierTest,NestedGteOfTuples)139 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
140 // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
141 // to some depth with a chain of Tuple instructions, then extracted with a
142 // chain of GTE instructions.
143 HloComputation::Builder builder(TestName());
144 HloInstruction* param = builder.AddInstruction(
145 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
146
147 const int kNestingDepth = 5;
148 HloInstruction* nested_tuple = param;
149 for (int i = 0; i < kNestingDepth; ++i) {
150 nested_tuple = builder.AddInstruction(
151 HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
152 }
153
154 HloInstruction* element = nested_tuple;
155 for (int i = 0; i < kNestingDepth; ++i) {
156 element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
157 ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
158 }
159
160 auto module = CreateNewVerifiedModule();
161 auto computation = module->AddEntryComputation(builder.Build());
162
163 EXPECT_THAT(computation->root_instruction(), element);
164
165 Run(module.get(), /*change_expected=*/true);
166
167 EXPECT_THAT(computation->root_instruction(), param);
168 }
169
TEST_F(TupleSimplifierTest,TupleOfGteInstructions)170 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
171 // Verify that a tuple constructed of GTE instructions operating on the same
172 // tuple are collapsed.
173 HloComputation::Builder builder(TestName());
174 HloInstruction* tuple_param = builder.AddInstruction(
175 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
176 HloInstruction* gte0 = builder.AddInstruction(
177 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
178 HloInstruction* gte1 = builder.AddInstruction(
179 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
180 HloInstruction* gte2 = builder.AddInstruction(
181 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
182 HloInstruction* tuple =
183 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
184
185 auto module = CreateNewVerifiedModule();
186 auto computation = module->AddEntryComputation(builder.Build());
187
188 EXPECT_THAT(computation->root_instruction(), tuple);
189
190 Run(module.get(), /*change_expected=*/true);
191
192 EXPECT_THAT(computation->root_instruction(), tuple_param);
193 }
194
TEST_F(TupleSimplifierTest,IncompatibleTuples)195 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
196 // Verify that a tuple->GTE->tuple construct is not simplified if the input
197 // and output tuple are not compatible shapes.
198 HloComputation::Builder builder(TestName());
199 HloInstruction* tuple_param = builder.AddInstruction(
200 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
201 HloInstruction* gte0 = builder.AddInstruction(
202 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
203 HloInstruction* gte1 = builder.AddInstruction(
204 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
205 // Output tuple has only two elements. Parameter tuple has three elements so
206 // simplification is not possible.
207 HloInstruction* tuple =
208 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
209
210 auto module = CreateNewVerifiedModule();
211 auto computation = module->AddEntryComputation(builder.Build());
212
213 EXPECT_THAT(computation->root_instruction(), tuple);
214
215 Run(module.get(), /*change_expected=*/false);
216
217 EXPECT_THAT(computation->root_instruction(), tuple);
218 }
219
TEST_F(TupleSimplifierTest,CanExcludeEntryComputation)220 TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
221 // Verify that the root computation can be excluded
222 auto module = CreateNewVerifiedModule();
223
224 HloInstruction* p0;
225 HloInstruction* p1;
226 HloComputation* c0;
227 HloComputation* c1;
228 HloComputation* entry;
229
230 {
231 HloComputation::Builder builder(TestName() + "_1");
232 p0 = builder.AddInstruction(
233 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
234 HloInstruction* gte0 = builder.AddInstruction(
235 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
236 HloInstruction* gte1 = builder.AddInstruction(
237 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
238 HloInstruction* gte2 = builder.AddInstruction(
239 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
240
241 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
242
243 c0 = module->AddEmbeddedComputation(builder.Build());
244 }
245 {
246 HloComputation::Builder builder(TestName() + "_2");
247 p1 = builder.AddInstruction(
248 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
249 HloInstruction* gte0 = builder.AddInstruction(
250 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
251 HloInstruction* gte1 = builder.AddInstruction(
252 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
253 HloInstruction* gte2 = builder.AddInstruction(
254 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
255
256 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
257
258 c1 = module->AddEmbeddedComputation(builder.Build());
259 }
260 {
261 HloComputation::Builder builder(TestName() + "_Entry");
262 HloInstruction* tuple_param = builder.AddInstruction(
263 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
264 HloInstruction* call0 = builder.AddInstruction(
265 HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
266 HloInstruction* call1 = builder.AddInstruction(
267 HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
268 HloInstruction* gte0 = builder.AddInstruction(
269 HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
270 HloInstruction* gte1 = builder.AddInstruction(
271 HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
272 HloInstruction* tuple0 =
273 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
274 HloInstruction* gte2 = builder.AddInstruction(
275 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
276 HloInstruction* gte3 = builder.AddInstruction(
277 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
278
279 builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
280
281 entry = module->AddEntryComputation(builder.Build());
282 }
283
284 Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true);
285
286 EXPECT_THAT(c0->root_instruction(), p0);
287 EXPECT_THAT(c1->root_instruction(), p1);
288 EXPECT_THAT(entry->instruction_count(), 9);
289 }
290
291 } // namespace
292 } // namespace xla
293