1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
17
18 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26
27 namespace xla {
28 namespace {
29
30 namespace m = match;
31
32 using StableSortExpanderTest = HloTestBase;
33
34 // Checks whether 'a' and 'b' are roots of equivalent computations, except that
35 // parameters 2 * i and 2 * i + 1 are switched.
IsSameComputationExceptParams(const HloInstruction * a,const HloInstruction * b)36 bool IsSameComputationExceptParams(const HloInstruction* a,
37 const HloInstruction* b) {
38 if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) {
39 return false;
40 }
41 if (a->opcode() == HloOpcode::kParameter) {
42 // Check that parameters were switched.
43 return a->parameter_number() == (b->parameter_number() ^ 1);
44 }
45 // If the operation has no operands, it should actually be the same.
46 if (a->operand_count() == 0) {
47 return a == b;
48 }
49 // Otherwise recursively compare all operands.
50 for (int64 i = 0; i < a->operand_count(); ++i) {
51 if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) {
52 return false;
53 }
54 }
55 return true;
56 }
57
58 // Check that the comparison computation has been modified to add a tie breaker
59 // using 'iota_parameter'.
CheckComputationHasTieBreaker(const HloInstruction * root,int64 iota_parameter)60 void CheckComputationHasTieBreaker(const HloInstruction* root,
61 int64 iota_parameter) {
62 // With the tie breaker, the root instruction should be
63 // Select(Eq(Comp(), CompReverse()), Lt(), Comp())
64 // with Comp() being the original comparison function, and CompReverse() being
65 // the copied comparison function where the parameters are reversed. Lt() is
66 // the tie breaker comparison using the Iota operand.
67 ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
68 ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare);
69 ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq);
70
71 // Check that the tie breaker instruction is correct.
72 EXPECT_THAT(root->operand(1),
73 GmockMatch(m::Lt(m::Parameter(iota_parameter * 2),
74 m::Parameter(iota_parameter * 2 + 1))));
75 EXPECT_EQ(root->operand(2), root->operand(0)->operand(0));
76
77 // Check that Comp() and CompReverse() are equivalent except that
78 // CompReverse() has reversed parameters.
79 EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0),
80 root->operand(0)->operand(1)));
81 }
82
TEST_F(StableSortExpanderTest,StabilizeSortReuseIotaOperand)83 TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
84 const char* hlo_string = R"(
85 HloModule permutation_sort
86
87 compare {
88 p.0.lhs = f32[] parameter(0)
89 p.0.rhs = f32[] parameter(1)
90 p.1.lhs = s32[] parameter(2)
91 p.1.rhs = s32[] parameter(3)
92 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
93 }
94
95 ENTRY sort_computation {
96 keys = f32[64,8732]{1,0} parameter(0)
97 values = s32[64,8732]{1,0} iota(), iota_dimension=1
98 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
99 dimensions={1}, to_apply=compare, is_stable=true
100 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
101 })";
102 TF_ASSERT_OK_AND_ASSIGN(auto module,
103 ParseAndReturnVerifiedModule(hlo_string));
104
105 StableSortExpander stabilizer;
106 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
107 auto root = module->entry_computation()->root_instruction();
108 EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
109 m::Sort(m::Parameter(0), m::Iota()), 0)));
110 CheckComputationHasTieBreaker(
111 root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
112 }
113
TEST_F(StableSortExpanderTest,StabilizeSortReuseIotaOperandComplicatedComparison)114 TEST_F(StableSortExpanderTest,
115 StabilizeSortReuseIotaOperandComplicatedComparison) {
116 const char* hlo_string = R"(
117 HloModule permutation_sort
118
119 compare {
120 p.0.lhs = f32[] parameter(0)
121 p.0.rhs = f32[] parameter(1)
122 p.1.lhs = s32[] parameter(2)
123 p.1.rhs = s32[] parameter(3)
124 max = u32[] constant(2147483647)
125 zero = s32[] constant(0)
126 lhs.signed = s32[] bitcast-convert(p.0.lhs)
127 lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
128 lhs.flipped = u32[] subtract(max, lhs.unsigned)
129 lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
130 lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT
131 lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
132 rhs.signed = s32[] bitcast-convert(p.0.rhs)
133 rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
134 rhs.flipped = u32[] subtract(max, rhs.unsigned)
135 rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
136 rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT
137 rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
138 ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT
139 }
140
141 ENTRY sort_computation {
142 keys = f32[64,8732]{1,0} parameter(0)
143 values = s32[64,8732]{1,0} iota(), iota_dimension=1
144 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
145 dimensions={1}, to_apply=compare, is_stable=true
146 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
147 })";
148 TF_ASSERT_OK_AND_ASSIGN(auto module,
149 ParseAndReturnVerifiedModule(hlo_string));
150
151 StableSortExpander stabilizer;
152 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
153 auto root = module->entry_computation()->root_instruction();
154 EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
155 m::Sort(m::Parameter(0), m::Iota()), 0)));
156 CheckComputationHasTieBreaker(
157 root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
158 }
159
TEST_F(StableSortExpanderTest,StabilizeSortAddIotaOperandAndChangeRoot)160 TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
161 const char* hlo_string = R"(
162 HloModule permutation_sort
163
164 compare {
165 p.0.lhs = f32[] parameter(0)
166 p.0.rhs = f32[] parameter(1)
167 p.1.lhs = s32[] parameter(2)
168 p.1.rhs = s32[] parameter(3)
169 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
170 }
171
172 ENTRY sort_computation {
173 keys = f32[64,8732]{1,0} parameter(0)
174 values = s32[64,8732]{1,0} parameter(1)
175 ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
176 dimensions={1}, to_apply=compare, is_stable=true
177 })";
178 TF_ASSERT_OK_AND_ASSIGN(auto module,
179 ParseAndReturnVerifiedModule(hlo_string));
180
181 StableSortExpander stabilizer;
182 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
183 auto root = module->entry_computation()->root_instruction();
184 EXPECT_THAT(
185 root, GmockMatch(m::Tuple(
186 m::GetTupleElement(
187 m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0),
188 m::GetTupleElement(
189 m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1))));
190 CheckComputationHasTieBreaker(
191 root->operand(0)->operand(0)->to_apply()->root_instruction(),
192 /*iota_parameter=*/2);
193 }
194
TEST_F(StableSortExpanderTest,HonorIsStableFlag)195 TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
196 const char* hlo_string = R"(
197 HloModule permutation_sort
198
199 compare {
200 p.0.lhs = f32[] parameter(0)
201 p.0.rhs = f32[] parameter(1)
202 p.1.lhs = s32[] parameter(2)
203 p.1.rhs = s32[] parameter(3)
204 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
205 }
206
207 ENTRY sort_computation {
208 keys = f32[64,8732]{1,0} parameter(0)
209 values = s32[64,8732]{1,0} iota(), iota_dimension=1
210 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
211 dimensions={1}, to_apply=compare, is_stable=false
212 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
213 })";
214 TF_ASSERT_OK_AND_ASSIGN(auto module,
215 ParseAndReturnVerifiedModule(hlo_string));
216
217 StableSortExpander stabilizer;
218 EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie());
219 }
220
TEST_F(StableSortExpanderTest,StabilizeSortDontReuseIotaOperandWrongDimension)221 TEST_F(StableSortExpanderTest,
222 StabilizeSortDontReuseIotaOperandWrongDimension) {
223 const char* hlo_string = R"(
224 HloModule permutation_sort
225
226 compare {
227 p.0.lhs = f32[] parameter(0)
228 p.0.rhs = f32[] parameter(1)
229 p.1.lhs = s32[] parameter(2)
230 p.1.rhs = s32[] parameter(3)
231 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
232 }
233
234 ENTRY sort_computation {
235 keys = f32[64,8732]{1,0} parameter(0)
236 values = s32[64,8732]{1,0} iota(), iota_dimension=0
237 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
238 dimensions={1}, to_apply=compare, is_stable=true
239 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
240 })";
241 TF_ASSERT_OK_AND_ASSIGN(auto module,
242 ParseAndReturnVerifiedModule(hlo_string));
243
244 StableSortExpander stabilizer;
245 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
246 // Simplify away the "wrapper" tuple around the new sort.
247 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
248 [](const Shape&, const Shape&) { return false; }));
249 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
250
251 auto root = module->entry_computation()->root_instruction();
252 EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
253 m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
254 CheckComputationHasTieBreaker(
255 root->operand(0)->to_apply()->root_instruction(),
256 /*iota_parameter=*/2);
257 }
258
TEST_F(StableSortExpanderTest,StabilizeSortDontReuseIotaOperandWrongType)259 TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
260 const char* hlo_string = R"(
261 HloModule permutation_sort
262
263 compare {
264 p.0.lhs = f32[] parameter(0)
265 p.0.rhs = f32[] parameter(1)
266 p.1.lhs = f32[] parameter(2)
267 p.1.rhs = f32[] parameter(3)
268 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
269 }
270
271 ENTRY sort_computation {
272 keys = f32[64,8732]{1,0} parameter(0)
273 values = f32[64,8732]{1,0} iota(), iota_dimension=1
274 sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values),
275 dimensions={1}, to_apply=compare, is_stable=true
276 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
277 })";
278 TF_ASSERT_OK_AND_ASSIGN(auto module,
279 ParseAndReturnVerifiedModule(hlo_string));
280
281 StableSortExpander stabilizer;
282 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
283 // Simplify away the "wrapper" tuple around the new sort.
284 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
285 [](const Shape&, const Shape&) { return false; }));
286 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
287
288 auto root = module->entry_computation()->root_instruction();
289 EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
290 m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
291 CheckComputationHasTieBreaker(
292 root->operand(0)->to_apply()->root_instruction(),
293 /*iota_parameter=*/2);
294 }
295
TEST_F(StableSortExpanderTest,StabilizeSortR1)296 TEST_F(StableSortExpanderTest, StabilizeSortR1) {
297 const char* hlo_string = R"(
298 HloModule permutation_sort
299
300 compare {
301 p.0.lhs = s32[] parameter(0)
302 p.0.rhs = s32[] parameter(1)
303 mask = s32[] constant(65535)
304 lhs = s32[] and(p.0.lhs, mask)
305 rhs = s32[] and(p.0.rhs, mask)
306 ROOT lt = pred[] compare(lhs, rhs), direction=LT
307 }
308
309 ENTRY sort_computation {
310 keys = s32[64,8732]{1,0} parameter(0)
311 ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
312 is_stable=true
313 })";
314 TF_ASSERT_OK_AND_ASSIGN(auto module,
315 ParseAndReturnVerifiedModule(hlo_string));
316
317 StableSortExpander stabilizer;
318 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
319 auto root = module->entry_computation()->root_instruction();
320 EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
321 m::Sort(m::Parameter(0), m::Iota()), 0)));
322 CheckComputationHasTieBreaker(
323 root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
324 }
325
TEST_F(StableSortExpanderTest,StabilizeSortR1NoRoot)326 TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
327 const char* hlo_string = R"(
328 HloModule permutation_sort
329
330 compare {
331 p.0.lhs = s32[] parameter(0)
332 p.0.rhs = s32[] parameter(1)
333 mask = s32[] constant(65535)
334 lhs = s32[] and(p.0.lhs, mask)
335 rhs = s32[] and(p.0.rhs, mask)
336 ROOT lt = pred[] compare(lhs, rhs), direction=LT
337 }
338
339 ENTRY sort_computation {
340 keys = s32[64,8732]{1,0} parameter(0)
341 sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
342 is_stable=true
343 ROOT neg = s32[64,8732]{1,0} negate(sort)
344 })";
345 TF_ASSERT_OK_AND_ASSIGN(auto module,
346 ParseAndReturnVerifiedModule(hlo_string));
347
348 StableSortExpander stabilizer;
349 EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
350 auto root = module->entry_computation()->root_instruction();
351 EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement(
352 m::Sort(m::Parameter(0), m::Iota()), 0))));
353 CheckComputationHasTieBreaker(
354 root->operand(0)->operand(0)->to_apply()->root_instruction(),
355 /*iota_parameter=*/1);
356 }
357
358 } // namespace
359 } // namespace xla
360