1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27 namespace {
28
29 namespace m = match;
30
31 using SortSimplifierTest = HloTestBase;
32
TEST_F(SortSimplifierTest,RemoveUnusedSortOperandArrayResult)33 TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) {
34 const char* hlo_string = R"(
35 HloModule permutation_sort
36
37 compare {
38 p.0.lhs = f32[] parameter(0)
39 p.0.rhs = f32[] parameter(1)
40 p.1.lhs = s32[] parameter(2)
41 p.1.rhs = s32[] parameter(3)
42 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
43 }
44
45 ENTRY sort_computation {
46 keys = f32[64,8732]{1,0} parameter(0)
47 values = s32[64,8732]{1,0} parameter(1)
48 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
49 dimensions={1}, to_apply=compare
50 ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
51 })";
52 TF_ASSERT_OK_AND_ASSIGN(auto module,
53 ParseAndReturnVerifiedModule(hlo_string));
54
55 SortSimplifier simplifier;
56 uint64 num_executions = 0;
57 do {
58 num_executions++;
59 } while (simplifier.Run(module.get()).ValueOrDie());
60 EXPECT_EQ(num_executions, 2);
61 auto root = module->entry_computation()->root_instruction();
62 EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0))));
63 }
64
TEST_F(SortSimplifierTest,RemoveUnusedSortOperandTuple)65 TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) {
66 const char* hlo_string = R"(
67 HloModule permutation_sort
68
69 compare {
70 p.0.lhs = f32[] parameter(0)
71 p.0.rhs = f32[] parameter(1)
72 p.1.lhs = s32[] parameter(2)
73 p.1.rhs = s32[] parameter(3)
74 p.2.lhs = u32[] parameter(4)
75 p.2.rhs = u32[] parameter(5)
76 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
77 }
78
79 ENTRY sort_computation {
80 keys = f32[64,87] parameter(0)
81 values.0 = s32[64,87] parameter(1)
82 values.1 = u32[64,87] parameter(2)
83 sort = (f32[64,87], s32[64,87], u32[64,87]) sort(
84 keys, values.0, values.1),
85 dimensions={1}, to_apply=compare
86 gte.0 = f32[64,87] get-tuple-element(sort), index=0
87 gte.1 = u32[64,87] get-tuple-element(sort), index=2
88 ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1)
89 })";
90 TF_ASSERT_OK_AND_ASSIGN(auto module,
91 ParseAndReturnVerifiedModule(hlo_string));
92
93 SortSimplifier simplifier;
94 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
95 auto root = module->entry_computation()->root_instruction();
96 EXPECT_THAT(
97 root,
98 GmockMatch(m::Tuple(
99 m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0),
100 m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1))));
101 }
102
TEST_F(SortSimplifierTest,DontRemoveUnusedSortKey)103 TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) {
104 const char* hlo_string = R"(
105 HloModule permutation_sort
106
107 compare {
108 p.0.lhs = f32[] parameter(0)
109 p.0.rhs = f32[] parameter(1)
110 p.1.lhs = s32[] parameter(2)
111 p.1.rhs = s32[] parameter(3)
112 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
113 }
114
115 ENTRY sort_computation {
116 keys = f32[64,8732]{1,0} parameter(0)
117 values = s32[64,8732]{1,0} parameter(1)
118 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare
119 ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
120 })";
121 TF_ASSERT_OK_AND_ASSIGN(auto module,
122 ParseAndReturnVerifiedModule(hlo_string));
123
124 SortSimplifier simplifier;
125 EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
126 }
127
TEST_F(SortSimplifierTest,RemoveUnusedFirstOperand)128 TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) {
129 const char* hlo_string = R"(
130 HloModule permutation_sort
131
132 compare {
133 p.0.lhs = f32[] parameter(0)
134 p.0.rhs = f32[] parameter(1)
135 p.1.lhs = s32[] parameter(2)
136 p.1.rhs = s32[] parameter(3)
137 ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
138 }
139
140 ENTRY sort_computation {
141 keys = f32[64,8732]{1,0} parameter(0)
142 values = s32[64,8732]{1,0} parameter(1)
143 sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
144 dimensions={1}, to_apply=compare
145 ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
146 })";
147 TF_ASSERT_OK_AND_ASSIGN(auto module,
148 ParseAndReturnVerifiedModule(hlo_string));
149
150 SortSimplifier simplifier;
151 uint64 num_executions = 0;
152 do {
153 num_executions++;
154 } while (simplifier.Run(module.get()).ValueOrDie());
155 EXPECT_EQ(num_executions, 2);
156 auto root = module->entry_computation()->root_instruction();
157 EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(1))));
158 }
159 } // namespace
160 } // namespace xla
161