1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 namespace {
28 
29 namespace m = match;
30 
31 using SortSimplifierTest = HloTestBase;
32 
TEST_F(SortSimplifierTest,RemoveUnusedSortOperandArrayResult)33 TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) {
34   const char* hlo_string = R"(
35    HloModule permutation_sort
36 
37    compare {
38      p.0.lhs = f32[] parameter(0)
39      p.0.rhs = f32[] parameter(1)
40      p.1.lhs = s32[] parameter(2)
41      p.1.rhs = s32[] parameter(3)
42      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
43    }
44 
45    ENTRY sort_computation {
46      keys = f32[64,8732]{1,0} parameter(0)
47      values = s32[64,8732]{1,0} parameter(1)
48      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
49        dimensions={1}, to_apply=compare
50      ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
51    })";
52   TF_ASSERT_OK_AND_ASSIGN(auto module,
53                           ParseAndReturnVerifiedModule(hlo_string));
54 
55   SortSimplifier simplifier;
56   uint64 num_executions = 0;
57   do {
58     num_executions++;
59   } while (simplifier.Run(module.get()).ValueOrDie());
60   EXPECT_EQ(num_executions, 2);
61   auto root = module->entry_computation()->root_instruction();
62   EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0))));
63 }
64 
TEST_F(SortSimplifierTest,RemoveUnusedSortOperandTuple)65 TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) {
66   const char* hlo_string = R"(
67    HloModule permutation_sort
68 
69    compare {
70      p.0.lhs = f32[] parameter(0)
71      p.0.rhs = f32[] parameter(1)
72      p.1.lhs = s32[] parameter(2)
73      p.1.rhs = s32[] parameter(3)
74      p.2.lhs = u32[] parameter(4)
75      p.2.rhs = u32[] parameter(5)
76      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
77    }
78 
79    ENTRY sort_computation {
80      keys = f32[64,87] parameter(0)
81      values.0 = s32[64,87] parameter(1)
82      values.1 = u32[64,87] parameter(2)
83      sort = (f32[64,87], s32[64,87], u32[64,87]) sort(
84          keys, values.0, values.1),
85        dimensions={1}, to_apply=compare
86      gte.0 = f32[64,87] get-tuple-element(sort), index=0
87      gte.1 = u32[64,87] get-tuple-element(sort), index=2
88      ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1)
89    })";
90   TF_ASSERT_OK_AND_ASSIGN(auto module,
91                           ParseAndReturnVerifiedModule(hlo_string));
92 
93   SortSimplifier simplifier;
94   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
95   auto root = module->entry_computation()->root_instruction();
96   EXPECT_THAT(
97       root,
98       GmockMatch(m::Tuple(
99           m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0),
100           m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1))));
101 }
102 
TEST_F(SortSimplifierTest,DontRemoveUnusedSortKey)103 TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) {
104   const char* hlo_string = R"(
105    HloModule permutation_sort
106 
107    compare {
108      p.0.lhs = f32[] parameter(0)
109      p.0.rhs = f32[] parameter(1)
110      p.1.lhs = s32[] parameter(2)
111      p.1.rhs = s32[] parameter(3)
112      ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
113    }
114 
115    ENTRY sort_computation {
116      keys = f32[64,8732]{1,0} parameter(0)
117      values = s32[64,8732]{1,0} parameter(1)
118      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare
119      ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
120    })";
121   TF_ASSERT_OK_AND_ASSIGN(auto module,
122                           ParseAndReturnVerifiedModule(hlo_string));
123 
124   SortSimplifier simplifier;
125   EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
126 }
127 
TEST_F(SortSimplifierTest,RemoveUnusedFirstOperand)128 TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) {
129   const char* hlo_string = R"(
130    HloModule permutation_sort
131 
132    compare {
133      p.0.lhs = f32[] parameter(0)
134      p.0.rhs = f32[] parameter(1)
135      p.1.lhs = s32[] parameter(2)
136      p.1.rhs = s32[] parameter(3)
137      ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
138    }
139 
140    ENTRY sort_computation {
141      keys = f32[64,8732]{1,0} parameter(0)
142      values = s32[64,8732]{1,0} parameter(1)
143      sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
144        dimensions={1}, to_apply=compare
145      ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
146    })";
147   TF_ASSERT_OK_AND_ASSIGN(auto module,
148                           ParseAndReturnVerifiedModule(hlo_string));
149 
150   SortSimplifier simplifier;
151   uint64 num_executions = 0;
152   do {
153     num_executions++;
154   } while (simplifier.Run(module.get()).ValueOrDie());
155   EXPECT_EQ(num_executions, 2);
156   auto root = module->entry_computation()->root_instruction();
157   EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(1))));
158 }
159 }  // namespace
160 }  // namespace xla
161