1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 
32 namespace op = xla::testing::opcode_matchers;
33 
34 namespace xla {
35 namespace {
36 
37 class TupleSimplifierTest : public HloTestBase {
38  protected:
Run(HloModule * module,bool change_expected)39   void Run(HloModule* module, bool change_expected) {
40     TupleSimplifier simplifier;
41     auto changed_status = simplifier.Run(module);
42     TF_ASSERT_OK(changed_status.status());
43     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
44   }
Run(HloModule * module,bool change_expected,bool exclude_entry)45   void Run(HloModule* module, bool change_expected, bool exclude_entry) {
46     TupleSimplifier simplifier(exclude_entry);
47     auto changed_status = simplifier.Run(module);
48     TF_ASSERT_OK(changed_status.status());
49     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
50   }
51 
52   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
53   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
54       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
55        ShapeUtil::MakeShape(F32, {})});
56 };
57 
TEST_F(TupleSimplifierTest,TupleOfParameters)58 TEST_F(TupleSimplifierTest, TupleOfParameters) {
59   // A Tuple constructed of a bunch of parameters should not be changed.
60   HloComputation::Builder builder(TestName());
61   HloInstruction* param0 = builder.AddInstruction(
62       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
63   HloInstruction* param1 = builder.AddInstruction(
64       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
65   HloInstruction* param2 = builder.AddInstruction(
66       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
67   builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
68   auto module = CreateNewVerifiedModule();
69   module->AddEntryComputation(builder.Build());
70 
71   Run(module.get(), /*change_expected=*/false);
72 }
73 
TEST_F(TupleSimplifierTest,GteOfTupleOfParameter)74 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
75   // A GTE of a tuple parameter should not be changed.
76   HloComputation::Builder builder(TestName());
77   HloInstruction* param = builder.AddInstruction(
78       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
79   builder.AddInstruction(
80       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
81   auto module = CreateNewVerifiedModule();
82   module->AddEntryComputation(builder.Build());
83 
84   Run(module.get(), /*change_expected=*/false);
85 }
86 
TEST_F(TupleSimplifierTest,GteOfTuple)87 TEST_F(TupleSimplifierTest, GteOfTuple) {
88   // A GTE of a Tuple should be short-circuited.
89   HloComputation::Builder builder(TestName());
90   HloInstruction* param0 = builder.AddInstruction(
91       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
92   HloInstruction* param1 = builder.AddInstruction(
93       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
94   HloInstruction* param2 = builder.AddInstruction(
95       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
96   HloInstruction* tuple = builder.AddInstruction(
97       HloInstruction::CreateTuple({param0, param1, param2}));
98   HloInstruction* gte = builder.AddInstruction(
99       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
100 
101   auto module = CreateNewVerifiedModule();
102   auto computation = module->AddEntryComputation(builder.Build());
103 
104   EXPECT_THAT(computation->root_instruction(), gte);
105 
106   Run(module.get(), /*change_expected=*/true);
107 
108   EXPECT_THAT(computation->root_instruction(), param1);
109 }
110 
TEST_F(TupleSimplifierTest,GteOfTupleChain)111 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
112   // Verify a chain of GTE/Tuple instructions is collapsed.
113   HloComputation::Builder builder(TestName());
114   HloInstruction* param = builder.AddInstruction(
115       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
116 
117   const int kChainLength = 10;
118   HloInstruction* element = param;
119   for (int i = 0; i < kChainLength; ++i) {
120     HloInstruction* tuple = builder.AddInstruction(
121         HloInstruction::CreateTuple({element, element, element}));
122     element = builder.AddInstruction(
123         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
124   }
125   builder.AddInstruction(
126       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
127 
128   auto module = CreateNewVerifiedModule();
129   auto computation = module->AddEntryComputation(builder.Build());
130 
131   EXPECT_THAT(computation->root_instruction(),
132               op::Negate(op::GetTupleElement(op::Tuple())));
133 
134   Run(module.get(), /*change_expected=*/true);
135 
136   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
137 }
138 
TEST_F(TupleSimplifierTest,NestedGteOfTuples)139 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
140   // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
141   // to some depth with a chain of Tuple instructions, then extracted with a
142   // chain of GTE instructions.
143   HloComputation::Builder builder(TestName());
144   HloInstruction* param = builder.AddInstruction(
145       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
146 
147   const int kNestingDepth = 5;
148   HloInstruction* nested_tuple = param;
149   for (int i = 0; i < kNestingDepth; ++i) {
150     nested_tuple = builder.AddInstruction(
151         HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
152   }
153 
154   HloInstruction* element = nested_tuple;
155   for (int i = 0; i < kNestingDepth; ++i) {
156     element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
157         ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
158   }
159 
160   auto module = CreateNewVerifiedModule();
161   auto computation = module->AddEntryComputation(builder.Build());
162 
163   EXPECT_THAT(computation->root_instruction(), element);
164 
165   Run(module.get(), /*change_expected=*/true);
166 
167   EXPECT_THAT(computation->root_instruction(), param);
168 }
169 
TEST_F(TupleSimplifierTest,TupleOfGteInstructions)170 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
171   // Verify that a tuple constructed of GTE instructions operating on the same
172   // tuple are collapsed.
173   HloComputation::Builder builder(TestName());
174   HloInstruction* tuple_param = builder.AddInstruction(
175       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
176   HloInstruction* gte0 = builder.AddInstruction(
177       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
178   HloInstruction* gte1 = builder.AddInstruction(
179       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
180   HloInstruction* gte2 = builder.AddInstruction(
181       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
182   HloInstruction* tuple =
183       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
184 
185   auto module = CreateNewVerifiedModule();
186   auto computation = module->AddEntryComputation(builder.Build());
187 
188   EXPECT_THAT(computation->root_instruction(), tuple);
189 
190   Run(module.get(), /*change_expected=*/true);
191 
192   EXPECT_THAT(computation->root_instruction(), tuple_param);
193 }
194 
TEST_F(TupleSimplifierTest,IncompatibleTuples)195 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
196   // Verify that a tuple->GTE->tuple construct is not simplified if the input
197   // and output tuple are not compatible shapes.
198   HloComputation::Builder builder(TestName());
199   HloInstruction* tuple_param = builder.AddInstruction(
200       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
201   HloInstruction* gte0 = builder.AddInstruction(
202       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
203   HloInstruction* gte1 = builder.AddInstruction(
204       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
205   // Output tuple has only two elements. Parameter tuple has three elements so
206   // simplification is not possible.
207   HloInstruction* tuple =
208       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
209 
210   auto module = CreateNewVerifiedModule();
211   auto computation = module->AddEntryComputation(builder.Build());
212 
213   EXPECT_THAT(computation->root_instruction(), tuple);
214 
215   Run(module.get(), /*change_expected=*/false);
216 
217   EXPECT_THAT(computation->root_instruction(), tuple);
218 }
219 
TEST_F(TupleSimplifierTest,CanExcludeEntryComputation)220 TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
221   //  Verify that the root computation can be excluded
222   auto module = CreateNewVerifiedModule();
223 
224   HloInstruction* p0;
225   HloInstruction* p1;
226   HloComputation* c0;
227   HloComputation* c1;
228   HloComputation* entry;
229 
230   {
231     HloComputation::Builder builder(TestName() + "_1");
232     p0 = builder.AddInstruction(
233         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
234     HloInstruction* gte0 = builder.AddInstruction(
235         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
236     HloInstruction* gte1 = builder.AddInstruction(
237         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
238     HloInstruction* gte2 = builder.AddInstruction(
239         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
240 
241     builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
242 
243     c0 = module->AddEmbeddedComputation(builder.Build());
244   }
245   {
246     HloComputation::Builder builder(TestName() + "_2");
247     p1 = builder.AddInstruction(
248         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
249     HloInstruction* gte0 = builder.AddInstruction(
250         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
251     HloInstruction* gte1 = builder.AddInstruction(
252         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
253     HloInstruction* gte2 = builder.AddInstruction(
254         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
255 
256     builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
257 
258     c1 = module->AddEmbeddedComputation(builder.Build());
259   }
260   {
261     HloComputation::Builder builder(TestName() + "_Entry");
262     HloInstruction* tuple_param = builder.AddInstruction(
263         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
264     HloInstruction* call0 = builder.AddInstruction(
265         HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
266     HloInstruction* call1 = builder.AddInstruction(
267         HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
268     HloInstruction* gte0 = builder.AddInstruction(
269         HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
270     HloInstruction* gte1 = builder.AddInstruction(
271         HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
272     HloInstruction* tuple0 =
273         builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
274     HloInstruction* gte2 = builder.AddInstruction(
275         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
276     HloInstruction* gte3 = builder.AddInstruction(
277         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
278 
279     builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
280 
281     entry = module->AddEntryComputation(builder.Build());
282   }
283 
284   Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true);
285 
286   EXPECT_THAT(c0->root_instruction(), p0);
287   EXPECT_THAT(c1->root_instruction(), p1);
288   EXPECT_THAT(entry->instruction_count(), 9);
289 }
290 
291 }  // namespace
292 }  // namespace xla
293