1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 
32 namespace op = xla::testing::opcode_matchers;
33 
34 namespace xla {
35 namespace {
36 
37 class TupleSimplifierTest : public HloTestBase {
38  protected:
39   void Run(HloModule* module, bool change_expected) {
40     TupleSimplifier simplifier;
41     auto changed_status = simplifier.Run(module);
42     TF_ASSERT_OK(changed_status.status());
43     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
44   }
45 
46   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
47   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
48       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
49        ShapeUtil::MakeShape(F32, {})});
50 };
51 
52 TEST_F(TupleSimplifierTest, TupleOfParameters) {
53   // A Tuple constructed of a bunch of parameters should not be changed.
54   HloComputation::Builder builder(TestName());
55   HloInstruction* param0 = builder.AddInstruction(
56       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
57   HloInstruction* param1 = builder.AddInstruction(
58       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
59   HloInstruction* param2 = builder.AddInstruction(
60       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
61   builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
62   auto module = CreateNewModule();
63   module->AddEntryComputation(builder.Build());
64 
65   Run(module.get(), /*change_expected=*/false);
66 }
67 
68 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
69   // A GTE of a tuple parameter should not be changed.
70   HloComputation::Builder builder(TestName());
71   HloInstruction* param = builder.AddInstruction(
72       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
73   builder.AddInstruction(
74       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
75   auto module = CreateNewModule();
76   module->AddEntryComputation(builder.Build());
77 
78   Run(module.get(), /*change_expected=*/false);
79 }
80 
81 TEST_F(TupleSimplifierTest, GteOfTuple) {
82   // A GTE of a Tuple should be short-circuited.
83   HloComputation::Builder builder(TestName());
84   HloInstruction* param0 = builder.AddInstruction(
85       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
86   HloInstruction* param1 = builder.AddInstruction(
87       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
88   HloInstruction* param2 = builder.AddInstruction(
89       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
90   HloInstruction* tuple = builder.AddInstruction(
91       HloInstruction::CreateTuple({param0, param1, param2}));
92   HloInstruction* gte = builder.AddInstruction(
93       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
94 
95   auto module = CreateNewModule();
96   auto computation = module->AddEntryComputation(builder.Build());
97 
98   EXPECT_THAT(computation->root_instruction(), gte);
99 
100   Run(module.get(), /*change_expected=*/true);
101 
102   EXPECT_THAT(computation->root_instruction(), param1);
103 }
104 
105 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
106   // Verify a chain of GTE/Tuple instructions is collapsed.
107   HloComputation::Builder builder(TestName());
108   HloInstruction* param = builder.AddInstruction(
109       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
110 
111   const int kChainLength = 10;
112   HloInstruction* element = param;
113   for (int i = 0; i < kChainLength; ++i) {
114     HloInstruction* tuple = builder.AddInstruction(
115         HloInstruction::CreateTuple({element, element, element}));
116     element = builder.AddInstruction(
117         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
118   }
119   builder.AddInstruction(
120       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
121 
122   auto module = CreateNewModule();
123   auto computation = module->AddEntryComputation(builder.Build());
124 
125   EXPECT_THAT(computation->root_instruction(),
126               op::Negate(op::GetTupleElement(op::Tuple())));
127 
128   Run(module.get(), /*change_expected=*/true);
129 
130   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
131 }
132 
133 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
134   // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
135   // to some depth with a chain of Tuple instructions, then extracted with a
136   // chain of GTE instructions.
137   HloComputation::Builder builder(TestName());
138   HloInstruction* param = builder.AddInstruction(
139       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
140 
141   const int kNestingDepth = 5;
142   HloInstruction* nested_tuple = param;
143   for (int i = 0; i < kNestingDepth; ++i) {
144     nested_tuple = builder.AddInstruction(
145         HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
146   }
147 
148   HloInstruction* element = nested_tuple;
149   for (int i = 0; i < kNestingDepth; ++i) {
150     element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
151         ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
152   }
153 
154   auto module = CreateNewModule();
155   auto computation = module->AddEntryComputation(builder.Build());
156 
157   EXPECT_THAT(computation->root_instruction(), element);
158 
159   Run(module.get(), /*change_expected=*/true);
160 
161   EXPECT_THAT(computation->root_instruction(), param);
162 }
163 
164 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
165   // Verify that a tuple constructed of GTE instructions operating on the same
166   // tuple are collapsed.
167   HloComputation::Builder builder(TestName());
168   HloInstruction* tuple_param = builder.AddInstruction(
169       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
170   HloInstruction* gte0 = builder.AddInstruction(
171       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
172   HloInstruction* gte1 = builder.AddInstruction(
173       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
174   HloInstruction* gte2 = builder.AddInstruction(
175       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
176   HloInstruction* tuple =
177       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
178 
179   auto module = CreateNewModule();
180   auto computation = module->AddEntryComputation(builder.Build());
181 
182   EXPECT_THAT(computation->root_instruction(), tuple);
183 
184   Run(module.get(), /*change_expected=*/true);
185 
186   EXPECT_THAT(computation->root_instruction(), tuple_param);
187 }
188 
189 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
190   // Verify that a tuple->GTE->tuple construct is not simplified if the input
191   // and output tuple are not compatible shapes.
192   HloComputation::Builder builder(TestName());
193   HloInstruction* tuple_param = builder.AddInstruction(
194       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
195   HloInstruction* gte0 = builder.AddInstruction(
196       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
197   HloInstruction* gte1 = builder.AddInstruction(
198       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
199   // Output tuple has only two elements. Parameter tuple has three elements so
200   // simplification is not possible.
201   HloInstruction* tuple =
202       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
203 
204   auto module = CreateNewModule();
205   auto computation = module->AddEntryComputation(builder.Build());
206 
207   EXPECT_THAT(computation->root_instruction(), tuple);
208 
209   Run(module.get(), /*change_expected=*/false);
210 
211   EXPECT_THAT(computation->root_instruction(), tuple);
212 }
213 
214 }  // namespace
215 }  // namespace xla
216