1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace xla {
29 
30 namespace {
31 
32 // TODO(b/69062148) Remove this code when all backends support BatchDot
33 // natively.
DecomposeBatchDot(HloInstruction * dot)34 Status DecomposeBatchDot(HloInstruction* dot) {
35   auto computation = dot->parent();
36   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
37   HloInstruction* lhs = dot->mutable_operand(0);
38   HloInstruction* rhs = dot->mutable_operand(1);
39   const Shape& lhs_shape = lhs->shape();
40   const Shape& rhs_shape = rhs->shape();
41   const Shape& dot_shape = dot->shape();
42 
43   // ShapeInference should guarantee that lhs/rhs batch dimensions match.
44   CHECK_EQ(dnums.lhs_batch_dimensions_size(),
45            dnums.rhs_batch_dimensions_size());
46   const int64 num_batch_dims = dnums.lhs_batch_dimensions_size();
47   // Calculate total batch size (note that ShapeInference requires that
48   // the batch dimensions are most-major).
49   int64 batch_size = 1;
50   for (int i = 0; i < num_batch_dims; ++i) {
51     CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)),
52              rhs_shape.dimensions(dnums.rhs_batch_dimensions(i)));
53     batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i));
54   }
55 
56   // Set lhs/rhs_transpose.
57   CHECK_EQ(1, dnums.lhs_contracting_dimensions_size());
58   const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0);
59   const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0;
60 
61   CHECK_EQ(1, dnums.rhs_contracting_dimensions_size());
62   const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0);
63   const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1;
64 
65   // Compute R3 and R3 shapes for lhs.
66   PrimitiveType lhs_type = lhs_shape.element_type();
67   const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0);
68   const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1);
69   Shape lhs_shape_r3 =
70       ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols});
71   Shape lhs_slice_shape_r3 =
72       ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols});
73   Shape lhs_slice_shape_r2 =
74       ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols});
75 
76   // Compute R3 and R3 shapes for rhs.
77   PrimitiveType rhs_type = rhs_shape.element_type();
78   const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0);
79   const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1);
80   Shape rhs_shape_r3 =
81       ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols});
82   Shape rhs_slice_shape_r3 =
83       ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols});
84   Shape rhs_slice_shape_r2 =
85       ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols});
86 
87   // Compute R3 and R3 shapes for dot output.
88   PrimitiveType dot_type = dot_shape.element_type();
89   const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0);
90   const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1);
91   Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols});
92   Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols});
93   Shape concat_shape_r3 =
94       ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols});
95 
96   // Reshape lhs/rhs into R3.
97   auto lhs_r3 = computation->AddInstruction(
98       HloInstruction::CreateReshape(lhs_shape_r3, lhs));
99   auto rhs_r3 = computation->AddInstruction(
100       HloInstruction::CreateReshape(rhs_shape_r3, rhs));
101 
102   // Loop through batch size, slicing out required lhs/rhs to compute each Dot.
103   std::vector<HloInstruction*> output_slices(batch_size);
104   for (int64 i = 0; i < batch_size; ++i) {
105     // Slice R3 shape from 'lhs' and reshape to R2.
106     auto lhs_slice_r3 = computation->AddInstruction(
107         HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0},
108                                     {i + 1, lhs_rows, lhs_cols}, {1, 1, 1}));
109     auto lhs_slice_r2 = computation->AddInstruction(
110         HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3));
111 
112     // Slice R3 shape from 'rhs' and reshape to R2.
113     auto rhs_slice_r3 = computation->AddInstruction(
114         HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0},
115                                     {i + 1, rhs_rows, rhs_cols}, {1, 1, 1}));
116     auto rhs_slice_r2 = computation->AddInstruction(
117         HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3));
118 
119     // Transpose lhs/rhs (if needed).
120     if (lhs_transpose) {
121       Shape lhs_slice_shape_r2_transpose =
122           ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows});
123       lhs_slice_r2 =
124           computation->AddInstruction(HloInstruction::CreateTranspose(
125               lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0}));
126     }
127     if (rhs_transpose) {
128       Shape rhs_slice_shape_r2_transpose =
129           ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows});
130       rhs_slice_r2 =
131           computation->AddInstruction(HloInstruction::CreateTranspose(
132               rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0}));
133     }
134 
135     // Compute Dot of lhs/rhs R2 slices.
136     DotDimensionNumbers dot_dnums;
137     dot_dnums.add_lhs_contracting_dimensions(1);
138     dot_dnums.add_rhs_contracting_dimensions(0);
139     auto dot_r2 = computation->AddInstruction(
140         HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
141                                   dot_dnums, dot->precision_config()));
142 
143     // Reshape Dot to R3 so we can concat along batch dimension.
144     auto dot_r3 = computation->AddInstruction(
145         HloInstruction::CreateReshape(dot_shape_r3, dot_r2));
146 
147     output_slices[i] = dot_r3;
148   }
149 
150   // Concatenate slices from 'output_slices' along batch dimension.
151   auto concat = computation->AddInstruction(
152       HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0));
153   // Reshape output 'new_dot' to original dimensions.
154   auto new_dot = computation->AddInstruction(
155       HloInstruction::CreateReshape(dot_shape, concat));
156 
157   // Replace all uses of 'dot' in 'computation' with 'new_dot'.
158   return computation->ReplaceInstruction(dot, new_dot);
159 }
160 
161 // Convert a dot into a canonical form where non-contracting and contracting
162 // dimensions are reshaped together and batch dimensions are the most major
163 // dimensions. The requires transposing and reshapes the lhs and rhs and
164 // reshaping the output batch to the original shape.
CanonicalizeDot(HloInstruction * original_dot)165 Status CanonicalizeDot(HloInstruction* original_dot) {
166   auto computation = original_dot->parent();
167   const auto& original_dnums = original_dot->dot_dimension_numbers();
168   const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size();
169   const int64 num_contracting_dims =
170       original_dnums.lhs_contracting_dimensions_size();
171 
172   const auto& lhs_shape = original_dot->operand(0)->shape();
173   const int64 lhs_rank = lhs_shape.rank();
174   const int64 num_lhs_non_contracting_dims =
175       lhs_rank - num_batch_dims - num_contracting_dims;
176 
177   std::vector<int64> lhs_non_contracting_dims;
178   lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
179   int64 lhs_contracting_size = 1;
180   int64 lhs_non_contracting_size = 1;
181   std::vector<int64> batch_dim_sizes;
182   batch_dim_sizes.reserve(num_batch_dims);
183   for (int64 i = 0; i < lhs_rank; ++i) {
184     if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
185       lhs_contracting_size *= lhs_shape.dimensions(i);
186     } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
187                                      i)) {
188       batch_dim_sizes.push_back(lhs_shape.dimensions(i));
189     } else {
190       lhs_non_contracting_dims.push_back(i);
191       lhs_non_contracting_size *= lhs_shape.dimensions(i);
192     }
193   }
194   // The canonical form of the lhs is
195   // [BatchDims, NonContractingDims, ContractingsDims]
196   std::vector<int64> lhs_transpose;
197   lhs_transpose.reserve(lhs_rank);
198   lhs_transpose.insert(lhs_transpose.end(),
199                        original_dnums.lhs_batch_dimensions().begin(),
200                        original_dnums.lhs_batch_dimensions().end());
201   lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
202                        lhs_non_contracting_dims.end());
203   lhs_transpose.insert(lhs_transpose.end(),
204                        original_dnums.lhs_contracting_dimensions().begin(),
205                        original_dnums.lhs_contracting_dimensions().end());
206   HloInstruction* transposed_lhs =
207       computation->AddInstruction(HloInstruction::CreateTranspose(
208           ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
209                                        lhs_shape),
210           original_dot->mutable_operand(0), lhs_transpose));
211   std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
212   lhs_reshape_dims.push_back(lhs_non_contracting_size);
213   lhs_reshape_dims.push_back(lhs_contracting_size);
214   // Reshape the contracting and non-contracting dimensions together.
215   HloInstruction* reshaped_lhs =
216       computation->AddInstruction(HloInstruction::CreateReshape(
217           ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
218           transposed_lhs));
219 
220   const auto& rhs_shape = original_dot->operand(1)->shape();
221   const int64 rhs_rank = rhs_shape.rank();
222   const int64 num_rhs_non_contracting_dims =
223       rhs_rank - num_batch_dims - num_contracting_dims;
224   std::vector<int64> rhs_non_contracting_dims;
225   rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
226   int64 rhs_non_contracting_size = 1;
227   int64 rhs_contracting_size = 1;
228   for (int64 i = 0; i < rhs_rank; ++i) {
229     if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
230       rhs_contracting_size *= rhs_shape.dimensions(i);
231     } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
232                                       i)) {
233       rhs_non_contracting_dims.push_back(i);
234       rhs_non_contracting_size *= rhs_shape.dimensions(i);
235     }
236   }
237 
238   // The canonical form of the rhs is
239   // [BatchDims, ContractingsDims, NonContractingDims]
240   std::vector<int64> rhs_transpose;
241   rhs_transpose.reserve(rhs_rank);
242   rhs_transpose.insert(rhs_transpose.end(),
243                        original_dnums.rhs_batch_dimensions().begin(),
244                        original_dnums.rhs_batch_dimensions().end());
245   rhs_transpose.insert(rhs_transpose.end(),
246                        original_dnums.rhs_contracting_dimensions().begin(),
247                        original_dnums.rhs_contracting_dimensions().end());
248   rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
249                        rhs_non_contracting_dims.end());
250   HloInstruction* transposed_rhs =
251       computation->AddInstruction(HloInstruction::CreateTranspose(
252           ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
253                                        rhs_shape),
254           original_dot->mutable_operand(1), rhs_transpose));
255 
256   std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
257   rhs_reshape_dims.push_back(rhs_contracting_size);
258   rhs_reshape_dims.push_back(rhs_non_contracting_size);
259   // Reshape the contracting and non-contracting dimensions together.
260   HloInstruction* reshaped_rhs =
261       computation->AddInstruction(HloInstruction::CreateReshape(
262           ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
263           transposed_rhs));
264 
265   std::vector<int64> dot_dims = batch_dim_sizes;
266   dot_dims.push_back(lhs_non_contracting_size);
267   dot_dims.push_back(rhs_non_contracting_size);
268 
269   DotDimensionNumbers dot_dnums;
270   for (int64 i = 0; i < num_batch_dims; ++i) {
271     dot_dnums.add_lhs_batch_dimensions(i);
272     dot_dnums.add_rhs_batch_dimensions(i);
273   }
274   dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
275   dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
276 
277   HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
278       ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
279       reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
280 
281   return computation->ReplaceInstruction(
282       original_dot, computation->AddInstruction(HloInstruction::CreateReshape(
283                         original_dot->shape(), dot)));
284 }
285 
286 }  // namespace
287 
Run(HloModule * module)288 StatusOr<bool> DotDecomposer::Run(HloModule* module) {
289   XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString());
290   // Gather all Non-canonical Dot operations.
291   std::vector<HloInstruction*> non_canonical_dots;
292   for (auto* computation : module->MakeNonfusionComputations()) {
293     for (auto* instruction : computation->instructions()) {
294       if (instruction->opcode() != HloOpcode::kDot) {
295         continue;
296       }
297       const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
298       // A dot it not canonical if there are more than one contracting
299       // dimension.
300       if (dnums.lhs_contracting_dimensions_size() != 1) {
301         non_canonical_dots.push_back(instruction);
302         continue;
303       }
304       if (dnums.lhs_batch_dimensions().empty() &&
305           dnums.lhs_contracting_dimensions().empty()) {
306         non_canonical_dots.push_back(instruction);
307         continue;
308       }
309       if (dnums.lhs_batch_dimensions().empty()) {
310         continue;
311       }
312       std::vector<int64> canonical_batch_dims(
313           dnums.lhs_batch_dimensions_size());
314       absl::c_iota(canonical_batch_dims, 0);
315       if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
316           !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
317         non_canonical_dots.push_back(instruction);
318       }
319     }
320   }
321   bool changed = false;
322   for (auto* dot : non_canonical_dots) {
323     TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
324     changed = true;
325   }
326 
327   if (decompose_batch_dot_) {
328     std::vector<HloInstruction*> batch_dots;
329     for (auto* computation : module->MakeNonfusionComputations()) {
330       for (auto* instruction : computation->instructions()) {
331         if (instruction->opcode() != HloOpcode::kDot) {
332           continue;
333         }
334         const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
335         if (!dnums.lhs_batch_dimensions().empty()) {
336           batch_dots.push_back(instruction);
337         }
338       }
339     }
340     // Decompose each batch Dot in 'batch_dots'.
341 
342     for (auto* dot : batch_dots) {
343       TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
344       changed = true;
345     }
346   }
347   XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
348   return changed;
349 }
350 
351 }  // namespace xla
352