1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/algorithm/container.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
30 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
31 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/gtl/cleanup.h"
37 #include "tensorflow/core/platform/numbers.h"
38
39 namespace xla {
40 namespace spmd {
41
HandleDot(HloInstruction * hlo)42 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
43 DotConvDimsMapping mapping;
44 const auto& dnums = hlo->dot_dimension_numbers();
45 int64 next_output_dim = 0;
46 for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
47 mapping.batch_dims.emplace_back();
48 mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
49 mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
50 mapping.batch_dims.back().output = next_output_dim++;
51 }
52 for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
53 mapping.contracting_dims.emplace_back();
54 mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
55 mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
56 mapping.contracting_dims.back().output = -1;
57 }
58 for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
59 if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
60 absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
61 continue;
62 }
63 mapping.lhs_non_contracting_dims.emplace_back();
64 mapping.lhs_non_contracting_dims.back().lhs = i;
65 mapping.lhs_non_contracting_dims.back().rhs = -1;
66 mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
67 }
68 for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
69 if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
70 absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
71 continue;
72 }
73 mapping.rhs_non_contracting_dims.emplace_back();
74 mapping.rhs_non_contracting_dims.back().lhs = -1;
75 mapping.rhs_non_contracting_dims.back().rhs = i;
76 mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
77 }
78 auto create_sharded_dot =
79 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
80 const Window& conv_window) -> StatusOr<HloInstruction*> {
81 TF_ASSIGN_OR_RETURN(
82 auto sharded_dot_shape,
83 ShapeInference::InferDotOpShape(
84 l->shape(), r->shape(), hlo->dot_dimension_numbers(),
85 /*preferred_element_type=*/hlo->shape().element_type()));
86 return b->AddInstruction(HloInstruction::CreateDot(
87 sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
88 hlo->precision_config()));
89 };
90 return HandleDotHelper(hlo, mapping, create_sharded_dot);
91 }
92
93 namespace {
94
95 enum class WindowedEinsumOperand { LHS, RHS };
96
97 struct WindowedEinsumConfig {
98 WindowedEinsumOperand windowed_op;
99 bool windowed_at_contracting_dims;
100 bool windowed_at_batch_dims;
101 bool operands_sharded_at_contracting_dims;
102 };
103
104 struct DotDimensionIndexMapping {
105 std::vector<int64> lhs_to_rhs_indices;
106 std::vector<int64> lhs_to_output_indices;
107 std::vector<int64> rhs_to_lhs_indices;
108 std::vector<int64> rhs_to_output_indices;
109 std::vector<int64> output_to_lhs_indices;
110 std::vector<int64> output_to_rhs_indices;
111 };
112
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64 reshaped_dim,bool lhs)113 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64 reshaped_dim,
114 bool lhs) {
115 auto update_dims =
116 [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64>* dims) {
117 for (int64 i = 0; i < dims->size(); ++i) {
118 auto dim = dims->at(i);
119 if (reshaped_dim <= dim) {
120 dims->Set(i, dim + 1);
121 }
122 }
123 if (absl::c_linear_search(*dims, reshaped_dim)) {
124 dims->Add(reshaped_dim);
125 }
126 };
127
128 if (lhs) {
129 update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
130 update_dims(new_ddnums->mutable_lhs_batch_dimensions());
131 } else { // rhs
132 update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
133 update_dims(new_ddnums->mutable_rhs_batch_dimensions());
134 }
135 }
136
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64 lhs_concat_dim,int64 rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)137 Window GenNewWindow(const HloInstruction* original_dot,
138 const HloInstruction* dot_lhs,
139 const HloInstruction* dot_rhs, int64 lhs_concat_dim,
140 int64 rhs_concat_dim, bool windowed_at_contracting_dims,
141 bool windowed_at_batch_dims) {
142 auto new_window = original_dot->window();
143 const ConvolutionDimensionNumbers& conv_dnums =
144 original_dot->convolution_dimension_numbers();
145 if (lhs_concat_dim != -1) {
146 for (int64 i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
147 if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
148 auto wd = new_window.mutable_dimensions(i);
149 auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
150 if (windowed_at_contracting_dims) {
151 wd->set_size(lhs_size);
152 }
153 if (windowed_at_batch_dims) {
154 wd->set_size(lhs_size);
155 wd->set_padding_low(0);
156 wd->set_padding_high(0);
157 wd->set_stride(std::max<int64>(1, lhs_size - 1));
158 wd->set_window_dilation(1);
159 wd->set_base_dilation(lhs_size);
160 wd->set_window_reversal(false);
161 }
162 }
163 }
164 }
165 if (rhs_concat_dim != -1) {
166 for (int64 i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
167 if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
168 !windowed_at_contracting_dims && !windowed_at_batch_dims &&
169 lhs_concat_dim == -1) {
170 auto wd = new_window.mutable_dimensions(i);
171 auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
172 wd->set_size(rhs_size);
173 wd->set_padding_low(rhs_size - 1);
174 wd->set_padding_high(rhs_size - 1);
175 }
176 }
177 }
178 // Add the extra dimension to window.
179 WindowDimension* new_dim = new_window.add_dimensions();
180 if (windowed_at_contracting_dims) {
181 new_dim->set_size(2);
182 new_dim->set_padding_low(0);
183 new_dim->set_padding_high(0);
184 new_dim->set_stride(1);
185 new_dim->set_window_dilation(1);
186 new_dim->set_base_dilation(1);
187 new_dim->set_window_reversal(false);
188 } else if (windowed_at_batch_dims) {
189 new_dim->set_size(2);
190 new_dim->set_padding_low(0);
191 new_dim->set_padding_high(0);
192 new_dim->set_stride(1); // std::max<int64>(1, 2 - 1)
193 new_dim->set_window_dilation(1);
194 new_dim->set_base_dilation(2);
195 new_dim->set_window_reversal(false);
196 } else {
197 if (lhs_concat_dim != -1) {
198 new_dim->set_size(1);
199 new_dim->set_padding_low(0);
200 new_dim->set_padding_high(0);
201 new_dim->set_stride(1);
202 new_dim->set_window_dilation(1);
203 new_dim->set_base_dilation(1);
204 new_dim->set_window_reversal(false);
205 }
206 if (rhs_concat_dim != -1) {
207 new_dim->set_size(2); // rhs_size
208 new_dim->set_padding_low(1); // rhs_size - 1
209 new_dim->set_padding_high(1); // rhs_size - 1
210 new_dim->set_stride(1);
211 new_dim->set_window_dilation(1);
212 new_dim->set_base_dilation(1);
213 new_dim->set_window_reversal(true);
214 }
215 }
216
217 VLOG(2) << "new_window: " << new_window.ShortDebugString();
218 return new_window;
219 }
220
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64 lhs_concat_dim,int64 rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64> & lhs_to_output_indices,const std::vector<int64> & rhs_to_output_indices,const Shape & new_dot_shape)221 ConvolutionDimensionNumbers GenNewConvDNums(
222 const HloInstruction* original_dot, const HloInstruction* dot_lhs,
223 const HloInstruction* dot_rhs, int64 lhs_concat_dim, int64 rhs_concat_dim,
224 bool windowed_at_contracting_dims, bool windowed_at_batch_dims,
225 const std::vector<int64>& lhs_to_output_indices,
226 const std::vector<int64>& rhs_to_output_indices,
227 const Shape& new_dot_shape) {
228 // Generate the new conv dimension numbers.
229 const ConvolutionDimensionNumbers& dnums =
230 original_dot->convolution_dimension_numbers();
231 // Handle the LHS dimension numbers.
232 int64 input_batch_dimension = dnums.input_batch_dimension();
233 int64 input_feature_dimension = dnums.input_feature_dimension();
234 std::vector<int64> input_spatial_dimensions(
235 dnums.input_spatial_dimensions().begin(),
236 dnums.input_spatial_dimensions().end());
237 if (lhs_concat_dim != -1) {
238 if (lhs_concat_dim <= input_batch_dimension) {
239 input_batch_dimension++;
240 }
241 if (lhs_concat_dim <= input_feature_dimension) {
242 input_feature_dimension++;
243 }
244 for (int64 i = 0; i < input_spatial_dimensions.size(); ++i) {
245 if (lhs_concat_dim <= input_spatial_dimensions[i]) {
246 input_spatial_dimensions[i]++;
247 }
248 }
249 input_spatial_dimensions.push_back(lhs_concat_dim);
250 }
251 if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
252 !windowed_at_batch_dims) {
253 input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
254 }
255 // Handle the RHS dimension numbers.
256 int64 kernel_input_feature_dimension = dnums.kernel_input_feature_dimension();
257 int64 kernel_output_feature_dimension =
258 dnums.kernel_output_feature_dimension();
259 std::vector<int64> kernel_spatial_dimensions(
260 dnums.kernel_spatial_dimensions().begin(),
261 dnums.kernel_spatial_dimensions().end());
262 if (rhs_concat_dim != -1) {
263 if (rhs_concat_dim <= kernel_input_feature_dimension) {
264 kernel_input_feature_dimension++;
265 }
266 if (rhs_concat_dim <= kernel_output_feature_dimension) {
267 kernel_output_feature_dimension++;
268 }
269 for (int64 i = 0; i < kernel_spatial_dimensions.size(); ++i) {
270 if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
271 kernel_spatial_dimensions[i]++;
272 }
273 }
274 kernel_spatial_dimensions.push_back(rhs_concat_dim);
275 }
276 if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
277 !windowed_at_batch_dims) {
278 kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
279 }
280 // Handle the Output dimension numbers.
281 int64 output_batch_dimension = dnums.output_batch_dimension();
282 int64 output_feature_dimension = dnums.output_feature_dimension();
283 std::vector<int64> output_spatial_dimensions(
284 dnums.output_spatial_dimensions().begin(),
285 dnums.output_spatial_dimensions().end());
286 if (!windowed_at_contracting_dims) {
287 auto output_slice_dim = lhs_concat_dim != -1
288 ? lhs_to_output_indices[lhs_concat_dim]
289 : rhs_to_output_indices[rhs_concat_dim];
290 if (output_slice_dim <= output_batch_dimension) {
291 output_batch_dimension++;
292 }
293 if (output_slice_dim <= output_feature_dimension) {
294 output_feature_dimension++;
295 }
296 for (int64 i = 0; i < output_spatial_dimensions.size(); ++i) {
297 if (output_slice_dim <= output_spatial_dimensions[i]) {
298 output_spatial_dimensions[i]++;
299 }
300 }
301 output_spatial_dimensions.push_back(output_slice_dim);
302 } else {
303 output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
304 }
305 // Construct the new dot dimension numbers.
306 ConvolutionDimensionNumbers new_dnums;
307 new_dnums.set_input_batch_dimension(input_batch_dimension);
308 new_dnums.set_input_feature_dimension(input_feature_dimension);
309 for (auto dim : input_spatial_dimensions) {
310 new_dnums.add_input_spatial_dimensions(dim);
311 }
312 new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
313 new_dnums.set_kernel_output_feature_dimension(
314 kernel_output_feature_dimension);
315 for (auto dim : kernel_spatial_dimensions) {
316 new_dnums.add_kernel_spatial_dimensions(dim);
317 }
318 new_dnums.set_output_batch_dimension(output_batch_dimension);
319 new_dnums.set_output_feature_dimension(output_feature_dimension);
320 for (auto dim : output_spatial_dimensions) {
321 new_dnums.add_output_spatial_dimensions(dim);
322 }
323
324 return new_dnums;
325 }
326
FirstShardingDimWithPartitionOfSize(int64 num_partitions,const HloSharding & sharding)327 int64 FirstShardingDimWithPartitionOfSize(int64 num_partitions,
328 const HloSharding& sharding) {
329 int64 sharding_dim = -1;
330 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
331 if (sharding.tile_assignment().dim(i) == num_partitions) {
332 sharding_dim = i;
333 break;
334 }
335 }
336 return sharding_dim;
337 }
338
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64 lhs_rank,int64 rhs_rank,int64 output_rank)339 DotDimensionIndexMapping ComputeDimensionIndexMapping(
340 const DotConvDimsMapping& dims_mapping, int64 lhs_rank, int64 rhs_rank,
341 int64 output_rank) {
342 std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1);
343 std::vector<int64> lhs_to_output_indices(lhs_rank, -1);
344 std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1);
345 std::vector<int64> rhs_to_output_indices(rhs_rank, -1);
346 std::vector<int64> output_to_lhs_indices(output_rank, -1);
347 std::vector<int64> output_to_rhs_indices(output_rank, -1);
348 auto populate_indices_mapping =
349 [&](const DotConvDimsMapping::DimsMapping& mapping) {
350 if (mapping.lhs >= 0) {
351 lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
352 lhs_to_output_indices[mapping.lhs] = mapping.output;
353 }
354 if (mapping.rhs >= 0) {
355 rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
356 rhs_to_output_indices[mapping.rhs] = mapping.output;
357 }
358 if (mapping.output >= 0) {
359 output_to_lhs_indices[mapping.output] = mapping.lhs;
360 output_to_rhs_indices[mapping.output] = mapping.rhs;
361 }
362 };
363 for (const auto& mapping : dims_mapping.batch_dims) {
364 populate_indices_mapping(mapping);
365 }
366 for (const auto& mapping : dims_mapping.contracting_dims) {
367 populate_indices_mapping(mapping);
368 }
369 for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
370 populate_indices_mapping(mapping);
371 }
372 for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
373 populate_indices_mapping(mapping);
374 }
375 for (const auto& mapping : dims_mapping.conv_spatial_dims) {
376 populate_indices_mapping(mapping);
377 }
378 return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices,
379 rhs_to_lhs_indices, rhs_to_output_indices,
380 output_to_lhs_indices, output_to_rhs_indices};
381 }
382
GetWindowedEinsumConfiguration(int64 num_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,int64 rhs_contracting_partitions,int64 rhs_non_contracting_partitions,int64 rhs_batch_partitions,int64 lhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 lhs_batch_partitions,int64 output_sharding_dim,int64 rhs_shape_size,int64 lhs_shape_size,int64 output_shape_size,int64 einsum_threshold_mib,const absl::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const absl::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding)383 absl::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
384 int64 num_partitions, int64 output_lhs_non_contracting_partitions,
385 int64 output_rhs_non_contracting_partitions,
386 int64 rhs_contracting_partitions, int64 rhs_non_contracting_partitions,
387 int64 rhs_batch_partitions, int64 lhs_contracting_partitions,
388 int64 lhs_non_contracting_partitions, int64 lhs_batch_partitions,
389 int64 output_sharding_dim, int64 rhs_shape_size, int64 lhs_shape_size,
390 int64 output_shape_size, int64 einsum_threshold_mib,
391 const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
392 const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
393 const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) {
394 if (output_lhs_non_contracting_partitions == num_partitions &&
395 output_sharding_transposed_to_match_lhs == lhs_sharding &&
396 rhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
397 if (rhs_contracting_partitions == num_partitions) {
398 return WindowedEinsumConfig{
399 /*windowed_op=*/WindowedEinsumOperand::RHS,
400 /*windowed_at_contracting_dims*/ true,
401 /*windowed_at_batch_dims=*/false,
402 /*operands_sharded_at_contracting_dims=*/false};
403 }
404 if (rhs_non_contracting_partitions == num_partitions) {
405 return WindowedEinsumConfig{
406 /*windowed_op=*/WindowedEinsumOperand::RHS,
407 /*windowed_at_contracting_dims*/ false,
408 /*windowed_at_batch_dims=*/false,
409 /*operands_sharded_at_contracting_dims=*/false};
410 }
411 if (rhs_batch_partitions == num_partitions) {
412 return WindowedEinsumConfig{
413 /*windowed_op=*/WindowedEinsumOperand::RHS,
414 /*windowed_at_contracting_dims*/ false,
415 /*windowed_at_batch_dims=*/true,
416 /*operands_sharded_at_contracting_dims=*/false};
417 }
418 }
419 if (output_rhs_non_contracting_partitions == num_partitions &&
420 output_sharding_transposed_to_match_rhs == rhs_sharding &&
421 lhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
422 if (lhs_contracting_partitions == num_partitions) {
423 return WindowedEinsumConfig{
424 /*windowed_op=*/WindowedEinsumOperand::LHS,
425 /*windowed_at_contracting_dims*/ true,
426 /*windowed_at_batch_dims=*/false,
427 /*operands_sharded_at_contracting_dims=*/false};
428 }
429 if (lhs_non_contracting_partitions == num_partitions) {
430 return WindowedEinsumConfig{
431 /*windowed_op=*/WindowedEinsumOperand::LHS,
432 /*windowed_at_contracting_dims*/ false,
433 /*windowed_at_batch_dims=*/false,
434 /*operands_sharded_at_contracting_dims=*/false};
435 }
436 if (lhs_batch_partitions == num_partitions) {
437 return WindowedEinsumConfig{
438 /*windowed_op=*/WindowedEinsumOperand::LHS,
439 /*windowed_at_contracting_dims*/ false,
440 /*windowed_at_batch_dims=*/true,
441 /*operands_sharded_at_contracting_dims=*/false};
442 }
443 }
444 if (lhs_contracting_partitions == rhs_contracting_partitions &&
445 lhs_contracting_partitions == num_partitions &&
446 output_sharding_dim > -1 &&
447 output_shape_size >= einsum_threshold_mib * 1024 * 1024) {
448 if (output_lhs_non_contracting_partitions == num_partitions) {
449 return WindowedEinsumConfig{
450 /*windowed_op=*/WindowedEinsumOperand::RHS,
451 /*windowed_at_contracting_dims*/ false,
452 /*windowed_at_batch_dims=*/false,
453 /*operands_sharded_at_contracting_dims=*/true};
454 }
455 if (output_rhs_non_contracting_partitions == num_partitions) {
456 return WindowedEinsumConfig{
457 /*windowed_op=*/WindowedEinsumOperand::LHS,
458 /*windowed_at_contracting_dims*/ false,
459 /*windowed_at_batch_dims=*/false,
460 /*operands_sharded_at_contracting_dims=*/true};
461 }
462 }
463 return absl::nullopt;
464 }
465
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64 lhs_batch_partitions,int64 rhs_batch_partitions,int64 output_batch_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match)466 StatusOr<HloInstruction*> PartitionBaseCase(
467 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
468 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
469 int64 num_partitions,
470 const std::function<StatusOr<HloInstruction*>(
471 HloInstruction*, HloInstruction*, SpmdBuilder*,
472 const Window& conv_window)>& create_sharded_dot,
473 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
474 int64 lhs_batch_partitions, int64 rhs_batch_partitions,
475 int64 output_batch_partitions, int64 lhs_contracting_partitions,
476 int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
477 int64 rhs_non_contracting_partitions,
478 int64 output_lhs_non_contracting_partitions,
479 int64 output_rhs_non_contracting_partitions,
480 const SpmdPartitionerOptions& options, SpmdBuilder* b,
481 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
482 windowed_dot_general_loops,
483 bool may_reshard_without_detecting_match) {
484 const HloSharding& lhs_sharding = lhs.sharding();
485 const HloSharding& rhs_sharding = rhs.sharding();
486 if (lhs_sharding.ReplicateOnLastTileDim() ||
487 rhs_sharding.ReplicateOnLastTileDim() ||
488 output_sharding.ReplicateOnLastTileDim()) {
489 return nullptr;
490 }
491 DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
492 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
493 output_base_shape.rank());
494 auto lhs_sharding_transposed_to_match_rhs =
495 hlo_sharding_util::TransposeShardingWithCollapsedDims(
496 lhs_sharding, indices_map.lhs_to_rhs_indices,
497 indices_map.rhs_to_lhs_indices);
498 auto rhs_sharding_transposed_to_match_lhs =
499 hlo_sharding_util::TransposeShardingWithCollapsedDims(
500 rhs_sharding, indices_map.rhs_to_lhs_indices,
501 indices_map.lhs_to_rhs_indices);
502 auto lhs_sharding_transposed_to_match_output =
503 hlo_sharding_util::TransposeShardingWithCollapsedDims(
504 lhs_sharding, indices_map.lhs_to_output_indices,
505 indices_map.output_to_lhs_indices);
506 auto rhs_sharding_transposed_to_match_output =
507 hlo_sharding_util::TransposeShardingWithCollapsedDims(
508 rhs_sharding, indices_map.rhs_to_output_indices,
509 indices_map.output_to_rhs_indices);
510 auto output_sharding_transposed_to_match_lhs =
511 hlo_sharding_util::TransposeShardingWithCollapsedDims(
512 output_sharding, indices_map.output_to_lhs_indices,
513 indices_map.lhs_to_output_indices);
514 auto output_sharding_transposed_to_match_rhs =
515 hlo_sharding_util::TransposeShardingWithCollapsedDims(
516 output_sharding, indices_map.output_to_rhs_indices,
517 indices_map.rhs_to_output_indices);
518
519 // LHS and RHS are partitioned the same way and only partitioned in batch
520 // dimensions.
521 if (lhs_batch_partitions == rhs_batch_partitions &&
522 rhs_batch_partitions == num_partitions &&
523 lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
524 TF_ASSIGN_OR_RETURN(
525 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
526 dot->set_sharding(*lhs_sharding_transposed_to_match_output);
527 return PartitionedHlo(dot, output_base_shape, lhs.state())
528 .Reshard(output_sharding)
529 .hlo();
530 }
531
532 // Try emit batch-partitioned einsum with one operand resharded. Returns
533 // partitioned HLO or nullptr if the attempt fails. If
534 // may_reshard_with_allreduce is false, reshard must be done using
535 // all-to-all/collective-permute; otherwise this attempt fails.
536 auto try_emit_output_batch_partitioned_einsum_with_reshard =
537 [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
538 // LHS and output are batch partitioned in the same way.
539 if (lhs_batch_partitions == num_partitions &&
540 output_batch_partitions == num_partitions &&
541 lhs_sharding_transposed_to_match_output == output_sharding) {
542 if (!may_reshard_with_allreduce &&
543 !CanReshardWithCollectivePermute(
544 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
545 !GetReshardAllToAllSourceTargetDims(
546 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
547 return nullptr;
548 }
549 auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
550 TF_ASSIGN_OR_RETURN(
551 auto dot,
552 create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
553 return dot;
554 }
555 // RHS and output are batch partitioned in the same way.
556 if (rhs_batch_partitions == num_partitions &&
557 output_batch_partitions == num_partitions &&
558 rhs_sharding_transposed_to_match_output == output_sharding) {
559 if (!may_reshard_with_allreduce &&
560 !CanReshardWithCollectivePermute(
561 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
562 !GetReshardAllToAllSourceTargetDims(
563 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
564 return nullptr;
565 }
566 auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
567 TF_ASSIGN_OR_RETURN(
568 auto dot,
569 create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
570 return dot;
571 }
572 return nullptr;
573 };
574
575 {
576 // Try batch-parallel by resharding one operand, and not using all-reduce.
577 TF_ASSIGN_OR_RETURN(
578 HloInstruction * partitioned_dot,
579 try_emit_output_batch_partitioned_einsum_with_reshard(false));
580 if (partitioned_dot) {
581 return partitioned_dot;
582 }
583 }
584
585 const int64 output_sharding_dim =
586 FirstShardingDimWithPartitionOfSize(num_partitions, output_sharding);
587 // Try to emit windowed DotGeneral when one operand is partitioned in the same
588 // way as the output along non-contracting dimensions, but the other operand
589 // is tiled in other dimensions. Or both operands are partitioned in the same
590 // way along contracting dimensions, but the output is partitioned along
591 // non-contracting dimensions.
592 auto emit_windowed_dot_general =
593 [&](const WindowedEinsumConfig& einsum_config)
594 -> StatusOr<HloInstruction*> {
595 CHECK(!einsum_config.windowed_at_batch_dims ||
596 !einsum_config.windowed_at_contracting_dims);
597 const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
598 const bool windowed_at_contracting_dims =
599 einsum_config.windowed_at_contracting_dims;
600 const bool operands_sharded_at_contracting_dims =
601 einsum_config.operands_sharded_at_contracting_dims;
602 auto unpadded_result_buffer_shape =
603 MakePartitionedShape(output_base_shape, output_sharding);
604 auto padded_result_buffer_shape = unpadded_result_buffer_shape;
605 const bool windowed_op_is_lhs =
606 einsum_config.windowed_op == WindowedEinsumOperand::LHS;
607 // For windowing at batch/non-contracting dims, we produce the result one
608 // partition at a time, so we need to pad the shape in case of uneven
609 // partitioning in order to make dynamic-update-slice in-bound.
610 if (!windowed_at_contracting_dims &&
611 !operands_sharded_at_contracting_dims) {
612 padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
613 padded_result_buffer_shape,
614 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
615 : *rhs_sharding_transposed_to_match_output);
616 }
617 // Mask the padding area of the windowed operand with zero if there is
618 // uneven partitioning.
619 if (windowed_at_contracting_dims) {
620 auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
621 to_mask =
622 to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
623 LiteralUtil::Zero(output_base_shape.element_type()))));
624 }
625 if (operands_sharded_at_contracting_dims) {
626 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
627 LiteralUtil::Zero(output_base_shape.element_type())));
628 lhs = lhs.PadWithValue(zero);
629 rhs = rhs.PadWithValue(zero);
630 }
631 auto result_buffer = CreateZero(padded_result_buffer_shape, b);
632 auto extra_buffer =
633 (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
634 operands_sharded_at_contracting_dims)
635 ? CreateZero(padded_result_buffer_shape, b)
636 : windowed_op_is_lhs ? lhs.hlo()
637 : rhs.hlo();
638
639 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
640 !operands_sharded_at_contracting_dims) {
641 std::vector<std::pair<int64, int64>> pre_sd_pairs(num_partitions);
642 for (int64 source = 0; source < num_partitions; ++source) {
643 // 0 -> 1, 1 -> 2, 2 -> 3, ...
644 pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
645 }
646 extra_buffer =
647 lhs.state()
648 .collective_ops_creator.create_cross_partition_collective_permute(
649 b, extra_buffer, pre_sd_pairs,
650 (*lhs.state().next_channel_id)++);
651 }
652
653 auto iteration = b->AddInstruction(
654 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
655
656 // Create a while loop that computes one window per iteration. During each
657 // iteration, each partition sends its input window to its neighbor using
658 // collective-permute for the next iteration.
659 SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
660
661 // Generate partial results used by bidirectional algorithm.
662 auto get_partial_bid_results =
663 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
664 HloInstruction* extra_inout, HloInstruction* cw_cp_output,
665 HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
666 auto partition_id =
667 lhs.state().collective_ops_creator.create_partition_id(&body_b);
668 auto partition_count =
669 body_b.AddInstruction(HloInstruction::CreateConstant(
670 LiteralUtil::CreateR0<uint32>(num_partitions)));
671 auto ccw_data_partition_id =
672 body_b.AddInstruction(HloInstruction::CreateBinary(
673 i->shape(), HloOpcode::kAdd, i, partition_id));
674 auto cw_data_partition_id =
675 body_b.AddInstruction(HloInstruction::CreateBinary(
676 i->shape(), HloOpcode::kAdd, partition_count, partition_id));
677 if (operands_sharded_at_contracting_dims) {
678 ccw_data_partition_id =
679 body_b.AddInstruction(HloInstruction::CreateBinary(
680 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
681 body_b.AddInstruction(HloInstruction::CreateConstant(
682 LiteralUtil::CreateR0<uint32>(num_partitions / 2 + 1)))));
683 cw_data_partition_id =
684 body_b.AddInstruction(HloInstruction::CreateBinary(
685 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
686 body_b.AddInstruction(HloInstruction::CreateConstant(
687 LiteralUtil::CreateR0<uint32>(num_partitions / 2)))));
688 } else {
689 cw_data_partition_id =
690 body_b.AddInstruction(HloInstruction::CreateBinary(
691 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
692 CreateOne(cw_data_partition_id->shape(), &body_b)));
693 }
694 ccw_data_partition_id = body_b.AddInstruction(
695 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
696 ccw_data_partition_id, partition_count));
697 cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
698 i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
699 cw_data_partition_id = body_b.AddInstruction(
700 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
701 cw_data_partition_id, partition_count));
702 // Calculate concat dim.
703 const HloSharding* slice_sharding;
704 if (operands_sharded_at_contracting_dims) {
705 slice_sharding = windowed_op_is_lhs
706 ? &*output_sharding_transposed_to_match_rhs
707 : &*output_sharding_transposed_to_match_lhs;
708 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
709 slice_sharding = windowed_op_is_lhs
710 ? &*lhs_sharding_transposed_to_match_rhs
711 : &*rhs_sharding_transposed_to_match_lhs;
712 } else {
713 slice_sharding = windowed_op_is_lhs
714 ? &*lhs_sharding_transposed_to_match_output
715 : &*rhs_sharding_transposed_to_match_output;
716 }
717 CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
718 num_partitions);
719 int64 slice_sharding_dim = -1;
720 for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions();
721 ++i) {
722 if (slice_sharding->tile_assignment().dim(i) > 1) {
723 slice_sharding_dim = i;
724 break;
725 }
726 }
727 int64 lhs_concat_dim = -1;
728 int64 rhs_concat_dim = -1;
729 if (operands_sharded_at_contracting_dims) {
730 if (windowed_op_is_lhs) {
731 rhs_concat_dim = slice_sharding_dim;
732 } else {
733 lhs_concat_dim = slice_sharding_dim;
734 }
735 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
736 lhs_concat_dim =
737 windowed_op_is_lhs
738 ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
739 : slice_sharding_dim;
740 rhs_concat_dim =
741 windowed_op_is_lhs
742 ? slice_sharding_dim
743 : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
744 } else {
745 if (windowed_op_is_lhs) {
746 lhs_concat_dim =
747 indices_map.output_to_lhs_indices[slice_sharding_dim];
748 } else {
749 rhs_concat_dim =
750 indices_map.output_to_rhs_indices[slice_sharding_dim];
751 }
752 }
753
754 DotDimensionNumbers new_ddnums;
755 if (original_hlo->opcode() == HloOpcode::kDot) {
756 new_ddnums = original_hlo->dot_dimension_numbers();
757 }
758
759 auto dot_lhs = l;
760 auto dot_rhs = r;
761 auto original_dot_lhs = l;
762 auto original_dot_rhs = r;
763 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
764 operands_sharded_at_contracting_dims) {
765 // Slice the matching operand according to the partitioned dimensions
766 // on the windowed operand or the output.
767 auto slice_operand = !windowed_op_is_lhs ? l : r;
768
769 // Pad the sharding dim first (then the concat dim) for correctness.
770 auto sharding_dim_size =
771 slice_operand->shape().dimensions(slice_sharding_dim);
772 if (sharding_dim_size % num_partitions != 0) {
773 slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
774 slice_operand, *slice_sharding, &body_b);
775 }
776
777 // We do this by treating the matching operand as replicated, and
778 // resharding it to match the windowed operand or the output.
779 auto gen_slice = [&](HloInstruction* data_partition_id,
780 bool ccw) -> HloInstruction* {
781 std::vector<int64> new_dims;
782 for (int64 i = 0; i < slice_operand->shape().dimensions_size(); ++i) {
783 if (i == slice_sharding_dim) {
784 new_dims.push_back(1);
785 }
786 new_dims.push_back(slice_operand->shape().dimensions(i));
787 }
788 auto reshaped_slice_operand =
789 body_b.AddInstruction(HloInstruction::CreateReshape(
790 ShapeUtil::MakeShape(slice_operand->shape().element_type(),
791 new_dims),
792 slice_operand));
793 auto min = body_b.AddInstruction(
794 HloInstruction::CreateConstant(LiteralUtil::MinValue(
795 reshaped_slice_operand->shape().element_type())));
796 std::vector<int64> min_padding(
797 reshaped_slice_operand->shape().rank());
798 auto padded_slice_operand = reshaped_slice_operand;
799 auto padded_shape = padded_slice_operand->shape();
800 int64 padding_dim = slice_sharding_dim;
801 padded_shape.set_dimensions(padding_dim, 2);
802 if (ccw) {
803 // ccw pad high
804 PaddingConfig ccw_pad_config =
805 window_util::MakeSymmetricPadding(min_padding);
806 ccw_pad_config.mutable_dimensions(padding_dim)
807 ->set_edge_padding_low(0);
808 ccw_pad_config.mutable_dimensions(padding_dim)
809 ->set_edge_padding_high(1);
810 padded_slice_operand =
811 body_b.AddInstruction(HloInstruction::CreatePad(
812 padded_shape, padded_slice_operand, min, ccw_pad_config));
813 } else {
814 // cw pad low
815 PaddingConfig cw_pad_config =
816 window_util::MakeSymmetricPadding(min_padding);
817 cw_pad_config.mutable_dimensions(padding_dim)
818 ->set_edge_padding_low(1);
819 cw_pad_config.mutable_dimensions(padding_dim)
820 ->set_edge_padding_high(0);
821 padded_slice_operand =
822 body_b.AddInstruction(HloInstruction::CreatePad(
823 padded_shape, padded_slice_operand, min, cw_pad_config));
824 }
825
826 padded_slice_operand->set_sharding(HloSharding::Replicate());
827 auto state = lhs.state();
828 state.b = &body_b;
829 state.partition_id = data_partition_id;
830 state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
831 auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
832 slice_operand->shape(), reshaped_slice_operand->shape(),
833 *slice_sharding);
834 auto padded_slice =
835 PartitionedHlo(padded_slice_operand,
836 padded_slice_operand->shape(), state)
837 .Reshard(*padded_slice_sharding)
838 .hlo();
839 padded_slice_operand->clear_sharding();
840 return padded_slice;
841 };
842
843 auto ccw_slice = gen_slice(ccw_data_partition_id, true);
844 auto cw_slice = gen_slice(cw_data_partition_id, false);
845 auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
846 ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
847 // Reshape. The reshaped slice will not be used to produce the final
848 // result, but used as a hint for the shape inference.
849 std::vector<int64> reshaped_slice_dims;
850 for (int64 i = 0; i < slice->shape().dimensions_size(); ++i) {
851 auto dim_size = slice->shape().dimensions(i);
852 if (i == (slice_sharding_dim + 1)) {
853 reshaped_slice_dims.push_back(dim_size * 2);
854 } else if (i != slice_sharding_dim) {
855 reshaped_slice_dims.push_back(dim_size);
856 }
857 }
858 auto reshaped_slice =
859 body_b.AddInstruction(HloInstruction::CreateReshape(
860 ShapeUtil::MakeShape(slice->shape().element_type(),
861 reshaped_slice_dims),
862 slice));
863
864 if (!windowed_op_is_lhs) {
865 dot_lhs = slice;
866 original_dot_lhs = reshaped_slice;
867 if (original_hlo->opcode() == HloOpcode::kDot) {
868 UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
869 }
870 } else {
871 dot_rhs = slice;
872 original_dot_rhs = reshaped_slice;
873 if (original_hlo->opcode() == HloOpcode::kDot) {
874 UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
875 }
876 }
877 }
878
879 auto ccw_dot_lhs = l;
880 auto ccw_dot_rhs = r;
881 auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
882 auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
883 if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
884 auto lhs_concat_shape = ccw_dot_lhs->shape();
885 lhs_concat_shape.set_dimensions(
886 lhs_concat_dim,
887 ccw_dot_lhs->shape().dimensions(lhs_concat_dim) * 2);
888 dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
889 lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
890 original_dot_lhs = dot_lhs;
891
892 // Reshape
893 std::vector<int64> reshaped_dims(dot_lhs->shape().dimensions().begin(),
894 dot_lhs->shape().dimensions().end());
895 reshaped_dims[lhs_concat_dim] /= 2;
896 reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 2);
897 dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
898 ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
899 reshaped_dims),
900 dot_lhs));
901
902 if (original_hlo->opcode() == HloOpcode::kDot) {
903 UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
904 }
905 }
906 if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
907 auto rhs_concat_shape = ccw_dot_rhs->shape();
908 rhs_concat_shape.set_dimensions(
909 rhs_concat_dim,
910 ccw_dot_rhs->shape().dimensions(rhs_concat_dim) * 2);
911 dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
912 rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
913 original_dot_rhs = dot_rhs;
914
915 // Reshape
916 std::vector<int64> reshaped_dims(dot_rhs->shape().dimensions().begin(),
917 dot_rhs->shape().dimensions().end());
918 reshaped_dims[rhs_concat_dim] /= 2;
919 reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 2);
920 dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
921 ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
922 reshaped_dims),
923 dot_rhs));
924
925 if (original_hlo->opcode() == HloOpcode::kDot) {
926 UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
927 }
928 }
929
930 // The generated original dot will not be used.
931 TF_ASSIGN_OR_RETURN(auto original_dot,
932 create_sharded_dot(original_dot_lhs, original_dot_rhs,
933 &body_b, conv_window));
934 VLOG(2) << original_dot->ToString();
935
936 // Generate the correct shape of the new dot/conv.
937 auto original_sharded_dot_shape = original_dot->shape();
938 auto new_dot_shape = original_sharded_dot_shape;
939 std::vector<int64> new_dims(new_dot_shape.dimensions().begin(),
940 new_dot_shape.dimensions().end());
941 if (!windowed_at_contracting_dims) {
942 auto slice_dim =
943 lhs_concat_dim != -1
944 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
945 : indices_map.rhs_to_output_indices[rhs_concat_dim];
946 new_dims[slice_dim] /= 2;
947 new_dims.insert(new_dims.begin() + slice_dim, 2);
948 } else {
949 new_dims.push_back(1);
950 }
951 new_dot_shape =
952 ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
953
954 HloInstruction* dot;
955 if (original_hlo->opcode() == HloOpcode::kDot) {
956 dot = body_b.AddInstruction(HloInstruction::CreateDot(
957 new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
958 original_hlo->precision_config()));
959 } else {
960 if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
961 if (lhs_concat_dim != -1) {
962 std::vector<int64> new_dims(dot_rhs->shape().dimensions().begin(),
963 dot_rhs->shape().dimensions().end());
964 new_dims.push_back(1);
965 dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
966 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
967 dot_rhs));
968 }
969 if (rhs_concat_dim != -1) {
970 std::vector<int64> new_dims(dot_lhs->shape().dimensions().begin(),
971 dot_lhs->shape().dimensions().end());
972 new_dims.push_back(1);
973 dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
974 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
975 dot_lhs));
976 }
977 }
978
979 dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
980 new_dot_shape, dot_lhs, dot_rhs,
981 original_dot->feature_group_count(),
982 original_dot->batch_group_count(),
983 GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
984 rhs_concat_dim, windowed_at_contracting_dims,
985 windowed_at_batch_dims),
986 GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
987 rhs_concat_dim, windowed_at_contracting_dims,
988 windowed_at_batch_dims,
989 indices_map.lhs_to_output_indices,
990 indices_map.rhs_to_output_indices, new_dot_shape),
991 original_dot->precision_config()));
992 }
993 VLOG(2) << dot->ToString();
994
995 // Reshape to the original sharded dot shape.
996 dot = body_b.AddInstruction(
997 HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
998
999 if (windowed_at_contracting_dims) {
1000 // Accumulate the partial output to the result buffer.
1001 o = body_b.AddInstruction(
1002 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1003 } else {
1004 // The windowing operand is partitioned along batch/non-contracting
1005 // dimensions, so we need a dynamic-update-slice to save the partial
1006 // output in the result buffer.
1007 auto slice_shape = dot->shape();
1008 auto slice_dim =
1009 lhs_concat_dim != -1
1010 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1011 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1012 slice_shape.set_dimensions(slice_dim,
1013 dot->shape().dimensions(slice_dim) / 2);
1014 std::vector<int64> ccw_start_indices(dot->shape().rank(), 0);
1015 std::vector<int64> cw_start_indices(dot->shape().rank(), 0);
1016 cw_start_indices[slice_dim] = dot->shape().dimensions(slice_dim) / 2;
1017 auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1018 slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1019 std::vector<int64>(dot->shape().rank(), 1)));
1020 auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1021 slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1022 std::vector<int64>(dot->shape().rank(), 1)));
1023
1024 if (operands_sharded_at_contracting_dims) {
1025 // Accumulate the partial output to the result buffer.
1026 o = body_b.AddInstruction(HloInstruction::CreateBinary(
1027 o->shape(), HloOpcode::kAdd, o, ccw_dot));
1028 cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1029 o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1030 } else {
1031 auto ccw_offsets = MakePartitionOffsets(
1032 o->shape(),
1033 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1034 : *rhs_sharding_transposed_to_match_output,
1035 ccw_data_partition_id, &body_b);
1036 auto cw_offsets = MakePartitionOffsets(
1037 o->shape(),
1038 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1039 : *rhs_sharding_transposed_to_match_output,
1040 cw_data_partition_id, &body_b);
1041 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1042 o->shape(), o, ccw_dot, ccw_offsets));
1043 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1044 o->shape(), o, cw_dot, cw_offsets));
1045 }
1046 }
1047
1048 std::vector<HloInstruction*> partial_results;
1049 partial_results.push_back(o);
1050 partial_results.push_back(cw_cp_output);
1051 return partial_results;
1052 };
1053
1054 // Generate partial result used by unidirectional algorithm.
1055 auto get_partial_unid_result =
1056 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1057 HloInstruction* i) -> StatusOr<HloInstruction*> {
1058 auto partition_id =
1059 lhs.state().collective_ops_creator.create_partition_id(&body_b);
1060 auto data_partition_id =
1061 body_b.AddInstruction(HloInstruction::CreateBinary(
1062 i->shape(), HloOpcode::kAdd, i, partition_id));
1063 auto partition_count =
1064 body_b.AddInstruction(HloInstruction::CreateConstant(
1065 LiteralUtil::CreateR0<uint32>(num_partitions)));
1066 data_partition_id = body_b.AddInstruction(
1067 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1068 data_partition_id, partition_count));
1069 auto dot_lhs = l;
1070 auto dot_rhs = r;
1071 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1072 operands_sharded_at_contracting_dims) {
1073 // Slice the matching operand according to the partitioned dimensions on
1074 // the windowed operand or the output.
1075 auto slice_operand = !windowed_op_is_lhs ? l : r;
1076 // We do this by treating the matching operand as replicated, and
1077 // resharding it to match the windowed operand or the output.
1078 slice_operand->set_sharding(HloSharding::Replicate());
1079 auto state = lhs.state();
1080 state.b = &body_b;
1081 state.partition_id = data_partition_id;
1082 state.reshard_cache->per_hlo_cache.erase(slice_operand);
1083 const HloSharding* slice_sharding;
1084 if (operands_sharded_at_contracting_dims) {
1085 slice_sharding = windowed_op_is_lhs
1086 ? &*output_sharding_transposed_to_match_rhs
1087 : &*output_sharding_transposed_to_match_lhs;
1088 } else {
1089 slice_sharding = windowed_op_is_lhs
1090 ? &*lhs_sharding_transposed_to_match_rhs
1091 : &*rhs_sharding_transposed_to_match_lhs;
1092 }
1093 auto slice =
1094 PartitionedHlo(slice_operand, slice_operand->shape(), state)
1095 .Reshard(*slice_sharding)
1096 .hlo();
1097 slice_operand->clear_sharding();
1098 if (!windowed_op_is_lhs) {
1099 dot_lhs = slice;
1100 } else {
1101 dot_rhs = slice;
1102 }
1103 }
1104 TF_ASSIGN_OR_RETURN(
1105 auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1106 if (windowed_at_contracting_dims ||
1107 operands_sharded_at_contracting_dims) {
1108 // Accumulate the partial output to the result buffer.
1109 o = body_b.AddInstruction(
1110 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1111 } else {
1112 // The windowing operand is partitioned along batch/non-contracting
1113 // dimensions, so we need a dynamic-update-slice to save the partial
1114 // output in the result buffer.
1115 auto offsets = MakePartitionOffsets(
1116 o->shape(),
1117 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1118 : *rhs_sharding_transposed_to_match_output,
1119 data_partition_id, &body_b);
1120 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1121 o->shape(), o, dot, offsets));
1122 }
1123 return o;
1124 };
1125
1126 auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1127 /*parameter_number=*/0,
1128 ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
1129 result_buffer->shape(),
1130 extra_buffer->shape(), iteration->shape()}),
1131 "param"));
1132 auto l = body_b.AddInstruction(
1133 HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0));
1134 auto r = body_b.AddInstruction(
1135 HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1));
1136 auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1137 result_buffer->shape(), param, 2));
1138 auto extra_inout = body_b.AddInstruction(
1139 HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1140 auto i = body_b.AddInstruction(
1141 HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1142
1143 // The bidirectional collective permute implementation has loop unrolling
1144 // of degree 2, so num_partitions is required to be a multiple of 4.
1145 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1146 std::vector<std::pair<int64, int64>> ccw_sd_pairs(num_partitions);
1147 for (int64 source = 0; source < num_partitions; ++source) {
1148 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1149 ccw_sd_pairs[source] = {source,
1150 (source - 1 + num_partitions) % num_partitions};
1151 }
1152 std::vector<std::pair<int64, int64>> cw_sd_pairs(num_partitions);
1153 for (int64 source = 0; source < num_partitions; ++source) {
1154 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1155 cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1156 }
1157
1158 // Even number iteration.
1159 auto next_l = l;
1160 auto next_r = r;
1161 auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1162 : windowed_op_is_lhs ? l
1163 : r;
1164 auto ccw_cp_output =
1165 lhs.state()
1166 .collective_ops_creator.create_cross_partition_collective_permute(
1167 &body_b, ccw_cp_input, ccw_sd_pairs,
1168 (*lhs.state().next_channel_id)++);
1169 if (operands_sharded_at_contracting_dims) {
1170 o = ccw_cp_output;
1171 } else if (windowed_op_is_lhs) {
1172 next_l = ccw_cp_output;
1173 } else {
1174 next_r = ccw_cp_output;
1175 }
1176 auto cw_cp_input = extra_inout;
1177 auto cw_cp_output =
1178 lhs.state()
1179 .collective_ops_creator.create_cross_partition_collective_permute(
1180 &body_b, cw_cp_input, cw_sd_pairs,
1181 (*lhs.state().next_channel_id)++);
1182
1183 TF_ASSIGN_OR_RETURN(
1184 auto outputs,
1185 get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1186 o = outputs[0];
1187 cw_cp_output = outputs[1];
1188
1189 // ++i
1190 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1191 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1192
1193 // Odd number iteration.
1194 auto second_next_l = next_l;
1195 auto second_next_r = next_r;
1196 ccw_cp_input = operands_sharded_at_contracting_dims ? o
1197 : windowed_op_is_lhs ? next_l
1198 : next_r;
1199 ccw_cp_output =
1200 lhs.state()
1201 .collective_ops_creator.create_cross_partition_collective_permute(
1202 &body_b, ccw_cp_input, ccw_sd_pairs,
1203 (*lhs.state().next_channel_id)++);
1204 if (operands_sharded_at_contracting_dims) {
1205 o = ccw_cp_output;
1206 } else if (windowed_op_is_lhs) {
1207 second_next_l = ccw_cp_output;
1208 } else {
1209 second_next_r = ccw_cp_output;
1210 }
1211 auto next_cw_cp_input = cw_cp_output;
1212 auto next_cw_cp_output =
1213 lhs.state()
1214 .collective_ops_creator.create_cross_partition_collective_permute(
1215 &body_b, next_cw_cp_input, cw_sd_pairs,
1216 (*lhs.state().next_channel_id)++);
1217
1218 TF_ASSIGN_OR_RETURN(
1219 outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1220 next_cw_cp_output, i));
1221 o = outputs[0];
1222 next_cw_cp_output = outputs[1];
1223
1224 // ++i
1225 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1226 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1227
1228 body_b.AddInstruction(HloInstruction::CreateTuple(
1229 {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1230
1231 } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1232 if (operands_sharded_at_contracting_dims) {
1233 std::vector<std::pair<int64, int64>> output_sd_pairs(num_partitions);
1234 for (int64 source = 0; source < num_partitions; ++source) {
1235 // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1236 output_sd_pairs[source] = {
1237 source, (source - 2 + num_partitions) % num_partitions};
1238 }
1239
1240 o = lhs.state()
1241 .collective_ops_creator
1242 .create_cross_partition_collective_permute(
1243 &body_b, o, output_sd_pairs,
1244 (*lhs.state().next_channel_id)++);
1245
1246 TF_ASSIGN_OR_RETURN(extra_inout,
1247 get_partial_unid_result(l, r, extra_inout, i));
1248
1249 extra_inout = lhs.state()
1250 .collective_ops_creator
1251 .create_cross_partition_collective_permute(
1252 &body_b, extra_inout, output_sd_pairs,
1253 (*lhs.state().next_channel_id)++);
1254
1255 // i+2
1256 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1257 i->shape(), HloOpcode::kAdd, i,
1258 body_b.AddInstruction(HloInstruction::CreateConstant(
1259 LiteralUtil::CreateR0<uint32>(2)))));
1260 auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1261 i->shape(), HloOpcode::kAdd, i,
1262 body_b.AddInstruction(HloInstruction::CreateConstant(
1263 LiteralUtil::CreateR0<uint32>(1)))));
1264
1265 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1266 body_b.AddInstruction(
1267 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1268 } else {
1269 std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1270 for (int64 source = 0; source < num_partitions; ++source) {
1271 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1272 sd_pairs[source] = {source,
1273 (source - 1 + num_partitions) % num_partitions};
1274 }
1275
1276 // Even number iteration.
1277 auto next_l = l;
1278 auto next_r = r;
1279 auto cp_input = windowed_op_is_lhs ? l : r;
1280 auto cp_output = lhs.state()
1281 .collective_ops_creator
1282 .create_cross_partition_collective_permute(
1283 &body_b, cp_input, sd_pairs,
1284 (*lhs.state().next_channel_id)++);
1285 if (windowed_op_is_lhs) {
1286 next_l = cp_output;
1287 } else {
1288 next_r = cp_output;
1289 }
1290 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1291
1292 // ++i
1293 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1294 i->shape(), HloOpcode::kAdd, i,
1295 body_b.AddInstruction(HloInstruction::CreateConstant(
1296 LiteralUtil::CreateR0<uint32>(1)))));
1297
1298 // Odd number iteration.
1299 auto second_next_l = next_l;
1300 auto second_next_r = next_r;
1301 cp_input = windowed_op_is_lhs ? next_l : next_r;
1302 cp_output = lhs.state()
1303 .collective_ops_creator
1304 .create_cross_partition_collective_permute(
1305 &body_b, cp_input, sd_pairs,
1306 (*lhs.state().next_channel_id)++);
1307 if (windowed_op_is_lhs) {
1308 second_next_l = cp_output;
1309 } else {
1310 second_next_r = cp_output;
1311 }
1312 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1313
1314 // ++i
1315 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1316 i->shape(), HloOpcode::kAdd, i,
1317 body_b.AddInstruction(HloInstruction::CreateConstant(
1318 LiteralUtil::CreateR0<uint32>(1)))));
1319
1320 body_b.AddInstruction(HloInstruction::CreateTuple(
1321 {second_next_l, second_next_r, o, extra_inout, i}));
1322 }
1323 } else {
1324 auto real_i = i;
1325 if (operands_sharded_at_contracting_dims) {
1326 // For reduce-scatter case, start from the data_partition_id + 1 to make
1327 // the data_partition_id of the final data shard in each partition the
1328 // same as the corresponding partition_id.
1329 real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1330 real_i->shape(), HloOpcode::kAdd, real_i,
1331 CreateOne(real_i->shape(), &body_b)));
1332 }
1333 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1334
1335 // ++i
1336 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1337 i->shape(), HloOpcode::kAdd, i,
1338 body_b.AddInstruction(HloInstruction::CreateConstant(
1339 LiteralUtil::CreateR0<uint32>(1)))));
1340 auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1341 ShapeUtil::MakeShape(PRED, {}), i,
1342 body_b.AddInstruction(HloInstruction::CreateConstant(
1343 LiteralUtil::CreateR0<uint32>(num_partitions))),
1344 ComparisonDirection::kLt));
1345 // Collective-permute for the next window. We don't need it for the last
1346 // iteration, so we use a conditional around the collective-permute.
1347 HloInstruction* conditional;
1348 {
1349 SpmdBuilder cp_b("window_collective_permute", original_hlo);
1350 {
1351 auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1352 0,
1353 operands_sharded_at_contracting_dims ? o->shape()
1354 : windowed_op_is_lhs ? l->shape()
1355 : r->shape(),
1356 "window"));
1357 std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1358 for (int64 source = 0; source < num_partitions; ++source) {
1359 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1360 sd_pairs[source] = {source,
1361 (source - 1 + num_partitions) % num_partitions};
1362 }
1363 lhs.state()
1364 .collective_ops_creator.create_cross_partition_collective_permute(
1365 &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1366 }
1367 SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1368 {
1369 ncp_b.AddInstruction(HloInstruction::CreateParameter(
1370 0,
1371 operands_sharded_at_contracting_dims ? o->shape()
1372 : windowed_op_is_lhs ? l->shape()
1373 : r->shape(),
1374 "window"));
1375 }
1376 conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1377 operands_sharded_at_contracting_dims ? o->shape()
1378 : windowed_op_is_lhs ? l->shape()
1379 : r->shape(),
1380 has_more,
1381 operands_sharded_at_contracting_dims ? o
1382 : windowed_op_is_lhs ? l
1383 : r,
1384 module->AddEmbeddedComputation(cp_b.Build()),
1385 operands_sharded_at_contracting_dims ? o
1386 : windowed_op_is_lhs ? l
1387 : r,
1388 module->AddEmbeddedComputation(ncp_b.Build())));
1389 }
1390 if (operands_sharded_at_contracting_dims) {
1391 o = conditional;
1392 } else if (windowed_op_is_lhs) {
1393 l = conditional;
1394 } else {
1395 r = conditional;
1396 }
1397 body_b.AddInstruction(
1398 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1399 }
1400
1401 SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1402 auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1403 /*parameter_number=*/0,
1404 ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
1405 result_buffer->shape(),
1406 extra_buffer->shape(), iteration->shape()}),
1407 "param"));
1408 auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1409 iteration->shape(), cond_param, 4));
1410 int64 adapted_num_partitions =
1411 (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1412 ? num_partitions / 2
1413 : num_partitions;
1414 cond_b.AddInstruction(HloInstruction::CreateCompare(
1415 ShapeUtil::MakeShape(PRED, {}), cond_i,
1416 cond_b.AddInstruction(HloInstruction::CreateConstant(
1417 LiteralUtil::CreateR0<uint32>(adapted_num_partitions))),
1418 ComparisonDirection::kLt));
1419 auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1420 cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1421 module->AddEmbeddedComputation(body_b.Build()),
1422 b->AddInstruction(HloInstruction::CreateTuple(
1423 {lhs.hlo(), rhs.hlo(), result_buffer, extra_buffer, iteration}))));
1424 windowed_dot_general_loops->push_back(
1425 {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1426 windowed_at_batch_dims, operands_sharded_at_contracting_dims});
1427 auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1428 result_buffer->shape(), while_loop, 2));
1429 if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1430 (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1431 operands_sharded_at_contracting_dims) {
1432 std::vector<std::pair<int64, int64>> extra_sd_pairs(num_partitions);
1433 for (int64 source = 0; source < num_partitions; ++source) {
1434 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1435 extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1436 }
1437 auto extra_result =
1438 b->AddInstruction(HloInstruction::CreateGetTupleElement(
1439 extra_buffer->shape(), while_loop, 3));
1440 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1441 extra_result = lhs.state()
1442 .collective_ops_creator
1443 .create_cross_partition_collective_permute(
1444 b, extra_result, extra_sd_pairs,
1445 (*lhs.state().next_channel_id)++);
1446 }
1447 if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1448 result = lhs.state()
1449 .collective_ops_creator
1450 .create_cross_partition_collective_permute(
1451 b, result, extra_sd_pairs,
1452 (*lhs.state().next_channel_id)++);
1453 }
1454 result = b->AddInstruction(HloInstruction::CreateBinary(
1455 result->shape(), HloOpcode::kAdd, result, extra_result));
1456 }
1457 if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1458 unpadded_result_buffer_shape)) {
1459 result = b->AddInstruction(HloInstruction::CreateSlice(
1460 unpadded_result_buffer_shape, result,
1461 std::vector<int64>(padded_result_buffer_shape.rank(), 0),
1462 unpadded_result_buffer_shape.dimensions(),
1463 std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
1464 }
1465 return result;
1466 };
1467 absl::optional<WindowedEinsumConfig> e_config =
1468 GetWindowedEinsumConfiguration(
1469 num_partitions, output_lhs_non_contracting_partitions,
1470 output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1471 rhs_non_contracting_partitions, rhs_batch_partitions,
1472 lhs_contracting_partitions, lhs_non_contracting_partitions,
1473 lhs_batch_partitions, output_sharding_dim,
1474 ShapeSizeInBytes(rhs.base_shape()),
1475 ShapeSizeInBytes(lhs.base_shape()),
1476 ShapeSizeInBytes(output_base_shape),
1477 options.threshold_for_windowed_einsum_mib,
1478 output_sharding_transposed_to_match_lhs,
1479 output_sharding_transposed_to_match_rhs, lhs_sharding, rhs_sharding);
1480 if (e_config) {
1481 return emit_windowed_dot_general(*e_config);
1482 }
1483
1484 {
1485 // Try batch-parallel by resharding one operand, and allowing all-reduce.
1486 TF_ASSIGN_OR_RETURN(
1487 HloInstruction * partitioned_dot,
1488 try_emit_output_batch_partitioned_einsum_with_reshard(true));
1489 if (partitioned_dot) {
1490 return partitioned_dot;
1491 }
1492 }
1493
1494 // LHS and RHS have the same partitioned contracting dimensions.
1495 if (lhs_contracting_partitions == rhs_contracting_partitions &&
1496 lhs_contracting_partitions == num_partitions) {
1497 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1498 LiteralUtil::Zero(output_base_shape.element_type())));
1499 // Pad both sides with zero, since NaN at one side cannot be masked by zero
1500 // on the other side.
1501 if (ShapeSizeInBytes(lhs.base_shape()) <
1502 ShapeSizeInBytes(rhs.base_shape())) {
1503 lhs =
1504 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1505 rhs = rhs.PadWithValue(zero);
1506 } else {
1507 lhs = lhs.PadWithValue(zero);
1508 rhs =
1509 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1510 }
1511 TF_ASSIGN_OR_RETURN(
1512 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1513 std::vector<int64> lhs_contracting_dims;
1514 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1515 for (const auto& cd : dims_mapping.contracting_dims) {
1516 lhs_contracting_dims.push_back(cd.lhs);
1517 }
1518 auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1519 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1520 lhs_contracting_dims, lhs.state().collective_ops_creator,
1521 MakeBinaryAdd(output_base_shape.element_type(), module));
1522 ar->set_sharding(HloSharding::Replicate());
1523 return PartitionedHlo(ar, output_base_shape, lhs.state())
1524 .Reshard(output_sharding)
1525 .hlo();
1526 }
1527
1528 // LHS and output have the same partitioned non-contracting dimensions.
1529 if (lhs_non_contracting_partitions == num_partitions &&
1530 output_lhs_non_contracting_partitions == num_partitions &&
1531 lhs_sharding_transposed_to_match_output == output_sharding) {
1532 auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1533 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1534 b, conv_window));
1535 return dot;
1536 }
1537
1538 // RHS and output have the same partitioned non-contracting dimensions.
1539 if (rhs_non_contracting_partitions == num_partitions &&
1540 output_rhs_non_contracting_partitions == num_partitions &&
1541 rhs_sharding_transposed_to_match_output == output_sharding) {
1542 auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1543 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1544 b, conv_window));
1545 return dot;
1546 }
1547
1548 if (may_reshard_without_detecting_match) {
1549 // Output is batch partitioned.
1550 if (output_batch_partitions == num_partitions) {
1551 auto resharded_lhs =
1552 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1553 auto resharded_rhs =
1554 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1555 TF_ASSIGN_OR_RETURN(
1556 auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1557 b, conv_window));
1558 return dot;
1559 }
1560 // Output is partitioned along LHS non-contracting dimensions.
1561 if (output_lhs_non_contracting_partitions == num_partitions) {
1562 auto resharded_lhs =
1563 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1564 auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1565 TF_ASSIGN_OR_RETURN(
1566 auto dot, create_sharded_dot(resharded_lhs.hlo(),
1567 replicated_rhs.hlo(), b, conv_window));
1568 return dot;
1569 }
1570 // Output is partitioned along RHS non-contracting dimensions.
1571 if (output_rhs_non_contracting_partitions == num_partitions) {
1572 auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1573 auto resharded_rhs =
1574 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1575 TF_ASSIGN_OR_RETURN(
1576 auto dot, create_sharded_dot(replicated_lhs.hlo(),
1577 resharded_rhs.hlo(), b, conv_window));
1578 return dot;
1579 }
1580 }
1581
1582 // Returns true if it is beneficial to reshard the operand at `operand_idx`
1583 // across the contracting dimension.
1584 const auto should_partition_contracting_dim = [&](int64 operand_idx) {
1585 if (!output_sharding.IsReplicated()) {
1586 return false;
1587 }
1588
1589 if (operand_idx == 0) {
1590 // If LHS and output are replicated, we compare the cost of all-gather
1591 // on RHS vs all-reduce on the output.
1592 return (rhs_contracting_partitions == num_partitions) &&
1593 lhs.sharding().IsReplicated() &&
1594 ShapeUtil::ElementsIn(rhs.base_shape()) >
1595 ShapeUtil::ElementsIn(output_base_shape);
1596 } else {
1597 return (lhs_contracting_partitions == num_partitions) &&
1598 rhs.sharding().IsReplicated() &&
1599 ShapeUtil::ElementsIn(lhs.base_shape()) >
1600 ShapeUtil::ElementsIn(output_base_shape);
1601 }
1602 };
1603
1604 // When the output is replicated and one of the operands is partitioned along
1605 // contracting dimension, align the other operand to be partitioned along
1606 // the contracting dimensions.
1607 if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1608 should_partition_contracting_dim(1))) {
1609 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1610 LiteralUtil::Zero(output_base_shape.element_type())));
1611 if (should_partition_contracting_dim(0)) {
1612 lhs =
1613 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1614 rhs = rhs.PadWithValue(zero);
1615 } else {
1616 lhs = lhs.PadWithValue(zero);
1617 rhs =
1618 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1619 }
1620 TF_ASSIGN_OR_RETURN(
1621 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1622
1623 std::vector<int64> lhs_contracting_dims;
1624 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1625 for (const auto& cd : dims_mapping.contracting_dims) {
1626 lhs_contracting_dims.push_back(cd.lhs);
1627 }
1628 return lhs.state().partitioner->AllReduceAlongShardingDims(
1629 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1630 lhs_contracting_dims, lhs.state().collective_ops_creator,
1631 MakeBinaryAdd(output_base_shape.element_type(), module));
1632 }
1633 return nullptr;
1634 }
1635
1636 StatusOr<HloInstruction*> PartitionDot(
1637 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1638 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1639 int64 num_partitions,
1640 const std::function<StatusOr<HloInstruction*>(
1641 HloInstruction*, HloInstruction*, SpmdBuilder*,
1642 const Window& conv_window)>& create_sharded_dot,
1643 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1644 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1645 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1646 windowed_dot_general_loops);
1647
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1648 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
1649 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1650 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1651 int64 num_partitions, int64 lhs_contracting_partitions,
1652 int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
1653 int64 rhs_non_contracting_partitions,
1654 const std::function<StatusOr<HloInstruction*>(
1655 HloInstruction*, HloInstruction*, SpmdBuilder*,
1656 const Window& conv_window)>& create_sharded_dot,
1657 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1658 bool require_matching_devices_to_group,
1659 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1660 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1661 windowed_dot_general_loops) {
1662 std::vector<std::pair<HloInstruction*, HloSharding>>
1663 top_level_sharding_to_reset;
1664 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1665 for (auto& to_reset : top_level_sharding_to_reset) {
1666 to_reset.first->set_sharding(to_reset.second);
1667 }
1668 });
1669 std::vector<int64> lhs_dims;
1670 std::vector<int64> rhs_dims;
1671 std::vector<int64> output_dims;
1672 auto lhs_sharding_dims_adjusted_to_output =
1673 lhs.sharding().IsReplicated()
1674 ? std::vector<int64>(lhs.base_shape().rank(), 1)
1675 : lhs.sharding().tile_assignment().dimensions();
1676 auto rhs_sharding_dims_adjusted_to_output =
1677 rhs.sharding().IsReplicated()
1678 ? std::vector<int64>(rhs.base_shape().rank(), 1)
1679 : rhs.sharding().tile_assignment().dimensions();
1680 auto output_sharding_dims_adjusted_to_lhs =
1681 output_sharding.tile_assignment().dimensions();
1682 bool lhs_rhs_dims_matching = true;
1683 for (const auto& dim : dims_mapping.batch_dims) {
1684 lhs_dims.push_back(dim.lhs);
1685 rhs_dims.push_back(dim.rhs);
1686 output_dims.push_back(dim.output);
1687 if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
1688 rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
1689 lhs_rhs_dims_matching = false;
1690 }
1691 lhs_sharding_dims_adjusted_to_output[dim.lhs] =
1692 output_sharding.tile_assignment().dim(dim.output);
1693 rhs_sharding_dims_adjusted_to_output[dim.rhs] =
1694 output_sharding.tile_assignment().dim(dim.output);
1695 output_sharding_dims_adjusted_to_lhs[dim.output] =
1696 lhs.sharding().tile_assignment().dim(dim.lhs);
1697 }
1698 if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
1699 lhs_rhs_dims_matching =
1700 rhs.sharding() == UngroupSharding(AlignGroupsWith(
1701 GroupShardingOnDims(rhs.sharding(), rhs_dims),
1702 GroupShardingOnDims(lhs.sharding(), lhs_dims)));
1703 }
1704 auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
1705 PartitionedHlo per_group_lhs = lhs;
1706 PartitionedHlo per_group_rhs = rhs;
1707 if (lhs_rhs_dims_matching) {
1708 auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
1709 auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
1710 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
1711 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
1712 rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
1713 rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
1714 } else {
1715 lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
1716 lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
1717 }
1718 auto reshaped_output_tiling = output_sharding.tile_assignment();
1719 reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
1720 output_grouped = AlignGroupsWith(
1721 GroupShardingOnDims(
1722 output_sharding.ReplicateOnLastTileDim()
1723 ? HloSharding::PartialTile(reshaped_output_tiling)
1724 : HloSharding::Tile(reshaped_output_tiling),
1725 output_dims),
1726 lhs_grouped);
1727 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1728 lhs.state(), lhs_grouped.device_groups, b);
1729 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
1730 lhs.hlo()->set_sharding(lhs_grouped.sharding);
1731 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
1732 rhs.hlo()->set_sharding(rhs_grouped.sharding);
1733 CHECK(lhs.hlo() != rhs.hlo() ||
1734 lhs_grouped.sharding == rhs_grouped.sharding);
1735 per_group_lhs = PartitionedHlo(
1736 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
1737 per_group_partitioner_state);
1738 per_group_rhs = PartitionedHlo(
1739 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
1740 per_group_partitioner_state);
1741 } else {
1742 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1743 lhs.state(), output_grouped.device_groups, b);
1744 auto reshard_to_output_batch =
1745 [&](PartitionedHlo operand, absl::Span<const int64> batch_dims,
1746 absl::Span<const int64> contracting_dims,
1747 absl::Span<const int64> non_contracting_dims,
1748 int64 contracting_dim_partitions,
1749 int64 non_contracting_dim_partitions,
1750 int64 other_contracting_dim_partitions,
1751 std::vector<int64>* sharding_dims_adjusted_to_output)
1752 -> absl::optional<PartitionedHlo> {
1753 if (operand.sharding().IsTileMaximal()) {
1754 auto partially_sharded = PerGroupSliceFromReplicated(
1755 operand.Replicate().hlo(), operand.state().partition_id,
1756 output_grouped.device_groups, batch_dims,
1757 output_grouped.group_dim_sizes, b);
1758 partially_sharded->set_sharding(HloSharding::Replicate());
1759 return PartitionedHlo(partially_sharded, partially_sharded->shape(),
1760 per_group_partitioner_state);
1761 }
1762 auto reshaped_tiling = operand.sharding().tile_assignment();
1763 // It's possible that the operand is not initially sharded on batch
1764 // dimensions in the same way as the output, although being tiled. In that
1765 // case, the current sharding_dims_adjusted_to_output may contain more
1766 // partitions than available devices. We remove partitioning on other
1767 // dimensions.
1768 if (Product(*sharding_dims_adjusted_to_output) >
1769 reshaped_tiling.num_elements()) {
1770 if (Product(*sharding_dims_adjusted_to_output) %
1771 reshaped_tiling.num_elements() !=
1772 0) {
1773 return absl::nullopt;
1774 }
1775 int64 ratio = Product(*sharding_dims_adjusted_to_output) /
1776 reshaped_tiling.num_elements();
1777 if (operand.sharding().ReplicateOnLastTileDim() &&
1778 reshaped_tiling.dimensions().back() % ratio == 0) {
1779 sharding_dims_adjusted_to_output->back() /= ratio;
1780 if (sharding_dims_adjusted_to_output->back() == 1) {
1781 sharding_dims_adjusted_to_output->pop_back();
1782 }
1783 } else if (ratio == non_contracting_dim_partitions &&
1784 (ratio != contracting_dim_partitions ||
1785 contracting_dim_partitions ==
1786 other_contracting_dim_partitions)) {
1787 for (int64 dim : non_contracting_dims) {
1788 (*sharding_dims_adjusted_to_output)[dim] = 1;
1789 }
1790 } else if (ratio == contracting_dim_partitions) {
1791 for (int64 dim : contracting_dims) {
1792 (*sharding_dims_adjusted_to_output)[dim] = 1;
1793 }
1794 } else {
1795 return absl::nullopt;
1796 }
1797 }
1798 // If the operand is initially sharded more ways than the output in the
1799 // batch dimensions, sharding_dims_adjusted_to_output currently contains
1800 // fewer partitions than available devices. We do not handle this case.
1801 if (Product(*sharding_dims_adjusted_to_output) <
1802 reshaped_tiling.num_elements()) {
1803 return absl::nullopt;
1804 }
1805 reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
1806 auto grouped = AlignGroupsWith(
1807 GroupShardingOnDims(operand.base_shape().rank() <
1808 sharding_dims_adjusted_to_output->size()
1809 ? HloSharding::PartialTile(reshaped_tiling)
1810 : HloSharding::Tile(reshaped_tiling),
1811 batch_dims),
1812 output_grouped);
1813 if (require_matching_devices_to_group &&
1814 operand.sharding() != UngroupSharding(grouped)) {
1815 return absl::nullopt;
1816 }
1817 auto resharded = operand.Reshard(UngroupSharding(grouped));
1818 top_level_sharding_to_reset.emplace_back(resharded.hlo(),
1819 resharded.sharding());
1820 resharded.hlo()->set_sharding(grouped.sharding);
1821 return PartitionedHlo(resharded.hlo(),
1822 GetPerGroupBaseShape(grouped, operand.base_shape()),
1823 per_group_partitioner_state);
1824 };
1825 std::vector<int64> lhs_contracting_dims;
1826 std::vector<int64> rhs_contracting_dims;
1827 lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1828 rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1829 for (const auto& dim : dims_mapping.contracting_dims) {
1830 lhs_contracting_dims.push_back(dim.lhs);
1831 rhs_contracting_dims.push_back(dim.rhs);
1832 }
1833 std::vector<int64> lhs_non_contracting_dims;
1834 std::vector<int64> rhs_non_contracting_dims;
1835 lhs_non_contracting_dims.reserve(
1836 dims_mapping.lhs_non_contracting_dims.size());
1837 rhs_non_contracting_dims.reserve(
1838 dims_mapping.rhs_non_contracting_dims.size());
1839 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
1840 lhs_non_contracting_dims.push_back(dim.lhs);
1841 }
1842 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
1843 rhs_non_contracting_dims.push_back(dim.rhs);
1844 }
1845 if (auto resharded = reshard_to_output_batch(
1846 lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
1847 lhs_contracting_partitions, lhs_non_contracting_partitions,
1848 rhs_contracting_partitions,
1849 &lhs_sharding_dims_adjusted_to_output)) {
1850 per_group_lhs = *resharded;
1851 } else {
1852 return nullptr;
1853 }
1854 if (auto resharded = reshard_to_output_batch(
1855 rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
1856 rhs_contracting_partitions, rhs_non_contracting_partitions,
1857 lhs_contracting_partitions,
1858 &rhs_sharding_dims_adjusted_to_output)) {
1859 per_group_rhs = *resharded;
1860 } else {
1861 return nullptr;
1862 }
1863 CHECK(lhs.hlo() != rhs.hlo() ||
1864 per_group_lhs.sharding() == per_group_rhs.sharding());
1865 }
1866 TF_ASSIGN_OR_RETURN(
1867 auto dot,
1868 PartitionDot(per_group_lhs, per_group_rhs,
1869 GetPerGroupBaseShape(output_grouped, output_base_shape),
1870 output_grouped.sharding, dims_mapping,
1871 num_partitions / output_grouped.device_groups.size(),
1872 create_sharded_dot, conv_window, module, original_hlo,
1873 options, b, windowed_dot_general_loops));
1874 dot->set_sharding(UngroupSharding(output_grouped));
1875 return PartitionedHlo(dot, output_base_shape, lhs.state())
1876 .Reshard(output_sharding)
1877 .hlo();
1878 }
1879
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)1880 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
1881 bool lhs_matching, const HloSharding& matching_sharding,
1882 const HloSharding& output_sharding,
1883 absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
1884 std::vector<int64> matching_sharding_dims =
1885 matching_sharding.tile_assignment().dimensions();
1886 std::vector<int64> matching_dims;
1887 std::vector<int64> output_dims;
1888 // Make sure the partitioning on matching's non-contracting dimensions
1889 // defines the same device groups for both matching and output.
1890 for (const auto& dim : partitioned_dims) {
1891 int64 md = lhs_matching ? dim.lhs : dim.rhs;
1892 matching_sharding_dims[md] =
1893 output_sharding.tile_assignment().dim(dim.output);
1894 matching_dims.push_back(md);
1895 output_dims.push_back(dim.output);
1896 }
1897 GroupedSharding output_grouped =
1898 GroupShardingOnDims(output_sharding, output_dims);
1899 Array<int64> reshaped_matching_tiling = matching_sharding.tile_assignment();
1900 reshaped_matching_tiling.Reshape(matching_sharding_dims);
1901 return AlignGroupsWith(
1902 GroupShardingOnDims(
1903 matching_sharding.ReplicateOnLastTileDim()
1904 ? HloSharding::PartialTile(reshaped_matching_tiling)
1905 : HloSharding::Tile(reshaped_matching_tiling),
1906 matching_dims),
1907 output_grouped);
1908 }
1909
1910 absl::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64 other_contracting_partitions,int64 other_non_contracting_partitions,int64 matching_contracting_partitions,int64 output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)1911 GetNonContractingPartitionGroupedShardingForOtherOperand(
1912 bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
1913 int64 other_contracting_partitions, int64 other_non_contracting_partitions,
1914 int64 matching_contracting_partitions,
1915 int64 output_other_non_contracting_partitions,
1916 const HloSharding& other_sharding, const HloSharding& output_sharding,
1917 absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
1918 absl::Span<const DotConvDimsMapping::DimsMapping>
1919 other_non_contracting_dims,
1920 absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
1921 int64 group_count = 1;
1922 std::vector<int64> output_dims;
1923 for (const auto& dim : matching_partitioned_dims) {
1924 output_dims.push_back(dim.output);
1925 group_count *= output_sharding.tile_assignment().dim(dim.output);
1926 }
1927 GroupedSharding output_grouped =
1928 GroupShardingOnDims(output_sharding, output_dims);
1929 std::vector<int64> other_group_dims;
1930 if (other_sharding.ReplicateOnLastTileDim() &&
1931 other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
1932 other_group_dims.push_back(
1933 other_sharding.tile_assignment().num_dimensions() - 1);
1934 } else {
1935 const bool may_replicate_other_contracting_dims =
1936 (other_contracting_partitions == group_count &&
1937 other_non_contracting_partitions ==
1938 output_other_non_contracting_partitions);
1939 const bool may_replicate_other_non_contracting_dims =
1940 group_count == other_non_contracting_partitions &&
1941 matching_contracting_partitions == other_contracting_partitions;
1942 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
1943 other_sharding, output_grouped.device_groups)) {
1944 other_group_dims = std::move(*found_dims);
1945 } else if (may_replicate_other_contracting_dims &&
1946 (!may_replicate_other_non_contracting_dims ||
1947 ShapeUtil::ByteSizeOf(other_shape)) <=
1948 ShapeUtil::ByteSizeOf(MakePartitionedShape(
1949 output_base_shape, output_sharding))) {
1950 for (const auto& dim : other_contracting_dims) {
1951 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
1952 }
1953 } else if (may_replicate_other_non_contracting_dims) {
1954 for (const auto& dim : other_non_contracting_dims) {
1955 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
1956 }
1957 } else {
1958 return absl::nullopt;
1959 }
1960 }
1961 if (other_group_dims.size() == 1 &&
1962 other_group_dims[0] ==
1963 other_sharding.tile_assignment().num_dimensions() - 1) {
1964 return AlignGroupsWith(
1965 GroupShardingOnDims(
1966 other_sharding, {other_group_dims[0]},
1967 {other_sharding.tile_assignment().dimensions().back() /
1968 group_count}),
1969 output_grouped, /*ignore_group_order=*/true);
1970
1971 } else if (!other_sharding.IsReplicated()) {
1972 return AlignGroupsWith(
1973 GroupShardingOnDims(other_sharding, other_group_dims), output_grouped,
1974 /*ignore_group_order=*/true);
1975 }
1976 return absl::nullopt;
1977 }
1978
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64 matching_contracting_partitions,int64 other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64 other_non_contracting_partitions,int64 output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1979 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
1980 bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
1981 int64 matching_contracting_partitions, int64 other_contracting_partitions,
1982 absl::Span<const DotConvDimsMapping::DimsMapping>
1983 partitioned_non_contracting_dims,
1984 int64 other_non_contracting_partitions,
1985 int64 output_other_non_contracting_partitions,
1986 const Shape& output_base_shape, const HloSharding& output_sharding,
1987 const DotConvDimsMapping& dims_mapping, int64 num_partitions,
1988 const std::function<StatusOr<HloInstruction*>(
1989 HloInstruction*, HloInstruction*, SpmdBuilder*,
1990 const Window& conv_window)>& create_sharded_dot,
1991 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1992 bool require_matching_devices_to_group,
1993 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1994 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1995 windowed_dot_general_loops) {
1996 std::vector<std::pair<HloInstruction*, HloSharding>>
1997 top_level_sharding_to_reset;
1998 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1999 for (auto& to_reset : top_level_sharding_to_reset) {
2000 to_reset.first->set_sharding(to_reset.second);
2001 }
2002 });
2003
2004 std::vector<int64> output_dims;
2005 for (const auto& dim : partitioned_non_contracting_dims) {
2006 output_dims.push_back(dim.output);
2007 }
2008 GroupedSharding output_grouped =
2009 GroupShardingOnDims(output_sharding, output_dims);
2010 GroupedSharding matching_grouped =
2011 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2012 lhs_matching, matching.sharding(), output_sharding,
2013 partitioned_non_contracting_dims);
2014 if (require_matching_devices_to_group &&
2015 matching.sharding() != UngroupSharding(matching_grouped)) {
2016 return nullptr;
2017 }
2018 absl::optional<GroupedSharding> other_grouped =
2019 GetNonContractingPartitionGroupedShardingForOtherOperand(
2020 lhs_matching, output_base_shape, other.hlo()->shape(),
2021 other_contracting_partitions, other_non_contracting_partitions,
2022 matching_contracting_partitions,
2023 output_other_non_contracting_partitions, other.sharding(),
2024 output_sharding, partitioned_non_contracting_dims,
2025 lhs_matching ? dims_mapping.rhs_non_contracting_dims
2026 : dims_mapping.lhs_non_contracting_dims,
2027 dims_mapping.contracting_dims);
2028
2029 if (!other_grouped) {
2030 other = other.Replicate();
2031 }
2032 matching = matching.Reshard(UngroupSharding(matching_grouped));
2033 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2034 matching.state(), matching_grouped.device_groups, b);
2035 top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2036 matching.hlo()->set_sharding(matching_grouped.sharding);
2037 auto matching_p = PartitionedHlo(
2038 matching.hlo(),
2039 GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2040 per_group_partitioner_state);
2041
2042 auto partially_replicated_other = other.hlo();
2043 if (other_grouped && other_grouped->group_dims.size() == 1 &&
2044 other_grouped->group_dims[0] == other.base_shape().rank()) {
2045 // Group on replication dim.
2046 other = other.Reshard(UngroupSharding(*other_grouped));
2047 partially_replicated_other = other.hlo();
2048 top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2049 partially_replicated_other->set_sharding(other_grouped->sharding);
2050 } else if (!other.sharding().IsReplicated()) {
2051 other = other.Reshard(UngroupSharding(*other_grouped));
2052 partially_replicated_other =
2053 other
2054 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2055 other.sharding(), other_grouped->group_dims))
2056 .hlo();
2057 top_level_sharding_to_reset.emplace_back(
2058 partially_replicated_other, partially_replicated_other->sharding());
2059 partially_replicated_other->set_sharding(other_grouped->sharding);
2060 }
2061 auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2062 per_group_partitioner_state);
2063 TF_ASSIGN_OR_RETURN(
2064 auto dot,
2065 PartitionDot(lhs_matching ? matching_p : other_p,
2066 lhs_matching ? other_p : matching_p,
2067 GetPerGroupBaseShape(output_grouped, output_base_shape),
2068 output_grouped.sharding, dims_mapping,
2069 num_partitions / matching_grouped.device_groups.size(),
2070 create_sharded_dot, conv_window, module, original_hlo,
2071 options, b, windowed_dot_general_loops));
2072 return dot;
2073 }
2074
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contractin_dims,int64 output_batch_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2075 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2076 PartitionedHlo lhs, PartitionedHlo rhs,
2077 absl::Span<const DotConvDimsMapping::DimsMapping>
2078 partitioned_contractin_dims,
2079 int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions,
2080 int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape,
2081 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2082 int64 num_partitions,
2083 const std::function<StatusOr<HloInstruction*>(
2084 HloInstruction*, HloInstruction*, SpmdBuilder*,
2085 const Window& conv_window)>& create_sharded_dot,
2086 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2087 bool require_matching_devices_to_group,
2088 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2089 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2090 windowed_dot_general_loops) {
2091 std::vector<std::pair<HloInstruction*, HloSharding>>
2092 top_level_sharding_to_reset;
2093 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2094 for (auto& to_reset : top_level_sharding_to_reset) {
2095 to_reset.first->set_sharding(to_reset.second);
2096 }
2097 });
2098 auto lhs_sharding = lhs.sharding();
2099 auto rhs_sharding = rhs.sharding();
2100 auto lhs_tile_shape = lhs_sharding.tile_assignment().dimensions();
2101 auto rhs_tile_shape = rhs_sharding.tile_assignment().dimensions();
2102 std::vector<int64> lhs_dims;
2103 std::vector<int64> rhs_dims;
2104 int64 group_count = 1;
2105 for (const auto& dim : partitioned_contractin_dims) {
2106 lhs_dims.push_back(dim.lhs);
2107 rhs_dims.push_back(dim.rhs);
2108 group_count *= lhs_sharding.tile_assignment().dim(dim.lhs);
2109 }
2110 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2111 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2112 for (const auto& dim : partitioned_contractin_dims) {
2113 rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2114 }
2115 auto new_tile = rhs.sharding().tile_assignment();
2116 new_tile.Reshape(rhs_tile_shape);
2117 rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2118 ? HloSharding::PartialTile(new_tile)
2119 : HloSharding::Tile(new_tile);
2120 } else {
2121 for (const auto& dim : partitioned_contractin_dims) {
2122 lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2123 }
2124 auto new_tile = lhs.sharding().tile_assignment();
2125 new_tile.Reshape(lhs_tile_shape);
2126 lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2127 ? HloSharding::PartialTile(new_tile)
2128 : HloSharding::Tile(new_tile);
2129 }
2130 auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2131 auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2132 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2133 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2134 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2135 rhs_sharding = UngroupSharding(rhs_grouped);
2136 if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2137 return nullptr;
2138 }
2139 rhs = rhs.Reshard(rhs_sharding);
2140 } else {
2141 lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2142 lhs_sharding = UngroupSharding(lhs_grouped);
2143 if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2144 return nullptr;
2145 }
2146 lhs = lhs.Reshard(lhs_sharding);
2147 }
2148 // Mask out invalid data.
2149 std::vector<int64> lhs_skipped_dims;
2150 for (int64 i = 0; i < lhs.base_shape().rank(); ++i) {
2151 if (absl::c_linear_search(lhs_dims, i)) {
2152 continue;
2153 }
2154 lhs_skipped_dims.push_back(i);
2155 }
2156 lhs = lhs.PadWithValue(
2157 CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
2158 /*left_padded_dims=*/{}, lhs_skipped_dims);
2159 std::vector<int64> rhs_skipped_dims;
2160 for (int64 i = 0; i < rhs.base_shape().rank(); ++i) {
2161 if (absl::c_linear_search(rhs_dims, i)) {
2162 continue;
2163 }
2164 rhs_skipped_dims.push_back(i);
2165 }
2166 rhs = rhs.PadWithValue(
2167 CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
2168 /*left_padded_dims=*/{}, rhs_skipped_dims);
2169 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2170 lhs.hlo()->set_sharding(lhs_grouped.sharding);
2171 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2172 rhs.hlo()->set_sharding(rhs_grouped.sharding);
2173
2174 HloSharding inner_output_sharding = HloSharding::Replicate();
2175 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2176 if (output_sharding.ReplicateOnLastTileDim() &&
2177 output_sharding.tile_assignment().dimensions().back() % group_count ==
2178 0) {
2179 auto grouped = AlignGroupsWith(
2180 GroupShardingOnDims(
2181 output_sharding,
2182 {output_sharding.tile_assignment().num_dimensions() - 1},
2183 {output_sharding.tile_assignment().dimensions().back() /
2184 group_count}),
2185 lhs_grouped,
2186 /*ignore_group_order=*/true);
2187 outer_output_tmp_sharding = UngroupSharding(grouped);
2188 inner_output_sharding = std::move(grouped.sharding);
2189 } else {
2190 std::vector<int64> group_dims;
2191 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2192 output_sharding, lhs_grouped.device_groups)) {
2193 group_dims = std::move(*found_dims);
2194 } else if (output_lhs_non_contracting_partitions == group_count ||
2195 output_rhs_non_contracting_partitions == group_count ||
2196 output_batch_partitions == group_count) {
2197 if (output_lhs_non_contracting_partitions == group_count) {
2198 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2199 group_dims.push_back(dim.output);
2200 }
2201 } else if (output_rhs_non_contracting_partitions == group_count) {
2202 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2203 group_dims.push_back(dim.output);
2204 }
2205 } else {
2206 for (const auto& dim : dims_mapping.batch_dims) {
2207 group_dims.push_back(dim.output);
2208 }
2209 }
2210 }
2211 if (!group_dims.empty()) {
2212 auto grouped = AlignGroupsWith(
2213 GroupShardingOnDims(output_sharding, group_dims), lhs_grouped);
2214 inner_output_sharding = grouped.sharding;
2215 outer_output_tmp_sharding =
2216 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2217 UngroupSharding(grouped), group_dims);
2218 }
2219 }
2220 auto inner_state = CreatePerGroupPartitioningState(
2221 lhs.state(), lhs_grouped.device_groups, b);
2222 TF_ASSIGN_OR_RETURN(
2223 auto dot,
2224 PartitionDot(
2225 PartitionedHlo(lhs.hlo(),
2226 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2227 inner_state),
2228 PartitionedHlo(rhs.hlo(),
2229 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2230 inner_state),
2231 output_base_shape, inner_output_sharding, dims_mapping,
2232 num_partitions / group_count, create_sharded_dot, conv_window, module,
2233 original_hlo, options, b, windowed_dot_general_loops));
2234 if (!dot) {
2235 return nullptr;
2236 }
2237 auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2238 b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2239 lhs.state().collective_ops_creator,
2240 MakeBinaryAdd(output_base_shape.element_type(), module));
2241 ar->set_sharding(outer_output_tmp_sharding);
2242 return PartitionedHlo(ar, output_base_shape, lhs.state())
2243 .Reshard(output_sharding)
2244 .hlo();
2245 }
2246
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2247 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2248 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2249 const auto& dnums = original_hlo->convolution_dimension_numbers();
2250 DotConvDimsMapping new_dims_mapping;
2251 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2252 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2253 // Append batch dims.
2254 new_dims_mapping.batch_dims.emplace_back();
2255 new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2256 new_dims_mapping.batch_dims.back().rhs =
2257 dnums.kernel_output_feature_dimension();
2258 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2259 new_dims_mapping.batch_dims.back().spatial = -1;
2260 // Setup non contracting dims.
2261 new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2262 new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2263 dnums.input_batch_dimension();
2264 new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2265 new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2266 dnums.kernel_input_feature_dimension();
2267 return new_dims_mapping;
2268 }
2269
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2270 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2271 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2272 const auto& dnums = original_hlo->convolution_dimension_numbers();
2273 DotConvDimsMapping new_dims_mapping;
2274 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2275 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2276 new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2277 // Append batch dims.
2278 new_dims_mapping.batch_dims.emplace_back();
2279 new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2280 new_dims_mapping.batch_dims.back().rhs =
2281 dnums.kernel_output_feature_dimension();
2282 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2283 new_dims_mapping.batch_dims.back().spatial = -1;
2284 return new_dims_mapping;
2285 }
2286
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64 num_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,int64 lhs_batch_partitions,int64 rhs_batch_partitions,bool may_group_on_lhs_non_contracting,bool may_group_on_rhs_non_contracting)2287 bool LhsIsBestMatchForNonContractingPartitioning(
2288 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2289 const PartitionedHlo& rhs, const Shape& output_base_shape,
2290 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2291 int64 num_partitions, int64 lhs_non_contracting_partitions,
2292 int64 rhs_non_contracting_partitions, int64 lhs_contracting_partitions,
2293 int64 rhs_contracting_partitions,
2294 int64 output_lhs_non_contracting_partitions,
2295 int64 output_rhs_non_contracting_partitions, int64 lhs_batch_partitions,
2296 int64 rhs_batch_partitions, bool may_group_on_lhs_non_contracting,
2297 bool may_group_on_rhs_non_contracting) {
2298 // If both match output non-contracting dimensions, choose the one which
2299 // will result in smaller replication of the other operand.
2300 bool lhs_matching = may_group_on_lhs_non_contracting &&
2301 (!may_group_on_rhs_non_contracting ||
2302 lhs_non_contracting_partitions *
2303 ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
2304 rhs_non_contracting_partitions *
2305 ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
2306 // If both grouping are available and the option to choose faster windowed
2307 // einsums vs saving memory is enabled then try to determine which of the
2308 // operands will generate the least amount of iterations for the windowed
2309 // einsum when matched (if a windowed einsum is gonna be generated at all).
2310 if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
2311 options.choose_faster_windowed_einsum_over_mem) {
2312 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2313 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2314 output_base_shape.rank());
2315 auto subsequent_einsum_iterations_estimate =
2316 [&](bool assume_lhs_match) -> absl::optional<int64> {
2317 const std::vector<DotConvDimsMapping::DimsMapping>&
2318 matching_non_contracting_dims =
2319 assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2320 : dims_mapping.rhs_non_contracting_dims;
2321 const std::vector<DotConvDimsMapping::DimsMapping>&
2322 other_non_contracting_dims =
2323 assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2324 : dims_mapping.lhs_non_contracting_dims;
2325 const std::vector<int64>& output_to_matching_indices =
2326 assume_lhs_match ? indices_map.output_to_lhs_indices
2327 : indices_map.output_to_rhs_indices;
2328 const std::vector<int64>& output_to_other_indices =
2329 assume_lhs_match ? indices_map.output_to_rhs_indices
2330 : indices_map.output_to_lhs_indices;
2331 const std::vector<int64>& matching_to_output_indices =
2332 assume_lhs_match ? indices_map.lhs_to_output_indices
2333 : indices_map.rhs_to_output_indices;
2334 const std::vector<int64>& other_to_output_indices =
2335 assume_lhs_match ? indices_map.rhs_to_output_indices
2336 : indices_map.lhs_to_output_indices;
2337 const HloSharding& matching_sharding =
2338 assume_lhs_match ? lhs.sharding() : rhs.sharding();
2339 const HloSharding& other_sharding =
2340 assume_lhs_match ? rhs.sharding() : lhs.sharding();
2341 const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2342 const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2343 const int64 matching_non_contracting_partitions =
2344 assume_lhs_match ? lhs_non_contracting_partitions
2345 : rhs_non_contracting_partitions;
2346 const int64 other_non_contracting_partitions =
2347 assume_lhs_match ? rhs_non_contracting_partitions
2348 : lhs_non_contracting_partitions;
2349 const int64 matching_contracting_partitions =
2350 assume_lhs_match ? lhs_contracting_partitions
2351 : rhs_contracting_partitions;
2352 const int64 other_contracting_partitions =
2353 assume_lhs_match ? rhs_contracting_partitions
2354 : lhs_contracting_partitions;
2355 const int64 output_matching_non_contracting_partitions =
2356 assume_lhs_match ? output_lhs_non_contracting_partitions
2357 : output_rhs_non_contracting_partitions;
2358 const int64 output_other_non_contracting_partitions =
2359 assume_lhs_match ? output_rhs_non_contracting_partitions
2360 : output_lhs_non_contracting_partitions;
2361 const int64 matching_batch_partitions =
2362 assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2363 const int64 other_batch_partitions =
2364 assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2365 std::vector<int64> output_dims;
2366 output_dims.reserve(matching_non_contracting_dims.size());
2367 for (const DotConvDimsMapping::DimsMapping& dim :
2368 matching_non_contracting_dims) {
2369 output_dims.push_back(dim.output);
2370 }
2371 GroupedSharding output_grouped =
2372 GroupShardingOnDims(output_sharding, output_dims);
2373 GroupedSharding matching_grouped =
2374 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2375 assume_lhs_match, matching_sharding, output_sharding,
2376 matching_non_contracting_dims);
2377 absl::optional<GroupedSharding> other_grouped =
2378 GetNonContractingPartitionGroupedShardingForOtherOperand(
2379 assume_lhs_match, output_base_shape,
2380 other_partitioned.hlo()->shape(), other_contracting_partitions,
2381 other_non_contracting_partitions, matching_contracting_partitions,
2382 output_other_non_contracting_partitions, other_sharding,
2383 output_sharding, matching_non_contracting_dims,
2384 other_non_contracting_dims, dims_mapping.contracting_dims);
2385 absl::optional<HloSharding> output_sharding_transposed_to_match_matching =
2386 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2387 output_grouped.sharding, output_to_matching_indices,
2388 matching_to_output_indices);
2389 absl::optional<HloSharding> output_sharding_transposed_to_match_other =
2390 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2391 output_grouped.sharding, output_to_other_indices,
2392 other_to_output_indices);
2393 const int64 new_num_partitions =
2394 num_partitions / matching_non_contracting_partitions;
2395 const int64 output_sharding_dim = FirstShardingDimWithPartitionOfSize(
2396 new_num_partitions, output_grouped.sharding);
2397 absl::optional<WindowedEinsumConfig> e_config =
2398 GetWindowedEinsumConfiguration(
2399 new_num_partitions, output_matching_non_contracting_partitions,
2400 output_other_non_contracting_partitions, 1,
2401 other_non_contracting_partitions, other_batch_partitions,
2402 matching_contracting_partitions, 1, matching_batch_partitions,
2403 output_sharding_dim,
2404 ShapeSizeInBytes(other_partitioned.base_shape()),
2405 ShapeSizeInBytes(matching_partitioned.base_shape()) /
2406 matching_non_contracting_partitions,
2407 ShapeSizeInBytes(
2408 GetPerGroupBaseShape(output_grouped, output_base_shape)),
2409 options.threshold_for_windowed_einsum_mib,
2410 output_sharding_transposed_to_match_matching,
2411 output_sharding_transposed_to_match_other,
2412 matching_grouped.sharding, other_grouped->sharding);
2413 return e_config ? new_num_partitions
2414 : absl::optional<int64>(absl::nullopt);
2415 };
2416 absl::optional<int64> lhs_matching_iterations =
2417 subsequent_einsum_iterations_estimate(true);
2418 absl::optional<int64> rhs_matching_iterations =
2419 subsequent_einsum_iterations_estimate(false);
2420 if (lhs_matching_iterations && rhs_matching_iterations &&
2421 *lhs_matching_iterations != *rhs_matching_iterations) {
2422 lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
2423 }
2424 }
2425 return lhs_matching;
2426 }
2427
2428 // Recursive partitioning function. If there are partial dimensions matching in
2429 // the operands and output, group the devices and recursively partition the
2430 // in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2431 StatusOr<HloInstruction*> PartitionDot(
2432 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2433 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2434 int64 num_partitions,
2435 const std::function<StatusOr<HloInstruction*>(
2436 HloInstruction*, HloInstruction*, SpmdBuilder*,
2437 const Window& conv_window)>& create_sharded_dot,
2438 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2439 bool require_matching_devices_to_group,
2440 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2441 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2442 windowed_dot_general_loops) {
2443 // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
2444 if (lhs.hlo() == rhs.hlo()) {
2445 auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
2446 rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
2447 copy_hlo->set_sharding(rhs.sharding());
2448 rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
2449 }
2450
2451 // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
2452 auto get_partitions_for_dims =
2453 [&](const HloSharding& sharding,
2454 absl::Span<const DotConvDimsMapping::DimsMapping> dims,
2455 int lhs_rhs_or_output) {
2456 int64 partitions = 1;
2457 if (sharding.IsTileMaximal()) {
2458 return partitions;
2459 }
2460 for (const auto& dim : dims) {
2461 if (lhs_rhs_or_output == 0) {
2462 partitions *= sharding.tile_assignment().dim(dim.lhs);
2463 } else if (lhs_rhs_or_output == 1) {
2464 partitions *= sharding.tile_assignment().dim(dim.rhs);
2465 } else {
2466 CHECK_EQ(lhs_rhs_or_output, 2);
2467 partitions *= sharding.tile_assignment().dim(dim.output);
2468 }
2469 }
2470 return partitions;
2471 };
2472 const int64 lhs_batch_partitions =
2473 get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
2474 const int64 rhs_batch_partitions =
2475 get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
2476 const int64 output_batch_partitions =
2477 get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
2478 const int64 lhs_contracting_partitions =
2479 get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
2480 const int64 rhs_contracting_partitions =
2481 get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
2482 const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
2483 lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
2484 const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
2485 rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
2486 const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
2487 output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
2488 const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
2489 output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
2490 const int64 lhs_conv_spatial_partitions = get_partitions_for_dims(
2491 lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
2492 const int64 rhs_conv_spatial_partitions = get_partitions_for_dims(
2493 rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
2494 const int64 output_conv_spatial_partitions = get_partitions_for_dims(
2495 output_sharding, dims_mapping.conv_spatial_dims, 2);
2496 // Before we find partial matches along the dimensions, invoke base case again
2497 // without may_reshard_without_detecting_match.
2498
2499 // Try partition the purely spatially-partitioned convolution with convolution
2500 // spatial dimension partitioned or depthwise parallel dimension partitioned.
2501 bool is_conv_spatial_dim_partitioned =
2502 (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
2503 output_conv_spatial_partitions > 1);
2504 bool is_conv_batch_or_contracting_dim_partitioned =
2505 (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
2506 output_batch_partitions > 1 ||
2507 (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
2508 if ((!dims_mapping.conv_spatial_dims.empty() &&
2509 is_conv_spatial_dim_partitioned &&
2510 !is_conv_batch_or_contracting_dim_partitioned) ||
2511 (original_hlo->opcode() == HloOpcode::kConvolution &&
2512 (original_hlo->batch_group_count() > 1 ||
2513 original_hlo->feature_group_count() > 1))) {
2514 // Partition with kernel_input_feature_dim > 1 and feature_group_count > 1
2515 // is not supported.
2516 const auto& dnums = original_hlo->convolution_dimension_numbers();
2517 if (original_hlo->feature_group_count() > 1 &&
2518 rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
2519 1) {
2520 return nullptr;
2521 }
2522
2523 TF_ASSIGN_OR_RETURN(
2524 auto partitioned_conv,
2525 PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
2526 dims_mapping, create_sharded_dot, conv_window,
2527 original_hlo, num_partitions, options,
2528 lhs.state().partition_id, module, b));
2529
2530 if (partitioned_conv) {
2531 return partitioned_conv;
2532 }
2533
2534 // Recursively partition on different types of dimensions for convolution.
2535 // Case 0.a: Group partitions by feature group count.
2536 if (original_hlo->feature_group_count() > 1 ||
2537 original_hlo->batch_group_count() > 1) {
2538 DotConvDimsMapping new_dims_mapping;
2539 if (original_hlo->feature_group_count() > 1) {
2540 new_dims_mapping =
2541 ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo);
2542 }
2543
2544 if (original_hlo->batch_group_count() > 1) {
2545 new_dims_mapping =
2546 ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
2547 }
2548
2549 const int64 conv_lhs_contracting_partitions = get_partitions_for_dims(
2550 lhs.sharding(), new_dims_mapping.contracting_dims, 0);
2551 const int64 conv_rhs_contracting_partitions = get_partitions_for_dims(
2552 rhs.sharding(), new_dims_mapping.contracting_dims, 1);
2553 const int64 conv_lhs_non_contracting_partitions = get_partitions_for_dims(
2554 lhs.sharding(), new_dims_mapping.lhs_non_contracting_dims, 0);
2555 const int64 conv_rhs_non_contracting_partitions = get_partitions_for_dims(
2556 rhs.sharding(), new_dims_mapping.rhs_non_contracting_dims, 1);
2557 const int64 conv_lhs_batch_partitions = get_partitions_for_dims(
2558 lhs.sharding(), new_dims_mapping.batch_dims, 0);
2559 const int64 conv_rhs_batch_partitions = get_partitions_for_dims(
2560 rhs.sharding(), new_dims_mapping.batch_dims, 1);
2561 const int64 conv_output_batch_partitions = get_partitions_for_dims(
2562 output_sharding, new_dims_mapping.batch_dims, 2);
2563 if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
2564 conv_rhs_batch_partitions == conv_output_batch_partitions) &&
2565 conv_output_batch_partitions > 1) {
2566 TF_ASSIGN_OR_RETURN(
2567 auto try_partitioned_conv,
2568 PartitionDotGroupOnBatch(
2569 lhs, rhs, output_base_shape, output_sharding, new_dims_mapping,
2570 num_partitions, conv_lhs_contracting_partitions,
2571 conv_rhs_contracting_partitions,
2572 conv_lhs_non_contracting_partitions,
2573 conv_rhs_non_contracting_partitions, create_sharded_dot,
2574 conv_window, module, original_hlo,
2575 require_matching_devices_to_group, options, b,
2576 windowed_dot_general_loops));
2577 if (try_partitioned_conv) {
2578 return try_partitioned_conv;
2579 }
2580 }
2581 return nullptr;
2582 }
2583 }
2584
2585 TF_ASSIGN_OR_RETURN(
2586 auto try_partitioned_dot,
2587 PartitionBaseCase(
2588 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2589 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2590 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2591 lhs_contracting_partitions, rhs_contracting_partitions,
2592 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2593 output_lhs_non_contracting_partitions,
2594 output_rhs_non_contracting_partitions, options, b,
2595 windowed_dot_general_loops,
2596 /*may_reshard_without_detecting_match=*/false));
2597 if (try_partitioned_dot) {
2598 return try_partitioned_dot;
2599 }
2600
2601 // Recursively partition on different types of dimensions.
2602 //
2603 // Case 1: Group partitions by batch.
2604 if ((lhs_batch_partitions == output_batch_partitions ||
2605 rhs_batch_partitions == output_batch_partitions) &&
2606 output_batch_partitions > 1) {
2607 TF_ASSIGN_OR_RETURN(
2608 auto dot,
2609 PartitionDotGroupOnBatch(
2610 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2611 num_partitions, lhs_contracting_partitions,
2612 rhs_contracting_partitions, lhs_non_contracting_partitions,
2613 rhs_non_contracting_partitions, create_sharded_dot, conv_window,
2614 module, original_hlo, require_matching_devices_to_group, options, b,
2615 windowed_dot_general_loops));
2616 if (dot) {
2617 return dot;
2618 }
2619 }
2620
2621 // Case 2: Group partitions by non-contracting dimensions.
2622 const bool may_group_on_lhs_non_contracting =
2623 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2624 lhs_non_contracting_partitions > 1;
2625 const bool may_group_on_rhs_non_contracting =
2626 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2627 rhs_non_contracting_partitions > 1;
2628 if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
2629 const bool lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
2630 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2631 num_partitions, lhs_non_contracting_partitions,
2632 rhs_non_contracting_partitions, lhs_contracting_partitions,
2633 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2634 output_rhs_non_contracting_partitions, lhs_batch_partitions,
2635 rhs_batch_partitions, may_group_on_lhs_non_contracting,
2636 may_group_on_rhs_non_contracting);
2637 TF_ASSIGN_OR_RETURN(
2638 auto dot,
2639 PartitionDotGroupOnNonContracting(
2640 lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
2641 lhs_matching ? lhs_contracting_partitions
2642 : rhs_contracting_partitions,
2643 lhs_matching ? rhs_contracting_partitions
2644 : lhs_contracting_partitions,
2645 lhs_matching ? dims_mapping.lhs_non_contracting_dims
2646 : dims_mapping.rhs_non_contracting_dims,
2647 lhs_matching ? rhs_non_contracting_partitions
2648 : lhs_non_contracting_partitions,
2649 lhs_matching ? output_rhs_non_contracting_partitions
2650 : output_lhs_non_contracting_partitions,
2651 output_base_shape, output_sharding, dims_mapping, num_partitions,
2652 create_sharded_dot, conv_window, module, original_hlo,
2653 require_matching_devices_to_group, options, b,
2654 windowed_dot_general_loops));
2655 if (dot) {
2656 return dot;
2657 }
2658 }
2659 if (lhs_non_contracting_partitions > 1 &&
2660 output_lhs_non_contracting_partitions > 1) {
2661 // If part of LHS non-contracting dims match output, try them.
2662 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2663 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2664 int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
2665 if (lhs_partitions > 1 &&
2666 lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
2667 matching_dims.push_back(dim);
2668 }
2669 }
2670 if (!matching_dims.empty()) {
2671 TF_ASSIGN_OR_RETURN(
2672 auto dot, PartitionDotGroupOnNonContracting(
2673 /*lhs_matching=*/true, lhs, rhs,
2674 lhs_contracting_partitions, rhs_contracting_partitions,
2675 matching_dims, rhs_non_contracting_partitions,
2676 output_rhs_non_contracting_partitions,
2677 output_base_shape, output_sharding, dims_mapping,
2678 num_partitions, create_sharded_dot, conv_window, module,
2679 original_hlo, require_matching_devices_to_group,
2680 options, b, windowed_dot_general_loops));
2681 if (dot) {
2682 return dot;
2683 }
2684 }
2685 }
2686 if (rhs_non_contracting_partitions > 1 &&
2687 output_rhs_non_contracting_partitions > 1) {
2688 // If part of RHS non-contracting dims match output, try them.
2689 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2690 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2691 int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
2692 if (rhs_partitions > 1 &&
2693 rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
2694 matching_dims.push_back(dim);
2695 }
2696 }
2697 if (!matching_dims.empty()) {
2698 TF_ASSIGN_OR_RETURN(
2699 auto dot, PartitionDotGroupOnNonContracting(
2700 /*lhs_matching=*/false, rhs, lhs,
2701 rhs_contracting_partitions, lhs_contracting_partitions,
2702 matching_dims, lhs_non_contracting_partitions,
2703 output_lhs_non_contracting_partitions,
2704 output_base_shape, output_sharding, dims_mapping,
2705 num_partitions, create_sharded_dot, conv_window, module,
2706 original_hlo, require_matching_devices_to_group,
2707 options, b, windowed_dot_general_loops));
2708 if (dot) {
2709 return dot;
2710 }
2711 }
2712 }
2713
2714 // Case 3: Group partitions by contracting dimensions.
2715 if (lhs_contracting_partitions == rhs_contracting_partitions &&
2716 lhs_contracting_partitions > 1) {
2717 TF_ASSIGN_OR_RETURN(
2718 auto dot,
2719 PartitionDotGroupOnContracting(
2720 lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
2721 output_lhs_non_contracting_partitions,
2722 output_rhs_non_contracting_partitions, output_base_shape,
2723 output_sharding, dims_mapping, num_partitions, create_sharded_dot,
2724 conv_window, module, original_hlo,
2725 require_matching_devices_to_group, options, b,
2726 windowed_dot_general_loops));
2727 if (dot) {
2728 return dot;
2729 }
2730 }
2731 if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
2732 // If part of contracting dims match, try them.
2733 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2734 for (const auto& dim : dims_mapping.contracting_dims) {
2735 int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
2736 if (lhs_partitions > 1 &&
2737 lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
2738 matching_dims.push_back(dim);
2739 }
2740 }
2741 if (!matching_dims.empty()) {
2742 TF_ASSIGN_OR_RETURN(
2743 auto dot, PartitionDotGroupOnContracting(
2744 lhs, rhs, matching_dims, output_batch_partitions,
2745 output_lhs_non_contracting_partitions,
2746 output_rhs_non_contracting_partitions,
2747 output_base_shape, output_sharding, dims_mapping,
2748 num_partitions, create_sharded_dot, conv_window, module,
2749 original_hlo, require_matching_devices_to_group,
2750 options, b, windowed_dot_general_loops));
2751 if (dot) {
2752 return dot;
2753 }
2754 }
2755 }
2756
2757 // Case 4: If operands are replicated but output is partially replicated,
2758 // recursive call with partial replication removed.
2759 if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
2760 output_sharding.ReplicateOnLastTileDim()) {
2761 auto grouped_output =
2762 GroupShardingOnDims(output_sharding, {output_base_shape.rank()});
2763 auto inner_state = CreatePerGroupPartitioningState(
2764 lhs.state(), grouped_output.device_groups, b);
2765 TF_ASSIGN_OR_RETURN(
2766 auto dot,
2767 PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
2768 PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
2769 output_base_shape, grouped_output.sharding, dims_mapping,
2770 output_sharding.NumTiles(), create_sharded_dot,
2771 conv_window, module, original_hlo, options, b,
2772 windowed_dot_general_loops));
2773 if (dot) {
2774 return dot;
2775 }
2776 }
2777
2778 // We failed to find partial matches, invoke base case again with
2779 // may_reshard_without_detecting_match.
2780 TF_ASSIGN_OR_RETURN(
2781 auto dot,
2782 PartitionBaseCase(
2783 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2784 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2785 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2786 lhs_contracting_partitions, rhs_contracting_partitions,
2787 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2788 output_lhs_non_contracting_partitions,
2789 output_rhs_non_contracting_partitions, options, b,
2790 windowed_dot_general_loops,
2791 /*may_reshard_without_detecting_match=*/true));
2792 if (dot) {
2793 return dot;
2794 }
2795 return nullptr;
2796 }
2797
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2798 StatusOr<HloInstruction*> PartitionDot(
2799 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2800 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2801 int64 num_partitions,
2802 const std::function<StatusOr<HloInstruction*>(
2803 HloInstruction*, HloInstruction*, SpmdBuilder*,
2804 const Window& conv_window)>& create_sharded_dot,
2805 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2806 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2807 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2808 windowed_dot_general_loops) {
2809 // First try partitioning without resharding the groups, then try allow
2810 // resharding the groups.
2811 for (bool require_matching_devices_to_group : {true, false}) {
2812 TF_ASSIGN_OR_RETURN(
2813 auto try_partition,
2814 PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2815 num_partitions, create_sharded_dot, conv_window, module,
2816 original_hlo, require_matching_devices_to_group, options,
2817 b, windowed_dot_general_loops));
2818 if (try_partition) {
2819 return try_partition;
2820 }
2821 }
2822
2823 // Default action.
2824 TF_ASSIGN_OR_RETURN(
2825 auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
2826 b, conv_window));
2827 dot->set_sharding(HloSharding::Replicate());
2828 return PartitionedHlo(dot, output_base_shape, lhs.state())
2829 .Reshard(output_sharding)
2830 .hlo();
2831 }
2832
2833 } // namespace
2834
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)2835 Status SpmdPartitioningVisitor::HandleDotHelper(
2836 HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
2837 const std::function<StatusOr<HloInstruction*>(
2838 HloInstruction*, HloInstruction*, SpmdBuilder*,
2839 const Window& conv_window)>& create_sharded_dot) {
2840 auto& lhs = GetPartitionedHlo(hlo->operand(0));
2841 auto& rhs = GetPartitionedHlo(hlo->operand(1));
2842 Window conv_window;
2843 if (hlo->opcode() == HloOpcode::kConvolution) {
2844 conv_window = hlo->window();
2845 }
2846
2847 TF_ASSIGN_OR_RETURN(
2848 auto partitioned_dot,
2849 PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
2850 num_partitions_, create_sharded_dot, conv_window, module_,
2851 hlo, options_, &b_, &windowed_dot_general_loops_));
2852 SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
2853 return Status::OK();
2854 }
2855
2856 namespace {
2857
2858 // Finds a cluster of nodes that produce the inputs for `hlo` which only depend
2859 // on small operands, which means the cluster should start with broadcasts,
2860 // constants and iotas. All other internal nodes must be non-side-effecting
2861 // elemntwise ops. Returns the set of nodes, and the small operands. E.g., for
2862 // the following graph,
2863 //
2864 // a -> broadcast -> multiply
2865 // iota ---> add--/
2866 // constant/
2867 //
2868 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
2869 // <{broadcast, iota, constant, add, multiply}, [a]>.
2870 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)2871 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
2872 absl::flat_hash_set<HloInstruction*> nodes_found;
2873 std::vector<HloInstruction*> new_operands;
2874 absl::flat_hash_set<const HloInstruction*> new_operands_set;
2875 std::vector<HloInstruction*> worklist;
2876 worklist.push_back(hlo);
2877 while (!worklist.empty()) {
2878 auto inst = worklist.back();
2879 worklist.pop_back();
2880 if (nodes_found.count(inst) > 0) {
2881 continue;
2882 }
2883 if (inst->opcode() == HloOpcode::kBroadcast ||
2884 inst->opcode() == HloOpcode::kConstant ||
2885 inst->opcode() == HloOpcode::kIota) {
2886 nodes_found.insert(inst);
2887 for (auto o : inst->operands()) {
2888 auto res = new_operands_set.emplace(o);
2889 if (res.second) {
2890 new_operands.push_back(o);
2891 }
2892 }
2893 } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
2894 inst->opcode() != HloOpcode::kAllReduce &&
2895 absl::c_all_of(inst->operands(),
2896 [inst](const HloInstruction* o) {
2897 return ShapeUtil::CompatibleIgnoringElementType(
2898 o->shape(), inst->shape());
2899 })) {
2900 nodes_found.insert(inst);
2901 for (auto o : inst->operands()) {
2902 worklist.push_back(o);
2903 }
2904 } else {
2905 nodes_found.clear();
2906 new_operands.clear();
2907 break;
2908 }
2909 }
2910 return {std::move(nodes_found), std::move(new_operands)};
2911 }
2912
2913 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
2914 // on contracting dimensions. Such a loop has a dynamic slice on the
2915 // non-windowed operand. If we move the input nodes into the loop, the
2916 // dynamic-slice could be merged with them by later optimization passes, which
2917 // reduces memory.
2918 //
2919 // small_operands small_operands
2920 // | |
2921 // input_nodes loop { |
2922 // | => input_nodes
2923 // loop { | |
2924 // dynamic-slice dynamic-slice
2925 // ... ...
2926 // } }
2927 //
2928 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
2929 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64 non_windowed_operand_index)2930 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
2931 HloInstruction* loop, int64 non_windowed_operand_index) {
2932 auto input_tuple = loop->mutable_operand(0);
2933 auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
2934 auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
2935 auto to_sink = std::move(input_nodes.first);
2936 auto new_operands = std::move(input_nodes.second);
2937 if (to_sink.empty()) {
2938 return Status::OK();
2939 }
2940 auto computation = loop->parent();
2941 // Replace the old operand with a tuple of the found small operands.
2942 auto new_input_subtuple =
2943 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
2944 TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
2945 non_windowed_operand_index, new_input_subtuple));
2946
2947 auto body = loop->while_body();
2948 auto body_param = body->parameter_instruction(0);
2949 auto old_body_param_users = body_param->users();
2950 // Update all tuple shapes.
2951 for (auto tuple : std::vector<HloInstruction*>{
2952 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
2953 body_param, body->root_instruction()}) {
2954 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
2955 {non_windowed_operand_index}) =
2956 new_input_subtuple->shape();
2957 }
2958 // Now update the loop body.
2959 auto new_operand_tuple_inside =
2960 body->AddInstruction(HloInstruction::CreateGetTupleElement(
2961 new_input_subtuple->shape(), body_param, non_windowed_operand_index));
2962 TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
2963 non_windowed_operand_index, new_operand_tuple_inside));
2964
2965 // Create nodes inside the loop body.
2966 std::vector<HloInstruction*> worklist;
2967 absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
2968 auto add_users_if_available = [&](HloInstruction* inst) {
2969 for (auto u : inst->users()) {
2970 if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
2971 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
2972 return outside_to_inside.count(o) > 0;
2973 })) {
2974 worklist.push_back(u);
2975 }
2976 }
2977 };
2978 for (int64 i = 0; i < new_operands.size(); ++i) {
2979 outside_to_inside[new_operands[i]] =
2980 body->AddInstruction(HloInstruction::CreateGetTupleElement(
2981 new_operands[i]->shape(), new_operand_tuple_inside, i));
2982 add_users_if_available(new_operands[i]);
2983 }
2984 // HLOs to sink without operands.
2985 std::vector<HloInstruction*> nullaries_to_sink;
2986 for (auto inst : to_sink) {
2987 if (inst->operand_count() == 0) {
2988 nullaries_to_sink.push_back(inst);
2989 }
2990 }
2991 // Sort nullaries_to_sink to make it deterministic.
2992 absl::c_sort(nullaries_to_sink,
2993 [](const HloInstruction* a, const HloInstruction* b) {
2994 return a->unique_id() < b->unique_id();
2995 });
2996 worklist.reserve(nullaries_to_sink.size());
2997 for (auto inst : nullaries_to_sink) {
2998 worklist.push_back(inst);
2999 }
3000 while (!worklist.empty()) {
3001 auto inst = worklist.back();
3002 worklist.pop_back();
3003 std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3004 for (int64 i = 0; i < inst->operand_count(); ++i) {
3005 inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3006 }
3007 outside_to_inside[inst] = body->AddInstruction(
3008 inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3009 add_users_if_available(inst);
3010 }
3011 TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3012 for (auto ou : old_body_param_users) {
3013 if (ou->opcode() == HloOpcode::kGetTupleElement &&
3014 ou->tuple_index() == non_windowed_operand_index) {
3015 TF_RETURN_IF_ERROR(
3016 ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3017 TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3018 }
3019 }
3020 return Status::OK();
3021 }
3022
3023 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into
3024 // the windowed dot-general loop on non-contracting dimensions. Such a loop has
3025 // a dynamic-update-slice at the output. If we move the user nodes into the loop
3026 // and before the dynamic-update-slice, the user nodes can operate on smaller
3027 // shapes, which reduces memory.
3028 //
3029 // small_operands small_operands
3030 // | | => | |
3031 // | | loop { loop { | |
3032 // | | conv | broadcast conv
3033 // | | | | | /
3034 // | | dynamic-update-slice | dynamic-slice /
3035 // | | | | | /
3036 // | | } | | multiply-----
3037 // |broadcast / | /
3038 // | | / reduce
3039 // |multiply-- |
3040 // \ | dynamic-update-slice
3041 // reduce }
3042 //
3043 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3044 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop)3045 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3046 HloInstruction* loop) {
3047 CHECK_EQ(loop->user_count(), 1);
3048 // There should be a single direct user of the while loop, which is the
3049 // gte for element 2, i.e., the dot output.
3050 auto user_gte = loop->users().front();
3051 CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
3052 CHECK_EQ(user_gte->tuple_index(), 2);
3053 auto computation = loop->parent();
3054
3055 // Find the reduce outputs and the input nodes they depend on, if input nodes
3056 // only have small operands.
3057 absl::flat_hash_set<HloInstruction*> to_move;
3058 std::vector<HloInstruction*> new_operands;
3059 absl::flat_hash_set<const HloInstruction*> new_operands_set;
3060 std::vector<HloInstruction*> reduce_outputs;
3061 std::vector<HloInstruction*> worklist;
3062 Shape padded_shape = user_gte->shape();
3063 Shape unpadded_shape = user_gte->shape();
3064 auto original_output = user_gte;
3065
3066 if (user_gte->user_count() == 1 &&
3067 user_gte->users().back()->opcode() == HloOpcode::kSlice) {
3068 original_output = user_gte->users().back();
3069 unpadded_shape = original_output->shape();
3070 }
3071 for (auto u : original_output->users()) {
3072 worklist.push_back(u);
3073 }
3074 to_move.insert(original_output);
3075 while (!worklist.empty()) {
3076 auto inst = worklist.back();
3077 worklist.pop_back();
3078 if (to_move.count(inst) > 0) {
3079 continue;
3080 }
3081 // We only support reduces with simple reduction function, since we may need
3082 // to accumulate across iterations manually.
3083 if (inst->opcode() == HloOpcode::kReduce &&
3084 inst->to_apply()->instruction_count() == 3 &&
3085 inst->to_apply()->num_parameters() == 2 &&
3086 inst->to_apply()->root_instruction()->IsElementwise()) {
3087 to_move.insert(inst);
3088 auto other_operand = inst->mutable_operand(1);
3089 auto res = new_operands_set.emplace(other_operand);
3090 if (res.second) {
3091 new_operands.push_back(other_operand);
3092 }
3093 reduce_outputs.push_back(inst);
3094 } else if (inst != computation->root_instruction() &&
3095 inst->user_count() > 0 && inst->IsElementwise() &&
3096 !inst->HasSideEffectNoRecurse() &&
3097 inst->opcode() != HloOpcode::kAllReduce &&
3098 absl::c_all_of(inst->operands(),
3099 [inst](const HloInstruction* o) {
3100 return ShapeUtil::CompatibleIgnoringElementType(
3101 o->shape(), inst->shape());
3102 })) {
3103 // For an elementwise op, we need to make sure that they depend on only
3104 // nodes already in to_move and nodes with small operands.
3105 bool can_include = true;
3106 for (auto operand : inst->operands()) {
3107 if (to_move.count(operand) > 0) {
3108 continue;
3109 }
3110 auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
3111 if (find_result.first.empty()) {
3112 can_include = false;
3113 break;
3114 }
3115 for (auto n : find_result.first) {
3116 to_move.insert(n);
3117 }
3118 for (auto new_operand : find_result.second) {
3119 auto res = new_operands_set.insert(new_operand);
3120 if (res.second) {
3121 new_operands.push_back(new_operand);
3122 }
3123 }
3124 }
3125 if (!can_include) {
3126 to_move.clear();
3127 break;
3128 }
3129 to_move.insert(inst);
3130 for (auto u : inst->users()) {
3131 worklist.push_back(u);
3132 }
3133 } else {
3134 to_move.clear();
3135 break;
3136 }
3137 }
3138 // If nothing is found, to_move could contain only original_output, or cleared
3139 // by the above code.
3140 if (to_move.size() <= 1) {
3141 return Status::OK();
3142 }
3143
3144 // We will replace the original loop output with reduce-shape outputs. Create
3145 // the initial buffers before the loop.
3146 for (auto out : reduce_outputs) {
3147 auto padded_out_shape = out->shape();
3148 int64 operand_dim = 0;
3149 int64 output_dim = 0;
3150 while (output_dim < padded_out_shape.rank()) {
3151 if (absl::c_linear_search(out->dimensions(), operand_dim)) {
3152 // Dimension colapsed.
3153 ++operand_dim;
3154 continue;
3155 }
3156 // Kept dimensions have the same size of the padded shape.
3157 padded_out_shape.set_dimensions(output_dim,
3158 padded_shape.dimensions(operand_dim));
3159 ++operand_dim;
3160 ++output_dim;
3161 }
3162 auto broadcast =
3163 computation->AddInstruction(HloInstruction::CreateBroadcast(
3164 padded_out_shape,
3165 computation->AddInstruction(HloInstruction::CreateConstant(
3166 LiteralUtil::Zero(out->shape().element_type()))),
3167 {}));
3168 new_operands.push_back(broadcast);
3169 }
3170
3171 auto input_tuple = loop->mutable_operand(0);
3172 // Create the new input subtuple that contains the small operands and the
3173 // reduce-shape result buffers.
3174 auto new_input_subtuple =
3175 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3176 TF_RETURN_IF_ERROR(
3177 input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
3178 auto body = loop->while_body();
3179 auto body_param = body->parameter_instruction(0);
3180 auto body_root = body->root_instruction();
3181 CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
3182 // Update tuple shapes.
3183 for (auto tuple : std::vector<HloInstruction*>{
3184 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3185 body_param, body_root}) {
3186 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
3187 new_input_subtuple->shape();
3188 }
3189 auto new_loop_input =
3190 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3191 new_input_subtuple->shape(), body_param, 2));
3192
3193 // Now create the moved nodes inside the loop body.
3194 absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3195 worklist.clear();
3196 auto add_users_if_available = [&](HloInstruction* inst) {
3197 for (auto u : inst->users()) {
3198 if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 &&
3199 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3200 return outside_to_inside.count(o) > 0;
3201 })) {
3202 worklist.push_back(u);
3203 }
3204 }
3205 };
3206 for (int64 i = 0; i < new_operands.size(); ++i) {
3207 outside_to_inside[new_operands[i]] =
3208 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3209 new_operands[i]->shape(), new_loop_input, i));
3210 add_users_if_available(new_operands[i]);
3211 }
3212 // The elementwise nodes will be created with sliced shape. The original loop
3213 // output corresponds to the dynamic-update-slice's update slice.
3214 auto dus = body_root->mutable_operand(2);
3215 CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice);
3216 outside_to_inside[original_output] = dus->mutable_operand(1);
3217 add_users_if_available(original_output);
3218 std::vector<HloInstruction*> slice_offsets(padded_shape.rank());
3219 for (int64 i = 0; i < slice_offsets.size(); ++i) {
3220 slice_offsets[i] = dus->mutable_operand(i + 2);
3221 }
3222 auto get_slice = [&](HloInstruction* padded) {
3223 return body->AddInstruction(HloInstruction::CreateDynamicSlice(
3224 ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3225 padded->shape().element_type()),
3226 padded, slice_offsets, dus->operand(1)->shape().dimensions()));
3227 };
3228 // Helper functions to create nodes with small operands.
3229 auto add_broadcast = [&](const HloInstruction* broadcast) {
3230 auto padded_operand_shape = broadcast->operand(0)->shape();
3231 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
3232 padded_operand_shape.set_dimensions(
3233 i, padded_shape.dimensions(broadcast->dimensions(i)));
3234 }
3235 auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)],
3236 padded_operand_shape, nullptr, body);
3237 outside_to_inside[broadcast] =
3238 get_slice(body->AddInstruction(broadcast->CloneWithNewOperands(
3239 ShapeUtil::ChangeElementType(padded_shape,
3240 padded_operand_shape.element_type()),
3241 {padded_operand})));
3242 };
3243 auto add_iota = [&](const HloInstruction* iota) {
3244 outside_to_inside[iota] =
3245 get_slice(body->AddInstruction(iota->CloneWithNewOperands(
3246 ShapeUtil::ChangeElementType(padded_shape,
3247 iota->shape().element_type()),
3248 {})));
3249 };
3250 auto add_constant = [&](const HloInstruction* constant) {
3251 outside_to_inside[constant] = body->AddInstruction(constant->Clone());
3252 outside_to_inside[constant] = get_slice(
3253 PadToShape(outside_to_inside[constant],
3254 ShapeUtil::ChangeElementType(
3255 padded_shape, constant->shape().element_type()),
3256 nullptr, body));
3257 };
3258 while (!worklist.empty()) {
3259 auto inst = worklist.back();
3260 worklist.pop_back();
3261 if (outside_to_inside.count(inst) > 0) {
3262 continue;
3263 }
3264 if (inst->opcode() == HloOpcode::kBroadcast) {
3265 add_broadcast(inst);
3266 } else if (inst->opcode() == HloOpcode::kIota) {
3267 add_iota(inst);
3268 } else if (inst->opcode() == HloOpcode::kConstant) {
3269 add_constant(inst);
3270 } else if (inst->opcode() == HloOpcode::kReduce) {
3271 // This is an output, for which we has special handling later.
3272 } else {
3273 std::vector<HloInstruction*> operands_inside(inst->operand_count());
3274 for (int64 i = 0; i < operands_inside.size(); ++i) {
3275 operands_inside[i] = outside_to_inside[inst->operand(i)];
3276 }
3277 outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands(
3278 ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3279 inst->shape().element_type()),
3280 operands_inside));
3281 }
3282 add_users_if_available(inst);
3283 }
3284 std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
3285 for (int64 i = 0; i < new_outputs_inside.size(); ++i) {
3286 new_outputs_inside[i] = outside_to_inside[new_operands[i]];
3287 }
3288 // Now create the reduce outpus inside of the loop.
3289 for (int64 i = 0; i < reduce_outputs.size(); ++i) {
3290 auto reduce_outside = reduce_outputs[i];
3291 CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
3292 int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3293 auto last_iter_result = outside_to_inside[new_operands[index_in_operand]];
3294 auto operand0 = outside_to_inside[reduce_outside->operand(0)];
3295 auto operand1 = outside_to_inside[reduce_outside->operand(1)];
3296 TF_ASSIGN_OR_RETURN(auto reduce_shape,
3297 ShapeInference::InferReduceShape(
3298 {&operand0->shape(), &operand1->shape()},
3299 reduce_outside->dimensions(),
3300 reduce_outside->to_apply()->ComputeProgramShape()));
3301 *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
3302 std::vector<HloInstruction*> reduce_dus_offsets;
3303 // If any collapsed dimension is windowed, we need to accumulate with last
3304 // iteration's result. If such a dimension has padding, we also need to mask
3305 // off invalid data.
3306 bool needs_accumulate = false;
3307 std::vector<int64> dims_to_mask;
3308 for (int64 i = 0; i < slice_offsets.size(); ++i) {
3309 if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
3310 if (reduce_outside->operand(0)->shape().dimensions(i) !=
3311 operand0->shape().dimensions(i)) {
3312 needs_accumulate = true;
3313 if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
3314 dims_to_mask.push_back(i);
3315 }
3316 }
3317 continue;
3318 }
3319 reduce_dus_offsets.push_back(slice_offsets[i]);
3320 }
3321 // Mask off invalid data in collapsed dimensions.
3322 for (int64 dim : dims_to_mask) {
3323 auto iota = body->AddInstruction(HloInstruction::CreateIota(
3324 ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
3325 auto add = body->AddInstruction(HloInstruction::CreateBinary(
3326 iota->shape(), HloOpcode::kAdd, iota,
3327 body->AddInstruction(HloInstruction::CreateBroadcast(
3328 iota->shape(), slice_offsets[dim], {}))));
3329 auto limit = body->AddInstruction(HloInstruction::CreateBroadcast(
3330 iota->shape(),
3331 body->AddInstruction(
3332 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
3333 reduce_outside->operand(0)->shape().dimensions(dim)))),
3334 {}));
3335 auto compare = body->AddInstruction(HloInstruction::CreateCompare(
3336 ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
3337 ComparisonDirection::kLt));
3338 operand0 = body->AddInstruction(HloInstruction::CreateTernary(
3339 operand0->shape(), HloOpcode::kSelect, compare, operand0,
3340 body->AddInstruction(HloInstruction::CreateBroadcast(
3341 operand0->shape(), operand1, {}))));
3342 }
3343 auto output_inside =
3344 body->AddInstruction(reduce_outside->CloneWithNewOperands(
3345 reduce_shape, {operand0, operand1}));
3346 // Accumulate with previous results if needed.
3347 if (needs_accumulate) {
3348 auto input_slice =
3349 body->AddInstruction(HloInstruction::CreateDynamicSlice(
3350 output_inside->shape(), last_iter_result, reduce_dus_offsets,
3351 output_inside->shape().dimensions()));
3352 output_inside = body->AddInstruction(HloInstruction::CreateBinary(
3353 output_inside->shape(),
3354 reduce_outside->to_apply()->root_instruction()->opcode(),
3355 output_inside, input_slice));
3356 }
3357 // Dynamic-update-slice if needed.
3358 if (!ShapeUtil::Compatible(output_inside->shape(),
3359 last_iter_result->shape())) {
3360 output_inside =
3361 body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3362 last_iter_result->shape(), last_iter_result, output_inside,
3363 reduce_dus_offsets));
3364 }
3365 new_outputs_inside[index_in_operand] = output_inside;
3366 }
3367 // Body output.
3368 auto new_output_inside =
3369 body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
3370 TF_RETURN_IF_ERROR(
3371 body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
3372 TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus));
3373 // Replace uses of the reduces outside the loop.
3374 auto new_output_gte =
3375 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3376 new_output_inside->shape(), loop, 2));
3377 for (int64 i = 0; i < reduce_outputs.size(); ++i) {
3378 int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3379 auto new_output =
3380 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3381 new_outputs_inside[index_in_operand]->shape(), new_output_gte,
3382 index_in_operand));
3383 if (!ShapeUtil::Compatible(new_output->shape(),
3384 reduce_outputs[i]->shape())) {
3385 new_output = computation->AddInstruction(HloInstruction::CreateSlice(
3386 reduce_outputs[i]->shape(), new_output,
3387 std::vector<int64>(new_output->shape().rank(), 0),
3388 reduce_outputs[i]->shape().dimensions(),
3389 std::vector<int64>(new_output->shape().rank(), 1)));
3390 }
3391 TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
3392 TF_RETURN_IF_ERROR(
3393 computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
3394 }
3395 return Status::OK();
3396 }
3397
3398 } // namespace
3399
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)3400 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
3401 HloComputation* computation, const SpmdPartitionerOptions& options) {
3402 for (auto& loop : windowed_dot_general_loops_) {
3403 if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
3404 loop.operands_sharded_at_contracting_dims) {
3405 // We have a dynamic-slice for the non-windowed operand in
3406 // batch/contracting-dim/noncontracting-dim windowed dot-general. So
3407 // moving the broadcast/iota/elementwise ops into the loop could help
3408 // reduce memory via fusion.
3409 TF_RETURN_IF_ERROR(
3410 SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3411 loop.while_loop, 1 - loop.windowed_operand));
3412 }
3413 // Currently unrolled loop does not support this optimization.
3414 if (!options.bidirectional_windowed_einsum &&
3415 !options.unroll_windowed_einsum && !loop.windowed_in_contracting_dims &&
3416 !loop.operands_sharded_at_contracting_dims) {
3417 // We have a dynamic-update-slice for the output in
3418 // batch/non-contracting-dim windowed dot-general. So moving reduce ops
3419 // into the loop could help reduce memory.
3420 TF_RETURN_IF_ERROR(
3421 MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3422 loop.while_loop));
3423 }
3424 }
3425 return Status::OK();
3426 }
3427
3428 } // namespace spmd
3429 } // namespace xla
3430