1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
17 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
20 
21 namespace xla {
22 namespace {
23 
24 namespace op = xla::testing::opcode_matchers;
25 
26 class BatchDotSimplificationTest : public HloTestBase {};
27 
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_VectorVector)28 TEST_F(BatchDotSimplificationTest,
29        ElideSingleDegenerateBatchDotDim_VectorVector) {
30   const string hlo_text = R"(
31 HloModule BatchDot
32 
33 main {
34   a = f32[1,3] parameter(0)
35   b = f32[1,3] parameter(1)
36   ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
37 }
38 )";
39 
40   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
41                           ParseAndReturnVerifiedModule(hlo_text));
42   BatchDotSimplification pass;
43   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
44 
45   HloInstruction* root = m->entry_computation()->root_instruction();
46   EXPECT_THAT(root,
47               op::Reshape(op::Dot(
48                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
49                   /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)));
50 }
51 
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_MatrixVector)52 TEST_F(BatchDotSimplificationTest,
53        ElideSingleDegenerateBatchDotDim_MatrixVector) {
54   const string hlo_text = R"(
55 HloModule BatchDot
56 
57 main {
58   a = f32[1,9,3] parameter(0)
59   b = f32[1,3] parameter(1)
60   ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
61 }
62 )";
63 
64   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
65                           ParseAndReturnVerifiedModule(hlo_text));
66   BatchDotSimplification pass;
67   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
68 
69   HloInstruction* root = m->entry_computation()->root_instruction();
70   EXPECT_THAT(root,
71               op::Reshape(op::Dot(
72                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
73                   /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
74 }
75 
TEST_F(BatchDotSimplificationTest,ElideSingleDegenerateBatchDotDim_MatrixMatrix)76 TEST_F(BatchDotSimplificationTest,
77        ElideSingleDegenerateBatchDotDim_MatrixMatrix) {
78   const string hlo_text = R"(
79 HloModule BatchDot
80 
81 main {
82   a = f32[1,9,3] parameter(0)
83   b = f32[1,3,7] parameter(1)
84   ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
85 }
86 )";
87 
88   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
89                           ParseAndReturnVerifiedModule(hlo_text));
90   BatchDotSimplification pass;
91   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
92 
93   HloInstruction* root = m->entry_computation()->root_instruction();
94   EXPECT_THAT(root,
95               op::Reshape(op::Dot(
96                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
97                   /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
98 }
99 
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_VectorVector)100 TEST_F(BatchDotSimplificationTest,
101        ElideMultipleDegenerateBatchDotDims_VectorVector) {
102   const string hlo_text = R"(
103 HloModule BatchDot
104 
105 main {
106   a = f32[9,1,7,1,3] parameter(0)
107   b = f32[9,1,7,1,3] parameter(1)
108   ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4}
109 }
110 )";
111 
112   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
113                           ParseAndReturnVerifiedModule(hlo_text));
114   BatchDotSimplification pass;
115   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
116 
117   HloInstruction* root = m->entry_computation()->root_instruction();
118   EXPECT_THAT(root,
119               op::Reshape(op::Dot(
120                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
121                   /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2)));
122 }
123 
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_VectorMatrix)124 TEST_F(BatchDotSimplificationTest,
125        ElideMultipleDegenerateBatchDotDims_VectorMatrix) {
126   const string hlo_text = R"(
127 HloModule BatchDot
128 
129 main {
130   a = f32[9,1,7,1,3] parameter(0)
131   b = f32[9,1,7,1,20,3] parameter(1)
132   ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5}
133 }
134 )";
135 
136   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
137                           ParseAndReturnVerifiedModule(hlo_text));
138   BatchDotSimplification pass;
139   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
140 
141   HloInstruction* root = m->entry_computation()->root_instruction();
142   EXPECT_THAT(root,
143               op::Reshape(op::Dot(
144                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
145                   /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3)));
146 }
147 
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDims_MatrixMatrix)148 TEST_F(BatchDotSimplificationTest,
149        ElideMultipleDegenerateBatchDotDims_MatrixMatrix) {
150   const string hlo_text = R"(
151 HloModule BatchDot
152 
153 main {
154   a = f32[9,1,7,1,19,3] parameter(0)
155   b = f32[9,1,7,1,3,20] parameter(1)
156   ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4}
157 }
158 )";
159 
160   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
161                           ParseAndReturnVerifiedModule(hlo_text));
162   BatchDotSimplification pass;
163   ASSERT_TRUE(pass.Run(m.get()).ValueOrDie());
164 
165   HloInstruction* root = m->entry_computation()->root_instruction();
166   EXPECT_THAT(root,
167               op::Reshape(op::Dot(
168                   op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
169                   /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2)));
170 }
171 
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDimsNonContracting)172 TEST_F(BatchDotSimplificationTest,
173        ElideMultipleDegenerateBatchDotDimsNonContracting) {
174   const char* hlo_text = R"(
175 HloModule BatchDot
176 
177 main {
178   a = f32[1,101] parameter(0)
179   b = f32[1,101] parameter(1)
180   ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0},
181                                       lhs_contracting_dims={},
182                                       rhs_batch_dims={0},
183                                       rhs_contracting_dims={}
184 }
185 )";
186 
187   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
188                           ParseAndReturnVerifiedModule(hlo_text));
189   BatchDotSimplification pass;
190   ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
191 }
192 
TEST_F(BatchDotSimplificationTest,ElideMultipleDegenerateBatchDotDimsMultipleContracting)193 TEST_F(BatchDotSimplificationTest,
194        ElideMultipleDegenerateBatchDotDimsMultipleContracting) {
195   const char* hlo_text = R"(
196 HloModule BatchDot
197 
198 main {
199   lhs = f32[1,5,17,10,13] parameter(0)
200   rhs = f32[1,9,10,13,6,5] parameter(1)
201   ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0},
202                                             rhs_batch_dims={2,0},
203                                             lhs_contracting_dims={1,4},
204                                             rhs_contracting_dims={5,3}
205 }
206 )";
207 
208   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
209                           ParseAndReturnVerifiedModule(hlo_text));
210   BatchDotSimplification pass;
211   ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
212 }
213 
214 }  // namespace
215 }  // namespace xla
216