1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace xla {
29
30 namespace {
31
32 // TODO(b/69062148) Remove this code when all backends support BatchDot
33 // natively.
DecomposeBatchDot(HloInstruction * dot)34 Status DecomposeBatchDot(HloInstruction* dot) {
35 auto computation = dot->parent();
36 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
37 HloInstruction* lhs = dot->mutable_operand(0);
38 HloInstruction* rhs = dot->mutable_operand(1);
39 const Shape& lhs_shape = lhs->shape();
40 const Shape& rhs_shape = rhs->shape();
41 const Shape& dot_shape = dot->shape();
42
43 // ShapeInference should guarantee that lhs/rhs batch dimensions match.
44 CHECK_EQ(dnums.lhs_batch_dimensions_size(),
45 dnums.rhs_batch_dimensions_size());
46 const int64 num_batch_dims = dnums.lhs_batch_dimensions_size();
47 // Calculate total batch size (note that ShapeInference requires that
48 // the batch dimensions are most-major).
49 int64 batch_size = 1;
50 for (int i = 0; i < num_batch_dims; ++i) {
51 CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)),
52 rhs_shape.dimensions(dnums.rhs_batch_dimensions(i)));
53 batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i));
54 }
55
56 // Set lhs/rhs_transpose.
57 CHECK_EQ(1, dnums.lhs_contracting_dimensions_size());
58 const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0);
59 const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0;
60
61 CHECK_EQ(1, dnums.rhs_contracting_dimensions_size());
62 const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0);
63 const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1;
64
65 // Compute R3 and R3 shapes for lhs.
66 PrimitiveType lhs_type = lhs_shape.element_type();
67 const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0);
68 const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1);
69 Shape lhs_shape_r3 =
70 ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols});
71 Shape lhs_slice_shape_r3 =
72 ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols});
73 Shape lhs_slice_shape_r2 =
74 ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols});
75
76 // Compute R3 and R3 shapes for rhs.
77 PrimitiveType rhs_type = rhs_shape.element_type();
78 const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0);
79 const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1);
80 Shape rhs_shape_r3 =
81 ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols});
82 Shape rhs_slice_shape_r3 =
83 ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols});
84 Shape rhs_slice_shape_r2 =
85 ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols});
86
87 // Compute R3 and R3 shapes for dot output.
88 PrimitiveType dot_type = dot_shape.element_type();
89 const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0);
90 const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1);
91 Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols});
92 Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols});
93 Shape concat_shape_r3 =
94 ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols});
95
96 // Reshape lhs/rhs into R3.
97 auto lhs_r3 = computation->AddInstruction(
98 HloInstruction::CreateReshape(lhs_shape_r3, lhs));
99 auto rhs_r3 = computation->AddInstruction(
100 HloInstruction::CreateReshape(rhs_shape_r3, rhs));
101
102 // Loop through batch size, slicing out required lhs/rhs to compute each Dot.
103 std::vector<HloInstruction*> output_slices(batch_size);
104 for (int64 i = 0; i < batch_size; ++i) {
105 // Slice R3 shape from 'lhs' and reshape to R2.
106 auto lhs_slice_r3 = computation->AddInstruction(
107 HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0},
108 {i + 1, lhs_rows, lhs_cols}, {1, 1, 1}));
109 auto lhs_slice_r2 = computation->AddInstruction(
110 HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3));
111
112 // Slice R3 shape from 'rhs' and reshape to R2.
113 auto rhs_slice_r3 = computation->AddInstruction(
114 HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0},
115 {i + 1, rhs_rows, rhs_cols}, {1, 1, 1}));
116 auto rhs_slice_r2 = computation->AddInstruction(
117 HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3));
118
119 // Transpose lhs/rhs (if needed).
120 if (lhs_transpose) {
121 Shape lhs_slice_shape_r2_transpose =
122 ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows});
123 lhs_slice_r2 =
124 computation->AddInstruction(HloInstruction::CreateTranspose(
125 lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0}));
126 }
127 if (rhs_transpose) {
128 Shape rhs_slice_shape_r2_transpose =
129 ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows});
130 rhs_slice_r2 =
131 computation->AddInstruction(HloInstruction::CreateTranspose(
132 rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0}));
133 }
134
135 // Compute Dot of lhs/rhs R2 slices.
136 DotDimensionNumbers dot_dnums;
137 dot_dnums.add_lhs_contracting_dimensions(1);
138 dot_dnums.add_rhs_contracting_dimensions(0);
139 auto dot_r2 = computation->AddInstruction(
140 HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
141 dot_dnums, dot->precision_config()));
142
143 // Reshape Dot to R3 so we can concat along batch dimension.
144 auto dot_r3 = computation->AddInstruction(
145 HloInstruction::CreateReshape(dot_shape_r3, dot_r2));
146
147 output_slices[i] = dot_r3;
148 }
149
150 // Concatenate slices from 'output_slices' along batch dimension.
151 auto concat = computation->AddInstruction(
152 HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0));
153 // Reshape output 'new_dot' to original dimensions.
154 auto new_dot = computation->AddInstruction(
155 HloInstruction::CreateReshape(dot_shape, concat));
156
157 // Replace all uses of 'dot' in 'computation' with 'new_dot'.
158 return computation->ReplaceInstruction(dot, new_dot);
159 }
160
161 // Convert a dot into a canonical form where non-contracting and contracting
162 // dimensions are reshaped together and batch dimensions are the most major
163 // dimensions. The requires transposing and reshapes the lhs and rhs and
164 // reshaping the output batch to the original shape.
CanonicalizeDot(HloInstruction * original_dot)165 Status CanonicalizeDot(HloInstruction* original_dot) {
166 auto computation = original_dot->parent();
167 const auto& original_dnums = original_dot->dot_dimension_numbers();
168 const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size();
169 const int64 num_contracting_dims =
170 original_dnums.lhs_contracting_dimensions_size();
171
172 const auto& lhs_shape = original_dot->operand(0)->shape();
173 const int64 lhs_rank = lhs_shape.rank();
174 const int64 num_lhs_non_contracting_dims =
175 lhs_rank - num_batch_dims - num_contracting_dims;
176
177 std::vector<int64> lhs_non_contracting_dims;
178 lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
179 int64 lhs_contracting_size = 1;
180 int64 lhs_non_contracting_size = 1;
181 std::vector<int64> batch_dim_sizes;
182 batch_dim_sizes.reserve(num_batch_dims);
183 for (int64 i = 0; i < lhs_rank; ++i) {
184 if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
185 lhs_contracting_size *= lhs_shape.dimensions(i);
186 } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
187 i)) {
188 batch_dim_sizes.push_back(lhs_shape.dimensions(i));
189 } else {
190 lhs_non_contracting_dims.push_back(i);
191 lhs_non_contracting_size *= lhs_shape.dimensions(i);
192 }
193 }
194 // The canonical form of the lhs is
195 // [BatchDims, NonContractingDims, ContractingsDims]
196 std::vector<int64> lhs_transpose;
197 lhs_transpose.reserve(lhs_rank);
198 lhs_transpose.insert(lhs_transpose.end(),
199 original_dnums.lhs_batch_dimensions().begin(),
200 original_dnums.lhs_batch_dimensions().end());
201 lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
202 lhs_non_contracting_dims.end());
203 lhs_transpose.insert(lhs_transpose.end(),
204 original_dnums.lhs_contracting_dimensions().begin(),
205 original_dnums.lhs_contracting_dimensions().end());
206 HloInstruction* transposed_lhs =
207 computation->AddInstruction(HloInstruction::CreateTranspose(
208 ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
209 lhs_shape),
210 original_dot->mutable_operand(0), lhs_transpose));
211 std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
212 lhs_reshape_dims.push_back(lhs_non_contracting_size);
213 lhs_reshape_dims.push_back(lhs_contracting_size);
214 // Reshape the contracting and non-contracting dimensions together.
215 HloInstruction* reshaped_lhs =
216 computation->AddInstruction(HloInstruction::CreateReshape(
217 ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
218 transposed_lhs));
219
220 const auto& rhs_shape = original_dot->operand(1)->shape();
221 const int64 rhs_rank = rhs_shape.rank();
222 const int64 num_rhs_non_contracting_dims =
223 rhs_rank - num_batch_dims - num_contracting_dims;
224 std::vector<int64> rhs_non_contracting_dims;
225 rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
226 int64 rhs_non_contracting_size = 1;
227 int64 rhs_contracting_size = 1;
228 for (int64 i = 0; i < rhs_rank; ++i) {
229 if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
230 rhs_contracting_size *= rhs_shape.dimensions(i);
231 } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
232 i)) {
233 rhs_non_contracting_dims.push_back(i);
234 rhs_non_contracting_size *= rhs_shape.dimensions(i);
235 }
236 }
237
238 // The canonical form of the rhs is
239 // [BatchDims, ContractingsDims, NonContractingDims]
240 std::vector<int64> rhs_transpose;
241 rhs_transpose.reserve(rhs_rank);
242 rhs_transpose.insert(rhs_transpose.end(),
243 original_dnums.rhs_batch_dimensions().begin(),
244 original_dnums.rhs_batch_dimensions().end());
245 rhs_transpose.insert(rhs_transpose.end(),
246 original_dnums.rhs_contracting_dimensions().begin(),
247 original_dnums.rhs_contracting_dimensions().end());
248 rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
249 rhs_non_contracting_dims.end());
250 HloInstruction* transposed_rhs =
251 computation->AddInstruction(HloInstruction::CreateTranspose(
252 ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
253 rhs_shape),
254 original_dot->mutable_operand(1), rhs_transpose));
255
256 std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
257 rhs_reshape_dims.push_back(rhs_contracting_size);
258 rhs_reshape_dims.push_back(rhs_non_contracting_size);
259 // Reshape the contracting and non-contracting dimensions together.
260 HloInstruction* reshaped_rhs =
261 computation->AddInstruction(HloInstruction::CreateReshape(
262 ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
263 transposed_rhs));
264
265 std::vector<int64> dot_dims = batch_dim_sizes;
266 dot_dims.push_back(lhs_non_contracting_size);
267 dot_dims.push_back(rhs_non_contracting_size);
268
269 DotDimensionNumbers dot_dnums;
270 for (int64 i = 0; i < num_batch_dims; ++i) {
271 dot_dnums.add_lhs_batch_dimensions(i);
272 dot_dnums.add_rhs_batch_dimensions(i);
273 }
274 dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
275 dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
276
277 HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
278 ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
279 reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
280
281 return computation->ReplaceInstruction(
282 original_dot, computation->AddInstruction(HloInstruction::CreateReshape(
283 original_dot->shape(), dot)));
284 }
285
286 } // namespace
287
Run(HloModule * module)288 StatusOr<bool> DotDecomposer::Run(HloModule* module) {
289 XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString());
290 // Gather all Non-canonical Dot operations.
291 std::vector<HloInstruction*> non_canonical_dots;
292 for (auto* computation : module->MakeNonfusionComputations()) {
293 for (auto* instruction : computation->instructions()) {
294 if (instruction->opcode() != HloOpcode::kDot) {
295 continue;
296 }
297 const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
298 // A dot it not canonical if there are more than one contracting
299 // dimension.
300 if (dnums.lhs_contracting_dimensions_size() != 1) {
301 non_canonical_dots.push_back(instruction);
302 continue;
303 }
304 if (dnums.lhs_batch_dimensions().empty() &&
305 dnums.lhs_contracting_dimensions().empty()) {
306 non_canonical_dots.push_back(instruction);
307 continue;
308 }
309 if (dnums.lhs_batch_dimensions().empty()) {
310 continue;
311 }
312 std::vector<int64> canonical_batch_dims(
313 dnums.lhs_batch_dimensions_size());
314 absl::c_iota(canonical_batch_dims, 0);
315 if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
316 !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
317 non_canonical_dots.push_back(instruction);
318 }
319 }
320 }
321 bool changed = false;
322 for (auto* dot : non_canonical_dots) {
323 TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
324 changed = true;
325 }
326
327 if (decompose_batch_dot_) {
328 std::vector<HloInstruction*> batch_dots;
329 for (auto* computation : module->MakeNonfusionComputations()) {
330 for (auto* instruction : computation->instructions()) {
331 if (instruction->opcode() != HloOpcode::kDot) {
332 continue;
333 }
334 const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
335 if (!dnums.lhs_batch_dimensions().empty()) {
336 batch_dots.push_back(instruction);
337 }
338 }
339 }
340 // Decompose each batch Dot in 'batch_dots'.
341
342 for (auto* dot : batch_dots) {
343 TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
344 changed = true;
345 }
346 }
347 XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
348 return changed;
349 }
350
351 } // namespace xla
352