1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
17 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
20
21 namespace xla {
22 namespace {
23
24 namespace op = xla::testing::opcode_matchers;
25
26 class BatchDotSimplificationTest : public HloTestBase {};
27
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_VectorVector)28 TEST_F(BatchDotSimplificationTest,
29 ElideSingleDegenerateBatchDotDim_VectorVector) {
30 const string hlo_text = R"(
31 HloModule BatchDot
32
33 main {
34 a = f32[1,3] parameter(0)
35 b = f32[1,3] parameter(1)
36 ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
37 }
38 )";
39
40 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
41 ParseAndReturnVerifiedModule(hlo_text));
42 BatchDotSimplification pass;
43 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
44
45 HloInstruction* root = m->entry_computation()->root_instruction();
46 EXPECT_THAT(root,
47 op::Reshape(op::Dot(
48 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
49 /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)));
50 }
51
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_MatrixVector)52 TEST_F(BatchDotSimplificationTest,
53 ElideSingleDegenerateBatchDotDim_MatrixVector) {
54 const string hlo_text = R"(
55 HloModule BatchDot
56
57 main {
58 a = f32[1,9,3] parameter(0)
59 b = f32[1,3] parameter(1)
60 ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
61 }
62 )";
63
64 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
65 ParseAndReturnVerifiedModule(hlo_text));
66 BatchDotSimplification pass;
67 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
68
69 HloInstruction* root = m->entry_computation()->root_instruction();
70 EXPECT_THAT(root,
71 op::Reshape(op::Dot(
72 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
73 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
74 }
75
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_MatrixMatrix)76 TEST_F(BatchDotSimplificationTest,
77 ElideSingleDegenerateBatchDotDim_MatrixMatrix) {
78 const string hlo_text = R"(
79 HloModule BatchDot
80
81 main {
82 a = f32[1,9,3] parameter(0)
83 b = f32[1,3,7] parameter(1)
84 ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
85 }
86 )";
87
88 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
89 ParseAndReturnVerifiedModule(hlo_text));
90 BatchDotSimplification pass;
91 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
92
93 HloInstruction* root = m->entry_computation()->root_instruction();
94 EXPECT_THAT(root,
95 op::Reshape(op::Dot(
96 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
97 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
98 }
99
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_VectorVector)100 TEST_F(BatchDotSimplificationTest,
101 ElideMultipleDegenerateBatchDotDims_VectorVector) {
102 const string hlo_text = R"(
103 HloModule BatchDot
104
105 main {
106 a = f32[9,1,7,1,3] parameter(0)
107 b = f32[9,1,7,1,3] parameter(1)
108 ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4}
109 }
110 )";
111
112 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
113 ParseAndReturnVerifiedModule(hlo_text));
114 BatchDotSimplification pass;
115 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
116
117 HloInstruction* root = m->entry_computation()->root_instruction();
118 EXPECT_THAT(root,
119 op::Reshape(op::Dot(
120 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
121 /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2)));
122 }
123
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_VectorMatrix)124 TEST_F(BatchDotSimplificationTest,
125 ElideMultipleDegenerateBatchDotDims_VectorMatrix) {
126 const string hlo_text = R"(
127 HloModule BatchDot
128
129 main {
130 a = f32[9,1,7,1,3] parameter(0)
131 b = f32[9,1,7,1,20,3] parameter(1)
132 ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5}
133 }
134 )";
135
136 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
137 ParseAndReturnVerifiedModule(hlo_text));
138 BatchDotSimplification pass;
139 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
140
141 HloInstruction* root = m->entry_computation()->root_instruction();
142 EXPECT_THAT(root,
143 op::Reshape(op::Dot(
144 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
145 /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3)));
146 }
147
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_MatrixMatrix)148 TEST_F(BatchDotSimplificationTest,
149 ElideMultipleDegenerateBatchDotDims_MatrixMatrix) {
150 const string hlo_text = R"(
151 HloModule BatchDot
152
153 main {
154 a = f32[9,1,7,1,19,3] parameter(0)
155 b = f32[9,1,7,1,3,20] parameter(1)
156 ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4}
157 }
158 )";
159
160 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
161 ParseAndReturnVerifiedModule(hlo_text));
162 BatchDotSimplification pass;
163 ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
164
165 HloInstruction* root = m->entry_computation()->root_instruction();
166 EXPECT_THAT(root,
167 op::Reshape(op::Dot(
168 op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
169 /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2)));
170 }
171
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDimsNonContracting)172 TEST_F(BatchDotSimplificationTest,
173 ElideMultipleDegenerateBatchDotDimsNonContracting) {
174 const char* hlo_text = R"(
175 HloModule BatchDot
176
177 main {
178 a = f32[1,101] parameter(0)
179 b = f32[1,101] parameter(1)
180 ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0},
181 lhs_contracting_dims={},
182 rhs_batch_dims={0},
183 rhs_contracting_dims={}
184 }
185 )";
186
187 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
188 ParseAndReturnVerifiedModule(hlo_text));
189 BatchDotSimplification pass;
190 ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
191 }
192
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDimsMultipleContracting)193 TEST_F(BatchDotSimplificationTest,
194 ElideMultipleDegenerateBatchDotDimsMultipleContracting) {
195 const char* hlo_text = R"(
196 HloModule BatchDot
197
198 main {
199 lhs = f32[1,5,17,10,13] parameter(0)
200 rhs = f32[1,9,10,13,6,5] parameter(1)
201 ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0},
202 rhs_batch_dims={2,0},
203 lhs_contracting_dims={1,4},
204 rhs_contracting_dims={5,3}
205 }
206 )";
207
208 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
209 ParseAndReturnVerifiedModule(hlo_text));
210 BatchDotSimplification pass;
211 ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
212 }
213
214 } // namespace
215 } // namespace xla
216