1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
17 
18 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 
27 namespace xla {
28 namespace {
29 
30 namespace m = match;
31 
32 using StableSortExpanderTest = HloTestBase;
33 
34 // Checks whether 'a' and 'b' are roots of equivalent computations, except that
35 // parameters 2 * i and 2 * i + 1 are switched.
IsSameComputationExceptParams(const HloInstruction * a,const HloInstruction * b)36 bool IsSameComputationExceptParams(const HloInstruction* a,
37                                    const HloInstruction* b) {
38   if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) {
39     return false;
40   }
41   if (a->opcode() == HloOpcode::kParameter) {
42     // Check that parameters were switched.
43     return a->parameter_number() == (b->parameter_number() ^ 1);
44   }
45   // If the operation has no operands, it should actually be the same.
46   if (a->operand_count() == 0) {
47     return a == b;
48   }
49   // Otherwise recursively compare all operands.
50   for (int64 i = 0; i < a->operand_count(); ++i) {
51     if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) {
52       return false;
53     }
54   }
55   return true;
56 }
57 
58 // Check that the comparison computation has been modified to add a tie breaker
59 // using 'iota_parameter'.
CheckComputationHasTieBreaker(const HloInstruction * root,int64 iota_parameter)60 void CheckComputationHasTieBreaker(const HloInstruction* root,
61                                    int64 iota_parameter) {
62   // With the tie breaker, the root instruction should be
63   //   Select(Eq(Comp(), CompReverse()), Lt(), Comp())
64   // with Comp() being the original comparison function, and CompReverse() being
65   // the copied comparison function where the parameters are reversed. Lt() is
66   // the tie breaker comparison using the Iota operand.
67   ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
68   ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare);
69   ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq);
70 
71   // Check that the tie breaker instruction is correct.
72   EXPECT_THAT(root->operand(1),
73               GmockMatch(m::Lt(m::Parameter(iota_parameter * 2),
74                                m::Parameter(iota_parameter * 2 + 1))));
75   EXPECT_EQ(root->operand(2), root->operand(0)->operand(0));
76 
77   // Check that Comp() and CompReverse() are equivalent except that
78   // CompReverse() has reversed parameters.
79   EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0),
80                                             root->operand(0)->operand(1)));
81 }
82 
TEST_F(StableSortExpanderTest,StabilizeSortReuseIotaOperand)83 TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
84   const char* hlo_string = R"(
85    HloModule permutation_sort
86 
87    compare {
88      p.0.lhs = f32[] parameter(0)
89      p.0.rhs = f32[] parameter(1)
90      p.1.lhs = s32[] parameter(2)
91      p.1.rhs = s32[] parameter(3)
92      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
93    }
94 
95    ENTRY sort_computation {
96      keys = f32[64,8732]{1,0} parameter(0)
97      values = s32[64,8732]{1,0} iota(), iota_dimension=1
98      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
99        dimensions={1}, to_apply=compare, is_stable=true
100      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
101    })";
102   TF_ASSERT_OK_AND_ASSIGN(auto module,
103                           ParseAndReturnVerifiedModule(hlo_string));
104 
105   StableSortExpander stabilizer;
106   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
107   auto root = module->entry_computation()->root_instruction();
108   EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
109                         m::Sort(m::Parameter(0), m::Iota()), 0)));
110   CheckComputationHasTieBreaker(
111       root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
112 }
113 
TEST_F(StableSortExpanderTest,StabilizeSortReuseIotaOperandComplicatedComparison)114 TEST_F(StableSortExpanderTest,
115        StabilizeSortReuseIotaOperandComplicatedComparison) {
116   const char* hlo_string = R"(
117    HloModule permutation_sort
118 
119    compare {
120      p.0.lhs = f32[] parameter(0)
121      p.0.rhs = f32[] parameter(1)
122      p.1.lhs = s32[] parameter(2)
123      p.1.rhs = s32[] parameter(3)
124      max = u32[] constant(2147483647)
125      zero = s32[] constant(0)
126      lhs.signed = s32[] bitcast-convert(p.0.lhs)
127      lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
128      lhs.flipped = u32[] subtract(max, lhs.unsigned)
129      lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
130      lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT
131      lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
132      rhs.signed = s32[] bitcast-convert(p.0.rhs)
133      rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
134      rhs.flipped = u32[] subtract(max, rhs.unsigned)
135      rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
136      rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT
137      rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
138      ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT
139    }
140 
141    ENTRY sort_computation {
142      keys = f32[64,8732]{1,0} parameter(0)
143      values = s32[64,8732]{1,0} iota(), iota_dimension=1
144      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
145        dimensions={1}, to_apply=compare, is_stable=true
146      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
147    })";
148   TF_ASSERT_OK_AND_ASSIGN(auto module,
149                           ParseAndReturnVerifiedModule(hlo_string));
150 
151   StableSortExpander stabilizer;
152   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
153   auto root = module->entry_computation()->root_instruction();
154   EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
155                         m::Sort(m::Parameter(0), m::Iota()), 0)));
156   CheckComputationHasTieBreaker(
157       root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
158 }
159 
TEST_F(StableSortExpanderTest,StabilizeSortAddIotaOperandAndChangeRoot)160 TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
161   const char* hlo_string = R"(
162    HloModule permutation_sort
163 
164    compare {
165      p.0.lhs = f32[] parameter(0)
166      p.0.rhs = f32[] parameter(1)
167      p.1.lhs = s32[] parameter(2)
168      p.1.rhs = s32[] parameter(3)
169      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
170    }
171 
172    ENTRY sort_computation {
173      keys = f32[64,8732]{1,0} parameter(0)
174      values = s32[64,8732]{1,0} parameter(1)
175      ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
176        dimensions={1}, to_apply=compare, is_stable=true
177    })";
178   TF_ASSERT_OK_AND_ASSIGN(auto module,
179                           ParseAndReturnVerifiedModule(hlo_string));
180 
181   StableSortExpander stabilizer;
182   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
183   auto root = module->entry_computation()->root_instruction();
184   EXPECT_THAT(
185       root, GmockMatch(m::Tuple(
186                 m::GetTupleElement(
187                     m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0),
188                 m::GetTupleElement(
189                     m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1))));
190   CheckComputationHasTieBreaker(
191       root->operand(0)->operand(0)->to_apply()->root_instruction(),
192       /*iota_parameter=*/2);
193 }
194 
TEST_F(StableSortExpanderTest,HonorIsStableFlag)195 TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
196   const char* hlo_string = R"(
197    HloModule permutation_sort
198 
199    compare {
200      p.0.lhs = f32[] parameter(0)
201      p.0.rhs = f32[] parameter(1)
202      p.1.lhs = s32[] parameter(2)
203      p.1.rhs = s32[] parameter(3)
204      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
205    }
206 
207    ENTRY sort_computation {
208      keys = f32[64,8732]{1,0} parameter(0)
209      values = s32[64,8732]{1,0} iota(), iota_dimension=1
210      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
211        dimensions={1}, to_apply=compare, is_stable=false
212      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
213    })";
214   TF_ASSERT_OK_AND_ASSIGN(auto module,
215                           ParseAndReturnVerifiedModule(hlo_string));
216 
217   StableSortExpander stabilizer;
218   EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie());
219 }
220 
TEST_F(StableSortExpanderTest,StabilizeSortDontReuseIotaOperandWrongDimension)221 TEST_F(StableSortExpanderTest,
222        StabilizeSortDontReuseIotaOperandWrongDimension) {
223   const char* hlo_string = R"(
224    HloModule permutation_sort
225 
226    compare {
227      p.0.lhs = f32[] parameter(0)
228      p.0.rhs = f32[] parameter(1)
229      p.1.lhs = s32[] parameter(2)
230      p.1.rhs = s32[] parameter(3)
231      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
232    }
233 
234    ENTRY sort_computation {
235      keys = f32[64,8732]{1,0} parameter(0)
236      values = s32[64,8732]{1,0} iota(), iota_dimension=0
237      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
238        dimensions={1}, to_apply=compare, is_stable=true
239      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
240    })";
241   TF_ASSERT_OK_AND_ASSIGN(auto module,
242                           ParseAndReturnVerifiedModule(hlo_string));
243 
244   StableSortExpander stabilizer;
245   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
246   // Simplify away the "wrapper" tuple around the new sort.
247   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
248       [](const Shape&, const Shape&) { return false; }));
249   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
250 
251   auto root = module->entry_computation()->root_instruction();
252   EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
253                         m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
254   CheckComputationHasTieBreaker(
255       root->operand(0)->to_apply()->root_instruction(),
256       /*iota_parameter=*/2);
257 }
258 
TEST_F(StableSortExpanderTest,StabilizeSortDontReuseIotaOperandWrongType)259 TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
260   const char* hlo_string = R"(
261    HloModule permutation_sort
262 
263    compare {
264      p.0.lhs = f32[] parameter(0)
265      p.0.rhs = f32[] parameter(1)
266      p.1.lhs = f32[] parameter(2)
267      p.1.rhs = f32[] parameter(3)
268      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
269    }
270 
271    ENTRY sort_computation {
272      keys = f32[64,8732]{1,0} parameter(0)
273      values = f32[64,8732]{1,0} iota(), iota_dimension=1
274      sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values),
275        dimensions={1}, to_apply=compare, is_stable=true
276      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
277    })";
278   TF_ASSERT_OK_AND_ASSIGN(auto module,
279                           ParseAndReturnVerifiedModule(hlo_string));
280 
281   StableSortExpander stabilizer;
282   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
283   // Simplify away the "wrapper" tuple around the new sort.
284   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
285       [](const Shape&, const Shape&) { return false; }));
286   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
287 
288   auto root = module->entry_computation()->root_instruction();
289   EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
290                         m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
291   CheckComputationHasTieBreaker(
292       root->operand(0)->to_apply()->root_instruction(),
293       /*iota_parameter=*/2);
294 }
295 
TEST_F(StableSortExpanderTest,StabilizeSortR1)296 TEST_F(StableSortExpanderTest, StabilizeSortR1) {
297   const char* hlo_string = R"(
298    HloModule permutation_sort
299 
300    compare {
301      p.0.lhs = s32[] parameter(0)
302      p.0.rhs = s32[] parameter(1)
303      mask = s32[] constant(65535)
304      lhs = s32[] and(p.0.lhs, mask)
305      rhs = s32[] and(p.0.rhs, mask)
306      ROOT lt = pred[] compare(lhs, rhs), direction=LT
307    }
308 
309    ENTRY sort_computation {
310      keys = s32[64,8732]{1,0} parameter(0)
311      ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
312        is_stable=true
313    })";
314   TF_ASSERT_OK_AND_ASSIGN(auto module,
315                           ParseAndReturnVerifiedModule(hlo_string));
316 
317   StableSortExpander stabilizer;
318   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
319   auto root = module->entry_computation()->root_instruction();
320   EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
321                         m::Sort(m::Parameter(0), m::Iota()), 0)));
322   CheckComputationHasTieBreaker(
323       root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
324 }
325 
TEST_F(StableSortExpanderTest,StabilizeSortR1NoRoot)326 TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
327   const char* hlo_string = R"(
328    HloModule permutation_sort
329 
330    compare {
331      p.0.lhs = s32[] parameter(0)
332      p.0.rhs = s32[] parameter(1)
333      mask = s32[] constant(65535)
334      lhs = s32[] and(p.0.lhs, mask)
335      rhs = s32[] and(p.0.rhs, mask)
336      ROOT lt = pred[] compare(lhs, rhs), direction=LT
337    }
338 
339    ENTRY sort_computation {
340      keys = s32[64,8732]{1,0} parameter(0)
341      sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
342        is_stable=true
343      ROOT neg = s32[64,8732]{1,0} negate(sort)
344    })";
345   TF_ASSERT_OK_AND_ASSIGN(auto module,
346                           ParseAndReturnVerifiedModule(hlo_string));
347 
348   StableSortExpander stabilizer;
349   EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
350   auto root = module->entry_computation()->root_instruction();
351   EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement(
352                         m::Sort(m::Parameter(0), m::Iota()), 0))));
353   CheckComputationHasTieBreaker(
354       root->operand(0)->operand(0)->to_apply()->root_instruction(),
355       /*iota_parameter=*/1);
356 }
357 
358 }  // namespace
359 }  // namespace xla
360