1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
30 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
31 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/gtl/cleanup.h"
37 #include "tensorflow/core/platform/numbers.h"
38 
39 namespace xla {
40 namespace spmd {
41 
HandleDot(HloInstruction * hlo)42 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
43   DotConvDimsMapping mapping;
44   const auto& dnums = hlo->dot_dimension_numbers();
45   int64 next_output_dim = 0;
46   for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
47     mapping.batch_dims.emplace_back();
48     mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
49     mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
50     mapping.batch_dims.back().output = next_output_dim++;
51   }
52   for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
53     mapping.contracting_dims.emplace_back();
54     mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
55     mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
56     mapping.contracting_dims.back().output = -1;
57   }
58   for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
59     if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
60         absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
61       continue;
62     }
63     mapping.lhs_non_contracting_dims.emplace_back();
64     mapping.lhs_non_contracting_dims.back().lhs = i;
65     mapping.lhs_non_contracting_dims.back().rhs = -1;
66     mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
67   }
68   for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
69     if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
70         absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
71       continue;
72     }
73     mapping.rhs_non_contracting_dims.emplace_back();
74     mapping.rhs_non_contracting_dims.back().lhs = -1;
75     mapping.rhs_non_contracting_dims.back().rhs = i;
76     mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
77   }
78   auto create_sharded_dot =
79       [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
80           const Window& conv_window) -> StatusOr<HloInstruction*> {
81     TF_ASSIGN_OR_RETURN(
82         auto sharded_dot_shape,
83         ShapeInference::InferDotOpShape(
84             l->shape(), r->shape(), hlo->dot_dimension_numbers(),
85             /*preferred_element_type=*/hlo->shape().element_type()));
86     return b->AddInstruction(HloInstruction::CreateDot(
87         sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
88         hlo->precision_config()));
89   };
90   return HandleDotHelper(hlo, mapping, create_sharded_dot);
91 }
92 
93 namespace {
94 
95 enum class WindowedEinsumOperand { LHS, RHS };
96 
97 struct WindowedEinsumConfig {
98   WindowedEinsumOperand windowed_op;
99   bool windowed_at_contracting_dims;
100   bool windowed_at_batch_dims;
101   bool operands_sharded_at_contracting_dims;
102 };
103 
104 struct DotDimensionIndexMapping {
105   std::vector<int64> lhs_to_rhs_indices;
106   std::vector<int64> lhs_to_output_indices;
107   std::vector<int64> rhs_to_lhs_indices;
108   std::vector<int64> rhs_to_output_indices;
109   std::vector<int64> output_to_lhs_indices;
110   std::vector<int64> output_to_rhs_indices;
111 };
112 
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64 reshaped_dim,bool lhs)113 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64 reshaped_dim,
114                   bool lhs) {
115   auto update_dims =
116       [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64>* dims) {
117         for (int64 i = 0; i < dims->size(); ++i) {
118           auto dim = dims->at(i);
119           if (reshaped_dim <= dim) {
120             dims->Set(i, dim + 1);
121           }
122         }
123         if (absl::c_linear_search(*dims, reshaped_dim)) {
124           dims->Add(reshaped_dim);
125         }
126       };
127 
128   if (lhs) {
129     update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
130     update_dims(new_ddnums->mutable_lhs_batch_dimensions());
131   } else {  // rhs
132     update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
133     update_dims(new_ddnums->mutable_rhs_batch_dimensions());
134   }
135 }
136 
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64 lhs_concat_dim,int64 rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)137 Window GenNewWindow(const HloInstruction* original_dot,
138                     const HloInstruction* dot_lhs,
139                     const HloInstruction* dot_rhs, int64 lhs_concat_dim,
140                     int64 rhs_concat_dim, bool windowed_at_contracting_dims,
141                     bool windowed_at_batch_dims) {
142   auto new_window = original_dot->window();
143   const ConvolutionDimensionNumbers& conv_dnums =
144       original_dot->convolution_dimension_numbers();
145   if (lhs_concat_dim != -1) {
146     for (int64 i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
147       if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
148         auto wd = new_window.mutable_dimensions(i);
149         auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
150         if (windowed_at_contracting_dims) {
151           wd->set_size(lhs_size);
152         }
153         if (windowed_at_batch_dims) {
154           wd->set_size(lhs_size);
155           wd->set_padding_low(0);
156           wd->set_padding_high(0);
157           wd->set_stride(std::max<int64>(1, lhs_size - 1));
158           wd->set_window_dilation(1);
159           wd->set_base_dilation(lhs_size);
160           wd->set_window_reversal(false);
161         }
162       }
163     }
164   }
165   if (rhs_concat_dim != -1) {
166     for (int64 i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
167       if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
168           !windowed_at_contracting_dims && !windowed_at_batch_dims &&
169           lhs_concat_dim == -1) {
170         auto wd = new_window.mutable_dimensions(i);
171         auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
172         wd->set_size(rhs_size);
173         wd->set_padding_low(rhs_size - 1);
174         wd->set_padding_high(rhs_size - 1);
175       }
176     }
177   }
178   // Add the extra dimension to window.
179   WindowDimension* new_dim = new_window.add_dimensions();
180   if (windowed_at_contracting_dims) {
181     new_dim->set_size(2);
182     new_dim->set_padding_low(0);
183     new_dim->set_padding_high(0);
184     new_dim->set_stride(1);
185     new_dim->set_window_dilation(1);
186     new_dim->set_base_dilation(1);
187     new_dim->set_window_reversal(false);
188   } else if (windowed_at_batch_dims) {
189     new_dim->set_size(2);
190     new_dim->set_padding_low(0);
191     new_dim->set_padding_high(0);
192     new_dim->set_stride(1);  // std::max<int64>(1, 2 - 1)
193     new_dim->set_window_dilation(1);
194     new_dim->set_base_dilation(2);
195     new_dim->set_window_reversal(false);
196   } else {
197     if (lhs_concat_dim != -1) {
198       new_dim->set_size(1);
199       new_dim->set_padding_low(0);
200       new_dim->set_padding_high(0);
201       new_dim->set_stride(1);
202       new_dim->set_window_dilation(1);
203       new_dim->set_base_dilation(1);
204       new_dim->set_window_reversal(false);
205     }
206     if (rhs_concat_dim != -1) {
207       new_dim->set_size(2);          // rhs_size
208       new_dim->set_padding_low(1);   // rhs_size - 1
209       new_dim->set_padding_high(1);  // rhs_size - 1
210       new_dim->set_stride(1);
211       new_dim->set_window_dilation(1);
212       new_dim->set_base_dilation(1);
213       new_dim->set_window_reversal(true);
214     }
215   }
216 
217   VLOG(2) << "new_window: " << new_window.ShortDebugString();
218   return new_window;
219 }
220 
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64 lhs_concat_dim,int64 rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64> & lhs_to_output_indices,const std::vector<int64> & rhs_to_output_indices,const Shape & new_dot_shape)221 ConvolutionDimensionNumbers GenNewConvDNums(
222     const HloInstruction* original_dot, const HloInstruction* dot_lhs,
223     const HloInstruction* dot_rhs, int64 lhs_concat_dim, int64 rhs_concat_dim,
224     bool windowed_at_contracting_dims, bool windowed_at_batch_dims,
225     const std::vector<int64>& lhs_to_output_indices,
226     const std::vector<int64>& rhs_to_output_indices,
227     const Shape& new_dot_shape) {
228   // Generate the new conv dimension numbers.
229   const ConvolutionDimensionNumbers& dnums =
230       original_dot->convolution_dimension_numbers();
231   // Handle the LHS dimension numbers.
232   int64 input_batch_dimension = dnums.input_batch_dimension();
233   int64 input_feature_dimension = dnums.input_feature_dimension();
234   std::vector<int64> input_spatial_dimensions(
235       dnums.input_spatial_dimensions().begin(),
236       dnums.input_spatial_dimensions().end());
237   if (lhs_concat_dim != -1) {
238     if (lhs_concat_dim <= input_batch_dimension) {
239       input_batch_dimension++;
240     }
241     if (lhs_concat_dim <= input_feature_dimension) {
242       input_feature_dimension++;
243     }
244     for (int64 i = 0; i < input_spatial_dimensions.size(); ++i) {
245       if (lhs_concat_dim <= input_spatial_dimensions[i]) {
246         input_spatial_dimensions[i]++;
247       }
248     }
249     input_spatial_dimensions.push_back(lhs_concat_dim);
250   }
251   if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
252       !windowed_at_batch_dims) {
253     input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
254   }
255   // Handle the RHS dimension numbers.
256   int64 kernel_input_feature_dimension = dnums.kernel_input_feature_dimension();
257   int64 kernel_output_feature_dimension =
258       dnums.kernel_output_feature_dimension();
259   std::vector<int64> kernel_spatial_dimensions(
260       dnums.kernel_spatial_dimensions().begin(),
261       dnums.kernel_spatial_dimensions().end());
262   if (rhs_concat_dim != -1) {
263     if (rhs_concat_dim <= kernel_input_feature_dimension) {
264       kernel_input_feature_dimension++;
265     }
266     if (rhs_concat_dim <= kernel_output_feature_dimension) {
267       kernel_output_feature_dimension++;
268     }
269     for (int64 i = 0; i < kernel_spatial_dimensions.size(); ++i) {
270       if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
271         kernel_spatial_dimensions[i]++;
272       }
273     }
274     kernel_spatial_dimensions.push_back(rhs_concat_dim);
275   }
276   if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
277       !windowed_at_batch_dims) {
278     kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
279   }
280   // Handle the Output dimension numbers.
281   int64 output_batch_dimension = dnums.output_batch_dimension();
282   int64 output_feature_dimension = dnums.output_feature_dimension();
283   std::vector<int64> output_spatial_dimensions(
284       dnums.output_spatial_dimensions().begin(),
285       dnums.output_spatial_dimensions().end());
286   if (!windowed_at_contracting_dims) {
287     auto output_slice_dim = lhs_concat_dim != -1
288                                 ? lhs_to_output_indices[lhs_concat_dim]
289                                 : rhs_to_output_indices[rhs_concat_dim];
290     if (output_slice_dim <= output_batch_dimension) {
291       output_batch_dimension++;
292     }
293     if (output_slice_dim <= output_feature_dimension) {
294       output_feature_dimension++;
295     }
296     for (int64 i = 0; i < output_spatial_dimensions.size(); ++i) {
297       if (output_slice_dim <= output_spatial_dimensions[i]) {
298         output_spatial_dimensions[i]++;
299       }
300     }
301     output_spatial_dimensions.push_back(output_slice_dim);
302   } else {
303     output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
304   }
305   // Construct the new dot dimension numbers.
306   ConvolutionDimensionNumbers new_dnums;
307   new_dnums.set_input_batch_dimension(input_batch_dimension);
308   new_dnums.set_input_feature_dimension(input_feature_dimension);
309   for (auto dim : input_spatial_dimensions) {
310     new_dnums.add_input_spatial_dimensions(dim);
311   }
312   new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
313   new_dnums.set_kernel_output_feature_dimension(
314       kernel_output_feature_dimension);
315   for (auto dim : kernel_spatial_dimensions) {
316     new_dnums.add_kernel_spatial_dimensions(dim);
317   }
318   new_dnums.set_output_batch_dimension(output_batch_dimension);
319   new_dnums.set_output_feature_dimension(output_feature_dimension);
320   for (auto dim : output_spatial_dimensions) {
321     new_dnums.add_output_spatial_dimensions(dim);
322   }
323 
324   return new_dnums;
325 }
326 
FirstShardingDimWithPartitionOfSize(int64 num_partitions,const HloSharding & sharding)327 int64 FirstShardingDimWithPartitionOfSize(int64 num_partitions,
328                                           const HloSharding& sharding) {
329   int64 sharding_dim = -1;
330   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
331     if (sharding.tile_assignment().dim(i) == num_partitions) {
332       sharding_dim = i;
333       break;
334     }
335   }
336   return sharding_dim;
337 }
338 
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64 lhs_rank,int64 rhs_rank,int64 output_rank)339 DotDimensionIndexMapping ComputeDimensionIndexMapping(
340     const DotConvDimsMapping& dims_mapping, int64 lhs_rank, int64 rhs_rank,
341     int64 output_rank) {
342   std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1);
343   std::vector<int64> lhs_to_output_indices(lhs_rank, -1);
344   std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1);
345   std::vector<int64> rhs_to_output_indices(rhs_rank, -1);
346   std::vector<int64> output_to_lhs_indices(output_rank, -1);
347   std::vector<int64> output_to_rhs_indices(output_rank, -1);
348   auto populate_indices_mapping =
349       [&](const DotConvDimsMapping::DimsMapping& mapping) {
350         if (mapping.lhs >= 0) {
351           lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
352           lhs_to_output_indices[mapping.lhs] = mapping.output;
353         }
354         if (mapping.rhs >= 0) {
355           rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
356           rhs_to_output_indices[mapping.rhs] = mapping.output;
357         }
358         if (mapping.output >= 0) {
359           output_to_lhs_indices[mapping.output] = mapping.lhs;
360           output_to_rhs_indices[mapping.output] = mapping.rhs;
361         }
362       };
363   for (const auto& mapping : dims_mapping.batch_dims) {
364     populate_indices_mapping(mapping);
365   }
366   for (const auto& mapping : dims_mapping.contracting_dims) {
367     populate_indices_mapping(mapping);
368   }
369   for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
370     populate_indices_mapping(mapping);
371   }
372   for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
373     populate_indices_mapping(mapping);
374   }
375   for (const auto& mapping : dims_mapping.conv_spatial_dims) {
376     populate_indices_mapping(mapping);
377   }
378   return DotDimensionIndexMapping{lhs_to_rhs_indices,    lhs_to_output_indices,
379                                   rhs_to_lhs_indices,    rhs_to_output_indices,
380                                   output_to_lhs_indices, output_to_rhs_indices};
381 }
382 
GetWindowedEinsumConfiguration(int64 num_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,int64 rhs_contracting_partitions,int64 rhs_non_contracting_partitions,int64 rhs_batch_partitions,int64 lhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 lhs_batch_partitions,int64 output_sharding_dim,int64 rhs_shape_size,int64 lhs_shape_size,int64 output_shape_size,int64 einsum_threshold_mib,const absl::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const absl::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding)383 absl::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
384     int64 num_partitions, int64 output_lhs_non_contracting_partitions,
385     int64 output_rhs_non_contracting_partitions,
386     int64 rhs_contracting_partitions, int64 rhs_non_contracting_partitions,
387     int64 rhs_batch_partitions, int64 lhs_contracting_partitions,
388     int64 lhs_non_contracting_partitions, int64 lhs_batch_partitions,
389     int64 output_sharding_dim, int64 rhs_shape_size, int64 lhs_shape_size,
390     int64 output_shape_size, int64 einsum_threshold_mib,
391     const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
392     const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
393     const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) {
394   if (output_lhs_non_contracting_partitions == num_partitions &&
395       output_sharding_transposed_to_match_lhs == lhs_sharding &&
396       rhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
397     if (rhs_contracting_partitions == num_partitions) {
398       return WindowedEinsumConfig{
399           /*windowed_op=*/WindowedEinsumOperand::RHS,
400           /*windowed_at_contracting_dims*/ true,
401           /*windowed_at_batch_dims=*/false,
402           /*operands_sharded_at_contracting_dims=*/false};
403     }
404     if (rhs_non_contracting_partitions == num_partitions) {
405       return WindowedEinsumConfig{
406           /*windowed_op=*/WindowedEinsumOperand::RHS,
407           /*windowed_at_contracting_dims*/ false,
408           /*windowed_at_batch_dims=*/false,
409           /*operands_sharded_at_contracting_dims=*/false};
410     }
411     if (rhs_batch_partitions == num_partitions) {
412       return WindowedEinsumConfig{
413           /*windowed_op=*/WindowedEinsumOperand::RHS,
414           /*windowed_at_contracting_dims*/ false,
415           /*windowed_at_batch_dims=*/true,
416           /*operands_sharded_at_contracting_dims=*/false};
417     }
418   }
419   if (output_rhs_non_contracting_partitions == num_partitions &&
420       output_sharding_transposed_to_match_rhs == rhs_sharding &&
421       lhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
422     if (lhs_contracting_partitions == num_partitions) {
423       return WindowedEinsumConfig{
424           /*windowed_op=*/WindowedEinsumOperand::LHS,
425           /*windowed_at_contracting_dims*/ true,
426           /*windowed_at_batch_dims=*/false,
427           /*operands_sharded_at_contracting_dims=*/false};
428     }
429     if (lhs_non_contracting_partitions == num_partitions) {
430       return WindowedEinsumConfig{
431           /*windowed_op=*/WindowedEinsumOperand::LHS,
432           /*windowed_at_contracting_dims*/ false,
433           /*windowed_at_batch_dims=*/false,
434           /*operands_sharded_at_contracting_dims=*/false};
435     }
436     if (lhs_batch_partitions == num_partitions) {
437       return WindowedEinsumConfig{
438           /*windowed_op=*/WindowedEinsumOperand::LHS,
439           /*windowed_at_contracting_dims*/ false,
440           /*windowed_at_batch_dims=*/true,
441           /*operands_sharded_at_contracting_dims=*/false};
442     }
443   }
444   if (lhs_contracting_partitions == rhs_contracting_partitions &&
445       lhs_contracting_partitions == num_partitions &&
446       output_sharding_dim > -1 &&
447       output_shape_size >= einsum_threshold_mib * 1024 * 1024) {
448     if (output_lhs_non_contracting_partitions == num_partitions) {
449       return WindowedEinsumConfig{
450           /*windowed_op=*/WindowedEinsumOperand::RHS,
451           /*windowed_at_contracting_dims*/ false,
452           /*windowed_at_batch_dims=*/false,
453           /*operands_sharded_at_contracting_dims=*/true};
454     }
455     if (output_rhs_non_contracting_partitions == num_partitions) {
456       return WindowedEinsumConfig{
457           /*windowed_op=*/WindowedEinsumOperand::LHS,
458           /*windowed_at_contracting_dims*/ false,
459           /*windowed_at_batch_dims=*/false,
460           /*operands_sharded_at_contracting_dims=*/true};
461     }
462   }
463   return absl::nullopt;
464 }
465 
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64 lhs_batch_partitions,int64 rhs_batch_partitions,int64 output_batch_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match)466 StatusOr<HloInstruction*> PartitionBaseCase(
467     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
468     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
469     int64 num_partitions,
470     const std::function<StatusOr<HloInstruction*>(
471         HloInstruction*, HloInstruction*, SpmdBuilder*,
472         const Window& conv_window)>& create_sharded_dot,
473     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
474     int64 lhs_batch_partitions, int64 rhs_batch_partitions,
475     int64 output_batch_partitions, int64 lhs_contracting_partitions,
476     int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
477     int64 rhs_non_contracting_partitions,
478     int64 output_lhs_non_contracting_partitions,
479     int64 output_rhs_non_contracting_partitions,
480     const SpmdPartitionerOptions& options, SpmdBuilder* b,
481     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
482         windowed_dot_general_loops,
483     bool may_reshard_without_detecting_match) {
484   const HloSharding& lhs_sharding = lhs.sharding();
485   const HloSharding& rhs_sharding = rhs.sharding();
486   if (lhs_sharding.ReplicateOnLastTileDim() ||
487       rhs_sharding.ReplicateOnLastTileDim() ||
488       output_sharding.ReplicateOnLastTileDim()) {
489     return nullptr;
490   }
491   DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
492       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
493       output_base_shape.rank());
494   auto lhs_sharding_transposed_to_match_rhs =
495       hlo_sharding_util::TransposeShardingWithCollapsedDims(
496           lhs_sharding, indices_map.lhs_to_rhs_indices,
497           indices_map.rhs_to_lhs_indices);
498   auto rhs_sharding_transposed_to_match_lhs =
499       hlo_sharding_util::TransposeShardingWithCollapsedDims(
500           rhs_sharding, indices_map.rhs_to_lhs_indices,
501           indices_map.lhs_to_rhs_indices);
502   auto lhs_sharding_transposed_to_match_output =
503       hlo_sharding_util::TransposeShardingWithCollapsedDims(
504           lhs_sharding, indices_map.lhs_to_output_indices,
505           indices_map.output_to_lhs_indices);
506   auto rhs_sharding_transposed_to_match_output =
507       hlo_sharding_util::TransposeShardingWithCollapsedDims(
508           rhs_sharding, indices_map.rhs_to_output_indices,
509           indices_map.output_to_rhs_indices);
510   auto output_sharding_transposed_to_match_lhs =
511       hlo_sharding_util::TransposeShardingWithCollapsedDims(
512           output_sharding, indices_map.output_to_lhs_indices,
513           indices_map.lhs_to_output_indices);
514   auto output_sharding_transposed_to_match_rhs =
515       hlo_sharding_util::TransposeShardingWithCollapsedDims(
516           output_sharding, indices_map.output_to_rhs_indices,
517           indices_map.rhs_to_output_indices);
518 
519   // LHS and RHS are partitioned the same way and only partitioned in batch
520   // dimensions.
521   if (lhs_batch_partitions == rhs_batch_partitions &&
522       rhs_batch_partitions == num_partitions &&
523       lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
524     TF_ASSIGN_OR_RETURN(
525         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
526     dot->set_sharding(*lhs_sharding_transposed_to_match_output);
527     return PartitionedHlo(dot, output_base_shape, lhs.state())
528         .Reshard(output_sharding)
529         .hlo();
530   }
531 
532   // Try emit batch-partitioned einsum with one operand resharded. Returns
533   // partitioned HLO or nullptr if the attempt fails. If
534   // may_reshard_with_allreduce is false, reshard must be done using
535   // all-to-all/collective-permute; otherwise this attempt fails.
536   auto try_emit_output_batch_partitioned_einsum_with_reshard =
537       [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
538     // LHS and output are batch partitioned in the same way.
539     if (lhs_batch_partitions == num_partitions &&
540         output_batch_partitions == num_partitions &&
541         lhs_sharding_transposed_to_match_output == output_sharding) {
542       if (!may_reshard_with_allreduce &&
543           !CanReshardWithCollectivePermute(
544               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
545           !GetReshardAllToAllSourceTargetDims(
546               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
547         return nullptr;
548       }
549       auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
550       TF_ASSIGN_OR_RETURN(
551           auto dot,
552           create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
553       return dot;
554     }
555     // RHS and output are batch partitioned in the same way.
556     if (rhs_batch_partitions == num_partitions &&
557         output_batch_partitions == num_partitions &&
558         rhs_sharding_transposed_to_match_output == output_sharding) {
559       if (!may_reshard_with_allreduce &&
560           !CanReshardWithCollectivePermute(
561               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
562           !GetReshardAllToAllSourceTargetDims(
563               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
564         return nullptr;
565       }
566       auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
567       TF_ASSIGN_OR_RETURN(
568           auto dot,
569           create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
570       return dot;
571     }
572     return nullptr;
573   };
574 
575   {
576     // Try batch-parallel by resharding one operand, and not using all-reduce.
577     TF_ASSIGN_OR_RETURN(
578         HloInstruction * partitioned_dot,
579         try_emit_output_batch_partitioned_einsum_with_reshard(false));
580     if (partitioned_dot) {
581       return partitioned_dot;
582     }
583   }
584 
585   const int64 output_sharding_dim =
586       FirstShardingDimWithPartitionOfSize(num_partitions, output_sharding);
587   // Try to emit windowed DotGeneral when one operand is partitioned in the same
588   // way as the output along non-contracting dimensions, but the other operand
589   // is tiled in other dimensions. Or both operands are partitioned in the same
590   // way along contracting dimensions, but the output is partitioned along
591   // non-contracting dimensions.
592   auto emit_windowed_dot_general =
593       [&](const WindowedEinsumConfig& einsum_config)
594       -> StatusOr<HloInstruction*> {
595     CHECK(!einsum_config.windowed_at_batch_dims ||
596           !einsum_config.windowed_at_contracting_dims);
597     const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
598     const bool windowed_at_contracting_dims =
599         einsum_config.windowed_at_contracting_dims;
600     const bool operands_sharded_at_contracting_dims =
601         einsum_config.operands_sharded_at_contracting_dims;
602     auto unpadded_result_buffer_shape =
603         MakePartitionedShape(output_base_shape, output_sharding);
604     auto padded_result_buffer_shape = unpadded_result_buffer_shape;
605     const bool windowed_op_is_lhs =
606         einsum_config.windowed_op == WindowedEinsumOperand::LHS;
607     // For windowing at batch/non-contracting dims, we produce the result one
608     // partition at a time, so we need to pad the shape in case of uneven
609     // partitioning in order to make dynamic-update-slice in-bound.
610     if (!windowed_at_contracting_dims &&
611         !operands_sharded_at_contracting_dims) {
612       padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
613           padded_result_buffer_shape,
614           windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
615                              : *rhs_sharding_transposed_to_match_output);
616     }
617     // Mask the padding area of the windowed operand with zero if there is
618     // uneven partitioning.
619     if (windowed_at_contracting_dims) {
620       auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
621       to_mask =
622           to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
623               LiteralUtil::Zero(output_base_shape.element_type()))));
624     }
625     if (operands_sharded_at_contracting_dims) {
626       auto zero = b->AddInstruction(HloInstruction::CreateConstant(
627           LiteralUtil::Zero(output_base_shape.element_type())));
628       lhs = lhs.PadWithValue(zero);
629       rhs = rhs.PadWithValue(zero);
630     }
631     auto result_buffer = CreateZero(padded_result_buffer_shape, b);
632     auto extra_buffer =
633         (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
634          operands_sharded_at_contracting_dims)
635             ? CreateZero(padded_result_buffer_shape, b)
636         : windowed_op_is_lhs ? lhs.hlo()
637                              : rhs.hlo();
638 
639     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
640         !operands_sharded_at_contracting_dims) {
641       std::vector<std::pair<int64, int64>> pre_sd_pairs(num_partitions);
642       for (int64 source = 0; source < num_partitions; ++source) {
643         // 0 -> 1, 1 -> 2, 2 -> 3, ...
644         pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
645       }
646       extra_buffer =
647           lhs.state()
648               .collective_ops_creator.create_cross_partition_collective_permute(
649                   b, extra_buffer, pre_sd_pairs,
650                   (*lhs.state().next_channel_id)++);
651     }
652 
653     auto iteration = b->AddInstruction(
654         HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
655 
656     // Create a while loop that computes one window per iteration. During each
657     // iteration, each partition sends its input window to its neighbor using
658     // collective-permute for the next iteration.
659     SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
660 
661     // Generate partial results used by bidirectional algorithm.
662     auto get_partial_bid_results =
663         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
664             HloInstruction* extra_inout, HloInstruction* cw_cp_output,
665             HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
666       auto partition_id =
667           lhs.state().collective_ops_creator.create_partition_id(&body_b);
668       auto partition_count =
669           body_b.AddInstruction(HloInstruction::CreateConstant(
670               LiteralUtil::CreateR0<uint32>(num_partitions)));
671       auto ccw_data_partition_id =
672           body_b.AddInstruction(HloInstruction::CreateBinary(
673               i->shape(), HloOpcode::kAdd, i, partition_id));
674       auto cw_data_partition_id =
675           body_b.AddInstruction(HloInstruction::CreateBinary(
676               i->shape(), HloOpcode::kAdd, partition_count, partition_id));
677       if (operands_sharded_at_contracting_dims) {
678         ccw_data_partition_id =
679             body_b.AddInstruction(HloInstruction::CreateBinary(
680                 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
681                 body_b.AddInstruction(HloInstruction::CreateConstant(
682                     LiteralUtil::CreateR0<uint32>(num_partitions / 2 + 1)))));
683         cw_data_partition_id =
684             body_b.AddInstruction(HloInstruction::CreateBinary(
685                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
686                 body_b.AddInstruction(HloInstruction::CreateConstant(
687                     LiteralUtil::CreateR0<uint32>(num_partitions / 2)))));
688       } else {
689         cw_data_partition_id =
690             body_b.AddInstruction(HloInstruction::CreateBinary(
691                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
692                 CreateOne(cw_data_partition_id->shape(), &body_b)));
693       }
694       ccw_data_partition_id = body_b.AddInstruction(
695           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
696                                        ccw_data_partition_id, partition_count));
697       cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
698           i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
699       cw_data_partition_id = body_b.AddInstruction(
700           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
701                                        cw_data_partition_id, partition_count));
702       // Calculate concat dim.
703       const HloSharding* slice_sharding;
704       if (operands_sharded_at_contracting_dims) {
705         slice_sharding = windowed_op_is_lhs
706                              ? &*output_sharding_transposed_to_match_rhs
707                              : &*output_sharding_transposed_to_match_lhs;
708       } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
709         slice_sharding = windowed_op_is_lhs
710                              ? &*lhs_sharding_transposed_to_match_rhs
711                              : &*rhs_sharding_transposed_to_match_lhs;
712       } else {
713         slice_sharding = windowed_op_is_lhs
714                              ? &*lhs_sharding_transposed_to_match_output
715                              : &*rhs_sharding_transposed_to_match_output;
716       }
717       CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
718                num_partitions);
719       int64 slice_sharding_dim = -1;
720       for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions();
721            ++i) {
722         if (slice_sharding->tile_assignment().dim(i) > 1) {
723           slice_sharding_dim = i;
724           break;
725         }
726       }
727       int64 lhs_concat_dim = -1;
728       int64 rhs_concat_dim = -1;
729       if (operands_sharded_at_contracting_dims) {
730         if (windowed_op_is_lhs) {
731           rhs_concat_dim = slice_sharding_dim;
732         } else {
733           lhs_concat_dim = slice_sharding_dim;
734         }
735       } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
736         lhs_concat_dim =
737             windowed_op_is_lhs
738                 ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
739                 : slice_sharding_dim;
740         rhs_concat_dim =
741             windowed_op_is_lhs
742                 ? slice_sharding_dim
743                 : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
744       } else {
745         if (windowed_op_is_lhs) {
746           lhs_concat_dim =
747               indices_map.output_to_lhs_indices[slice_sharding_dim];
748         } else {
749           rhs_concat_dim =
750               indices_map.output_to_rhs_indices[slice_sharding_dim];
751         }
752       }
753 
754       DotDimensionNumbers new_ddnums;
755       if (original_hlo->opcode() == HloOpcode::kDot) {
756         new_ddnums = original_hlo->dot_dimension_numbers();
757       }
758 
759       auto dot_lhs = l;
760       auto dot_rhs = r;
761       auto original_dot_lhs = l;
762       auto original_dot_rhs = r;
763       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
764           operands_sharded_at_contracting_dims) {
765         // Slice the matching operand according to the partitioned dimensions
766         // on the windowed operand or the output.
767         auto slice_operand = !windowed_op_is_lhs ? l : r;
768 
769         // Pad the sharding dim first (then the concat dim) for correctness.
770         auto sharding_dim_size =
771             slice_operand->shape().dimensions(slice_sharding_dim);
772         if (sharding_dim_size % num_partitions != 0) {
773           slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
774               slice_operand, *slice_sharding, &body_b);
775         }
776 
777         // We do this by treating the matching operand as replicated, and
778         // resharding it to match the windowed operand or the output.
779         auto gen_slice = [&](HloInstruction* data_partition_id,
780                              bool ccw) -> HloInstruction* {
781           std::vector<int64> new_dims;
782           for (int64 i = 0; i < slice_operand->shape().dimensions_size(); ++i) {
783             if (i == slice_sharding_dim) {
784               new_dims.push_back(1);
785             }
786             new_dims.push_back(slice_operand->shape().dimensions(i));
787           }
788           auto reshaped_slice_operand =
789               body_b.AddInstruction(HloInstruction::CreateReshape(
790                   ShapeUtil::MakeShape(slice_operand->shape().element_type(),
791                                        new_dims),
792                   slice_operand));
793           auto min = body_b.AddInstruction(
794               HloInstruction::CreateConstant(LiteralUtil::MinValue(
795                   reshaped_slice_operand->shape().element_type())));
796           std::vector<int64> min_padding(
797               reshaped_slice_operand->shape().rank());
798           auto padded_slice_operand = reshaped_slice_operand;
799           auto padded_shape = padded_slice_operand->shape();
800           int64 padding_dim = slice_sharding_dim;
801           padded_shape.set_dimensions(padding_dim, 2);
802           if (ccw) {
803             // ccw pad high
804             PaddingConfig ccw_pad_config =
805                 window_util::MakeSymmetricPadding(min_padding);
806             ccw_pad_config.mutable_dimensions(padding_dim)
807                 ->set_edge_padding_low(0);
808             ccw_pad_config.mutable_dimensions(padding_dim)
809                 ->set_edge_padding_high(1);
810             padded_slice_operand =
811                 body_b.AddInstruction(HloInstruction::CreatePad(
812                     padded_shape, padded_slice_operand, min, ccw_pad_config));
813           } else {
814             // cw pad low
815             PaddingConfig cw_pad_config =
816                 window_util::MakeSymmetricPadding(min_padding);
817             cw_pad_config.mutable_dimensions(padding_dim)
818                 ->set_edge_padding_low(1);
819             cw_pad_config.mutable_dimensions(padding_dim)
820                 ->set_edge_padding_high(0);
821             padded_slice_operand =
822                 body_b.AddInstruction(HloInstruction::CreatePad(
823                     padded_shape, padded_slice_operand, min, cw_pad_config));
824           }
825 
826           padded_slice_operand->set_sharding(HloSharding::Replicate());
827           auto state = lhs.state();
828           state.b = &body_b;
829           state.partition_id = data_partition_id;
830           state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
831           auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
832               slice_operand->shape(), reshaped_slice_operand->shape(),
833               *slice_sharding);
834           auto padded_slice =
835               PartitionedHlo(padded_slice_operand,
836                              padded_slice_operand->shape(), state)
837                   .Reshard(*padded_slice_sharding)
838                   .hlo();
839           padded_slice_operand->clear_sharding();
840           return padded_slice;
841         };
842 
843         auto ccw_slice = gen_slice(ccw_data_partition_id, true);
844         auto cw_slice = gen_slice(cw_data_partition_id, false);
845         auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
846             ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
847         // Reshape. The reshaped slice will not be used to produce the final
848         // result, but used as a hint for the shape inference.
849         std::vector<int64> reshaped_slice_dims;
850         for (int64 i = 0; i < slice->shape().dimensions_size(); ++i) {
851           auto dim_size = slice->shape().dimensions(i);
852           if (i == (slice_sharding_dim + 1)) {
853             reshaped_slice_dims.push_back(dim_size * 2);
854           } else if (i != slice_sharding_dim) {
855             reshaped_slice_dims.push_back(dim_size);
856           }
857         }
858         auto reshaped_slice =
859             body_b.AddInstruction(HloInstruction::CreateReshape(
860                 ShapeUtil::MakeShape(slice->shape().element_type(),
861                                      reshaped_slice_dims),
862                 slice));
863 
864         if (!windowed_op_is_lhs) {
865           dot_lhs = slice;
866           original_dot_lhs = reshaped_slice;
867           if (original_hlo->opcode() == HloOpcode::kDot) {
868             UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
869           }
870         } else {
871           dot_rhs = slice;
872           original_dot_rhs = reshaped_slice;
873           if (original_hlo->opcode() == HloOpcode::kDot) {
874             UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
875           }
876         }
877       }
878 
879       auto ccw_dot_lhs = l;
880       auto ccw_dot_rhs = r;
881       auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
882       auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
883       if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
884         auto lhs_concat_shape = ccw_dot_lhs->shape();
885         lhs_concat_shape.set_dimensions(
886             lhs_concat_dim,
887             ccw_dot_lhs->shape().dimensions(lhs_concat_dim) * 2);
888         dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
889             lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
890         original_dot_lhs = dot_lhs;
891 
892         // Reshape
893         std::vector<int64> reshaped_dims(dot_lhs->shape().dimensions().begin(),
894                                          dot_lhs->shape().dimensions().end());
895         reshaped_dims[lhs_concat_dim] /= 2;
896         reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 2);
897         dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
898             ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
899                                  reshaped_dims),
900             dot_lhs));
901 
902         if (original_hlo->opcode() == HloOpcode::kDot) {
903           UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
904         }
905       }
906       if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
907         auto rhs_concat_shape = ccw_dot_rhs->shape();
908         rhs_concat_shape.set_dimensions(
909             rhs_concat_dim,
910             ccw_dot_rhs->shape().dimensions(rhs_concat_dim) * 2);
911         dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
912             rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
913         original_dot_rhs = dot_rhs;
914 
915         // Reshape
916         std::vector<int64> reshaped_dims(dot_rhs->shape().dimensions().begin(),
917                                          dot_rhs->shape().dimensions().end());
918         reshaped_dims[rhs_concat_dim] /= 2;
919         reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 2);
920         dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
921             ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
922                                  reshaped_dims),
923             dot_rhs));
924 
925         if (original_hlo->opcode() == HloOpcode::kDot) {
926           UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
927         }
928       }
929 
930       // The generated original dot will not be used.
931       TF_ASSIGN_OR_RETURN(auto original_dot,
932                           create_sharded_dot(original_dot_lhs, original_dot_rhs,
933                                              &body_b, conv_window));
934       VLOG(2) << original_dot->ToString();
935 
936       // Generate the correct shape of the new dot/conv.
937       auto original_sharded_dot_shape = original_dot->shape();
938       auto new_dot_shape = original_sharded_dot_shape;
939       std::vector<int64> new_dims(new_dot_shape.dimensions().begin(),
940                                   new_dot_shape.dimensions().end());
941       if (!windowed_at_contracting_dims) {
942         auto slice_dim =
943             lhs_concat_dim != -1
944                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
945                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
946         new_dims[slice_dim] /= 2;
947         new_dims.insert(new_dims.begin() + slice_dim, 2);
948       } else {
949         new_dims.push_back(1);
950       }
951       new_dot_shape =
952           ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
953 
954       HloInstruction* dot;
955       if (original_hlo->opcode() == HloOpcode::kDot) {
956         dot = body_b.AddInstruction(HloInstruction::CreateDot(
957             new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
958             original_hlo->precision_config()));
959       } else {
960         if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
961           if (lhs_concat_dim != -1) {
962             std::vector<int64> new_dims(dot_rhs->shape().dimensions().begin(),
963                                         dot_rhs->shape().dimensions().end());
964             new_dims.push_back(1);
965             dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
966                 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
967                 dot_rhs));
968           }
969           if (rhs_concat_dim != -1) {
970             std::vector<int64> new_dims(dot_lhs->shape().dimensions().begin(),
971                                         dot_lhs->shape().dimensions().end());
972             new_dims.push_back(1);
973             dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
974                 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
975                 dot_lhs));
976           }
977         }
978 
979         dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
980             new_dot_shape, dot_lhs, dot_rhs,
981             original_dot->feature_group_count(),
982             original_dot->batch_group_count(),
983             GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
984                          rhs_concat_dim, windowed_at_contracting_dims,
985                          windowed_at_batch_dims),
986             GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
987                             rhs_concat_dim, windowed_at_contracting_dims,
988                             windowed_at_batch_dims,
989                             indices_map.lhs_to_output_indices,
990                             indices_map.rhs_to_output_indices, new_dot_shape),
991             original_dot->precision_config()));
992       }
993       VLOG(2) << dot->ToString();
994 
995       // Reshape to the original sharded dot shape.
996       dot = body_b.AddInstruction(
997           HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
998 
999       if (windowed_at_contracting_dims) {
1000         // Accumulate the partial output to the result buffer.
1001         o = body_b.AddInstruction(
1002             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1003       } else {
1004         // The windowing operand is partitioned along batch/non-contracting
1005         // dimensions, so we need a dynamic-update-slice to save the partial
1006         // output in the result buffer.
1007         auto slice_shape = dot->shape();
1008         auto slice_dim =
1009             lhs_concat_dim != -1
1010                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1011                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1012         slice_shape.set_dimensions(slice_dim,
1013                                    dot->shape().dimensions(slice_dim) / 2);
1014         std::vector<int64> ccw_start_indices(dot->shape().rank(), 0);
1015         std::vector<int64> cw_start_indices(dot->shape().rank(), 0);
1016         cw_start_indices[slice_dim] = dot->shape().dimensions(slice_dim) / 2;
1017         auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1018             slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1019             std::vector<int64>(dot->shape().rank(), 1)));
1020         auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1021             slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1022             std::vector<int64>(dot->shape().rank(), 1)));
1023 
1024         if (operands_sharded_at_contracting_dims) {
1025           // Accumulate the partial output to the result buffer.
1026           o = body_b.AddInstruction(HloInstruction::CreateBinary(
1027               o->shape(), HloOpcode::kAdd, o, ccw_dot));
1028           cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1029               o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1030         } else {
1031           auto ccw_offsets = MakePartitionOffsets(
1032               o->shape(),
1033               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1034                                  : *rhs_sharding_transposed_to_match_output,
1035               ccw_data_partition_id, &body_b);
1036           auto cw_offsets = MakePartitionOffsets(
1037               o->shape(),
1038               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1039                                  : *rhs_sharding_transposed_to_match_output,
1040               cw_data_partition_id, &body_b);
1041           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1042               o->shape(), o, ccw_dot, ccw_offsets));
1043           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1044               o->shape(), o, cw_dot, cw_offsets));
1045         }
1046       }
1047 
1048       std::vector<HloInstruction*> partial_results;
1049       partial_results.push_back(o);
1050       partial_results.push_back(cw_cp_output);
1051       return partial_results;
1052     };
1053 
1054     // Generate partial result used by unidirectional algorithm.
1055     auto get_partial_unid_result =
1056         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1057             HloInstruction* i) -> StatusOr<HloInstruction*> {
1058       auto partition_id =
1059           lhs.state().collective_ops_creator.create_partition_id(&body_b);
1060       auto data_partition_id =
1061           body_b.AddInstruction(HloInstruction::CreateBinary(
1062               i->shape(), HloOpcode::kAdd, i, partition_id));
1063       auto partition_count =
1064           body_b.AddInstruction(HloInstruction::CreateConstant(
1065               LiteralUtil::CreateR0<uint32>(num_partitions)));
1066       data_partition_id = body_b.AddInstruction(
1067           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1068                                        data_partition_id, partition_count));
1069       auto dot_lhs = l;
1070       auto dot_rhs = r;
1071       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1072           operands_sharded_at_contracting_dims) {
1073         // Slice the matching operand according to the partitioned dimensions on
1074         // the windowed operand or the output.
1075         auto slice_operand = !windowed_op_is_lhs ? l : r;
1076         // We do this by treating the matching operand as replicated, and
1077         // resharding it to match the windowed operand or the output.
1078         slice_operand->set_sharding(HloSharding::Replicate());
1079         auto state = lhs.state();
1080         state.b = &body_b;
1081         state.partition_id = data_partition_id;
1082         state.reshard_cache->per_hlo_cache.erase(slice_operand);
1083         const HloSharding* slice_sharding;
1084         if (operands_sharded_at_contracting_dims) {
1085           slice_sharding = windowed_op_is_lhs
1086                                ? &*output_sharding_transposed_to_match_rhs
1087                                : &*output_sharding_transposed_to_match_lhs;
1088         } else {
1089           slice_sharding = windowed_op_is_lhs
1090                                ? &*lhs_sharding_transposed_to_match_rhs
1091                                : &*rhs_sharding_transposed_to_match_lhs;
1092         }
1093         auto slice =
1094             PartitionedHlo(slice_operand, slice_operand->shape(), state)
1095                 .Reshard(*slice_sharding)
1096                 .hlo();
1097         slice_operand->clear_sharding();
1098         if (!windowed_op_is_lhs) {
1099           dot_lhs = slice;
1100         } else {
1101           dot_rhs = slice;
1102         }
1103       }
1104       TF_ASSIGN_OR_RETURN(
1105           auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1106       if (windowed_at_contracting_dims ||
1107           operands_sharded_at_contracting_dims) {
1108         // Accumulate the partial output to the result buffer.
1109         o = body_b.AddInstruction(
1110             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1111       } else {
1112         // The windowing operand is partitioned along batch/non-contracting
1113         // dimensions, so we need a dynamic-update-slice to save the partial
1114         // output in the result buffer.
1115         auto offsets = MakePartitionOffsets(
1116             o->shape(),
1117             windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1118                                : *rhs_sharding_transposed_to_match_output,
1119             data_partition_id, &body_b);
1120         o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1121             o->shape(), o, dot, offsets));
1122       }
1123       return o;
1124     };
1125 
1126     auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1127         /*parameter_number=*/0,
1128         ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
1129                                    result_buffer->shape(),
1130                                    extra_buffer->shape(), iteration->shape()}),
1131         "param"));
1132     auto l = body_b.AddInstruction(
1133         HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0));
1134     auto r = body_b.AddInstruction(
1135         HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1));
1136     auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1137         result_buffer->shape(), param, 2));
1138     auto extra_inout = body_b.AddInstruction(
1139         HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1140     auto i = body_b.AddInstruction(
1141         HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1142 
1143     // The bidirectional collective permute implementation has loop unrolling
1144     // of degree 2, so num_partitions is required to be a multiple of 4.
1145     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1146       std::vector<std::pair<int64, int64>> ccw_sd_pairs(num_partitions);
1147       for (int64 source = 0; source < num_partitions; ++source) {
1148         // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1149         ccw_sd_pairs[source] = {source,
1150                                 (source - 1 + num_partitions) % num_partitions};
1151       }
1152       std::vector<std::pair<int64, int64>> cw_sd_pairs(num_partitions);
1153       for (int64 source = 0; source < num_partitions; ++source) {
1154         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1155         cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1156       }
1157 
1158       // Even number iteration.
1159       auto next_l = l;
1160       auto next_r = r;
1161       auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1162                           : windowed_op_is_lhs                 ? l
1163                                                                : r;
1164       auto ccw_cp_output =
1165           lhs.state()
1166               .collective_ops_creator.create_cross_partition_collective_permute(
1167                   &body_b, ccw_cp_input, ccw_sd_pairs,
1168                   (*lhs.state().next_channel_id)++);
1169       if (operands_sharded_at_contracting_dims) {
1170         o = ccw_cp_output;
1171       } else if (windowed_op_is_lhs) {
1172         next_l = ccw_cp_output;
1173       } else {
1174         next_r = ccw_cp_output;
1175       }
1176       auto cw_cp_input = extra_inout;
1177       auto cw_cp_output =
1178           lhs.state()
1179               .collective_ops_creator.create_cross_partition_collective_permute(
1180                   &body_b, cw_cp_input, cw_sd_pairs,
1181                   (*lhs.state().next_channel_id)++);
1182 
1183       TF_ASSIGN_OR_RETURN(
1184           auto outputs,
1185           get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1186       o = outputs[0];
1187       cw_cp_output = outputs[1];
1188 
1189       // ++i
1190       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1191           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1192 
1193       // Odd number iteration.
1194       auto second_next_l = next_l;
1195       auto second_next_r = next_r;
1196       ccw_cp_input = operands_sharded_at_contracting_dims ? o
1197                      : windowed_op_is_lhs                 ? next_l
1198                                                           : next_r;
1199       ccw_cp_output =
1200           lhs.state()
1201               .collective_ops_creator.create_cross_partition_collective_permute(
1202                   &body_b, ccw_cp_input, ccw_sd_pairs,
1203                   (*lhs.state().next_channel_id)++);
1204       if (operands_sharded_at_contracting_dims) {
1205         o = ccw_cp_output;
1206       } else if (windowed_op_is_lhs) {
1207         second_next_l = ccw_cp_output;
1208       } else {
1209         second_next_r = ccw_cp_output;
1210       }
1211       auto next_cw_cp_input = cw_cp_output;
1212       auto next_cw_cp_output =
1213           lhs.state()
1214               .collective_ops_creator.create_cross_partition_collective_permute(
1215                   &body_b, next_cw_cp_input, cw_sd_pairs,
1216                   (*lhs.state().next_channel_id)++);
1217 
1218       TF_ASSIGN_OR_RETURN(
1219           outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1220                                            next_cw_cp_output, i));
1221       o = outputs[0];
1222       next_cw_cp_output = outputs[1];
1223 
1224       // ++i
1225       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1226           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1227 
1228       body_b.AddInstruction(HloInstruction::CreateTuple(
1229           {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1230 
1231     } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1232       if (operands_sharded_at_contracting_dims) {
1233         std::vector<std::pair<int64, int64>> output_sd_pairs(num_partitions);
1234         for (int64 source = 0; source < num_partitions; ++source) {
1235           // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1236           output_sd_pairs[source] = {
1237               source, (source - 2 + num_partitions) % num_partitions};
1238         }
1239 
1240         o = lhs.state()
1241                 .collective_ops_creator
1242                 .create_cross_partition_collective_permute(
1243                     &body_b, o, output_sd_pairs,
1244                     (*lhs.state().next_channel_id)++);
1245 
1246         TF_ASSIGN_OR_RETURN(extra_inout,
1247                             get_partial_unid_result(l, r, extra_inout, i));
1248 
1249         extra_inout = lhs.state()
1250                           .collective_ops_creator
1251                           .create_cross_partition_collective_permute(
1252                               &body_b, extra_inout, output_sd_pairs,
1253                               (*lhs.state().next_channel_id)++);
1254 
1255         // i+2
1256         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1257             i->shape(), HloOpcode::kAdd, i,
1258             body_b.AddInstruction(HloInstruction::CreateConstant(
1259                 LiteralUtil::CreateR0<uint32>(2)))));
1260         auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1261             i->shape(), HloOpcode::kAdd, i,
1262             body_b.AddInstruction(HloInstruction::CreateConstant(
1263                 LiteralUtil::CreateR0<uint32>(1)))));
1264 
1265         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1266         body_b.AddInstruction(
1267             HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1268       } else {
1269         std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1270         for (int64 source = 0; source < num_partitions; ++source) {
1271           // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1272           sd_pairs[source] = {source,
1273                               (source - 1 + num_partitions) % num_partitions};
1274         }
1275 
1276         // Even number iteration.
1277         auto next_l = l;
1278         auto next_r = r;
1279         auto cp_input = windowed_op_is_lhs ? l : r;
1280         auto cp_output = lhs.state()
1281                              .collective_ops_creator
1282                              .create_cross_partition_collective_permute(
1283                                  &body_b, cp_input, sd_pairs,
1284                                  (*lhs.state().next_channel_id)++);
1285         if (windowed_op_is_lhs) {
1286           next_l = cp_output;
1287         } else {
1288           next_r = cp_output;
1289         }
1290         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1291 
1292         // ++i
1293         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1294             i->shape(), HloOpcode::kAdd, i,
1295             body_b.AddInstruction(HloInstruction::CreateConstant(
1296                 LiteralUtil::CreateR0<uint32>(1)))));
1297 
1298         // Odd number iteration.
1299         auto second_next_l = next_l;
1300         auto second_next_r = next_r;
1301         cp_input = windowed_op_is_lhs ? next_l : next_r;
1302         cp_output = lhs.state()
1303                         .collective_ops_creator
1304                         .create_cross_partition_collective_permute(
1305                             &body_b, cp_input, sd_pairs,
1306                             (*lhs.state().next_channel_id)++);
1307         if (windowed_op_is_lhs) {
1308           second_next_l = cp_output;
1309         } else {
1310           second_next_r = cp_output;
1311         }
1312         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1313 
1314         // ++i
1315         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1316             i->shape(), HloOpcode::kAdd, i,
1317             body_b.AddInstruction(HloInstruction::CreateConstant(
1318                 LiteralUtil::CreateR0<uint32>(1)))));
1319 
1320         body_b.AddInstruction(HloInstruction::CreateTuple(
1321             {second_next_l, second_next_r, o, extra_inout, i}));
1322       }
1323     } else {
1324       auto real_i = i;
1325       if (operands_sharded_at_contracting_dims) {
1326         // For reduce-scatter case, start from the data_partition_id + 1 to make
1327         // the data_partition_id of the final data shard in each partition the
1328         // same as the corresponding partition_id.
1329         real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1330             real_i->shape(), HloOpcode::kAdd, real_i,
1331             CreateOne(real_i->shape(), &body_b)));
1332       }
1333       TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1334 
1335       // ++i
1336       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1337           i->shape(), HloOpcode::kAdd, i,
1338           body_b.AddInstruction(HloInstruction::CreateConstant(
1339               LiteralUtil::CreateR0<uint32>(1)))));
1340       auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1341           ShapeUtil::MakeShape(PRED, {}), i,
1342           body_b.AddInstruction(HloInstruction::CreateConstant(
1343               LiteralUtil::CreateR0<uint32>(num_partitions))),
1344           ComparisonDirection::kLt));
1345       // Collective-permute for the next window. We don't need it for the last
1346       // iteration, so we use a conditional around the collective-permute.
1347       HloInstruction* conditional;
1348       {
1349         SpmdBuilder cp_b("window_collective_permute", original_hlo);
1350         {
1351           auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1352               0,
1353               operands_sharded_at_contracting_dims ? o->shape()
1354               : windowed_op_is_lhs                 ? l->shape()
1355                                                    : r->shape(),
1356               "window"));
1357           std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1358           for (int64 source = 0; source < num_partitions; ++source) {
1359             // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1360             sd_pairs[source] = {source,
1361                                 (source - 1 + num_partitions) % num_partitions};
1362           }
1363           lhs.state()
1364               .collective_ops_creator.create_cross_partition_collective_permute(
1365                   &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1366         }
1367         SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1368         {
1369           ncp_b.AddInstruction(HloInstruction::CreateParameter(
1370               0,
1371               operands_sharded_at_contracting_dims ? o->shape()
1372               : windowed_op_is_lhs                 ? l->shape()
1373                                                    : r->shape(),
1374               "window"));
1375         }
1376         conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1377             operands_sharded_at_contracting_dims ? o->shape()
1378             : windowed_op_is_lhs                 ? l->shape()
1379                                                  : r->shape(),
1380             has_more,
1381             operands_sharded_at_contracting_dims ? o
1382             : windowed_op_is_lhs                 ? l
1383                                                  : r,
1384             module->AddEmbeddedComputation(cp_b.Build()),
1385             operands_sharded_at_contracting_dims ? o
1386             : windowed_op_is_lhs                 ? l
1387                                                  : r,
1388             module->AddEmbeddedComputation(ncp_b.Build())));
1389       }
1390       if (operands_sharded_at_contracting_dims) {
1391         o = conditional;
1392       } else if (windowed_op_is_lhs) {
1393         l = conditional;
1394       } else {
1395         r = conditional;
1396       }
1397       body_b.AddInstruction(
1398           HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1399     }
1400 
1401     SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1402     auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1403         /*parameter_number=*/0,
1404         ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
1405                                    result_buffer->shape(),
1406                                    extra_buffer->shape(), iteration->shape()}),
1407         "param"));
1408     auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1409         iteration->shape(), cond_param, 4));
1410     int64 adapted_num_partitions =
1411         (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1412             ? num_partitions / 2
1413             : num_partitions;
1414     cond_b.AddInstruction(HloInstruction::CreateCompare(
1415         ShapeUtil::MakeShape(PRED, {}), cond_i,
1416         cond_b.AddInstruction(HloInstruction::CreateConstant(
1417             LiteralUtil::CreateR0<uint32>(adapted_num_partitions))),
1418         ComparisonDirection::kLt));
1419     auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1420         cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1421         module->AddEmbeddedComputation(body_b.Build()),
1422         b->AddInstruction(HloInstruction::CreateTuple(
1423             {lhs.hlo(), rhs.hlo(), result_buffer, extra_buffer, iteration}))));
1424     windowed_dot_general_loops->push_back(
1425         {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1426          windowed_at_batch_dims, operands_sharded_at_contracting_dims});
1427     auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1428         result_buffer->shape(), while_loop, 2));
1429     if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1430          (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1431         operands_sharded_at_contracting_dims) {
1432       std::vector<std::pair<int64, int64>> extra_sd_pairs(num_partitions);
1433       for (int64 source = 0; source < num_partitions; ++source) {
1434         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1435         extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1436       }
1437       auto extra_result =
1438           b->AddInstruction(HloInstruction::CreateGetTupleElement(
1439               extra_buffer->shape(), while_loop, 3));
1440       if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1441         extra_result = lhs.state()
1442                            .collective_ops_creator
1443                            .create_cross_partition_collective_permute(
1444                                b, extra_result, extra_sd_pairs,
1445                                (*lhs.state().next_channel_id)++);
1446       }
1447       if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1448         result = lhs.state()
1449                      .collective_ops_creator
1450                      .create_cross_partition_collective_permute(
1451                          b, result, extra_sd_pairs,
1452                          (*lhs.state().next_channel_id)++);
1453       }
1454       result = b->AddInstruction(HloInstruction::CreateBinary(
1455           result->shape(), HloOpcode::kAdd, result, extra_result));
1456     }
1457     if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1458                                unpadded_result_buffer_shape)) {
1459       result = b->AddInstruction(HloInstruction::CreateSlice(
1460           unpadded_result_buffer_shape, result,
1461           std::vector<int64>(padded_result_buffer_shape.rank(), 0),
1462           unpadded_result_buffer_shape.dimensions(),
1463           std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
1464     }
1465     return result;
1466   };
1467   absl::optional<WindowedEinsumConfig> e_config =
1468       GetWindowedEinsumConfiguration(
1469           num_partitions, output_lhs_non_contracting_partitions,
1470           output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1471           rhs_non_contracting_partitions, rhs_batch_partitions,
1472           lhs_contracting_partitions, lhs_non_contracting_partitions,
1473           lhs_batch_partitions, output_sharding_dim,
1474           ShapeSizeInBytes(rhs.base_shape()),
1475           ShapeSizeInBytes(lhs.base_shape()),
1476           ShapeSizeInBytes(output_base_shape),
1477           options.threshold_for_windowed_einsum_mib,
1478           output_sharding_transposed_to_match_lhs,
1479           output_sharding_transposed_to_match_rhs, lhs_sharding, rhs_sharding);
1480   if (e_config) {
1481     return emit_windowed_dot_general(*e_config);
1482   }
1483 
1484   {
1485     // Try batch-parallel by resharding one operand, and allowing all-reduce.
1486     TF_ASSIGN_OR_RETURN(
1487         HloInstruction * partitioned_dot,
1488         try_emit_output_batch_partitioned_einsum_with_reshard(true));
1489     if (partitioned_dot) {
1490       return partitioned_dot;
1491     }
1492   }
1493 
1494   // LHS and RHS have the same partitioned contracting dimensions.
1495   if (lhs_contracting_partitions == rhs_contracting_partitions &&
1496       lhs_contracting_partitions == num_partitions) {
1497     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1498         LiteralUtil::Zero(output_base_shape.element_type())));
1499     // Pad both sides with zero, since NaN at one side cannot be masked by zero
1500     // on the other side.
1501     if (ShapeSizeInBytes(lhs.base_shape()) <
1502         ShapeSizeInBytes(rhs.base_shape())) {
1503       lhs =
1504           lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1505       rhs = rhs.PadWithValue(zero);
1506     } else {
1507       lhs = lhs.PadWithValue(zero);
1508       rhs =
1509           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1510     }
1511     TF_ASSIGN_OR_RETURN(
1512         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1513     std::vector<int64> lhs_contracting_dims;
1514     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1515     for (const auto& cd : dims_mapping.contracting_dims) {
1516       lhs_contracting_dims.push_back(cd.lhs);
1517     }
1518     auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1519         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1520         lhs_contracting_dims, lhs.state().collective_ops_creator,
1521         MakeBinaryAdd(output_base_shape.element_type(), module));
1522     ar->set_sharding(HloSharding::Replicate());
1523     return PartitionedHlo(ar, output_base_shape, lhs.state())
1524         .Reshard(output_sharding)
1525         .hlo();
1526   }
1527 
1528   // LHS and output have the same partitioned non-contracting dimensions.
1529   if (lhs_non_contracting_partitions == num_partitions &&
1530       output_lhs_non_contracting_partitions == num_partitions &&
1531       lhs_sharding_transposed_to_match_output == output_sharding) {
1532     auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1533     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1534                                                      b, conv_window));
1535     return dot;
1536   }
1537 
1538   // RHS and output have the same partitioned non-contracting dimensions.
1539   if (rhs_non_contracting_partitions == num_partitions &&
1540       output_rhs_non_contracting_partitions == num_partitions &&
1541       rhs_sharding_transposed_to_match_output == output_sharding) {
1542     auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1543     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1544                                                      b, conv_window));
1545     return dot;
1546   }
1547 
1548   if (may_reshard_without_detecting_match) {
1549     // Output is batch partitioned.
1550     if (output_batch_partitions == num_partitions) {
1551       auto resharded_lhs =
1552           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1553       auto resharded_rhs =
1554           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1555       TF_ASSIGN_OR_RETURN(
1556           auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1557                                        b, conv_window));
1558       return dot;
1559     }
1560     // Output is partitioned along LHS non-contracting dimensions.
1561     if (output_lhs_non_contracting_partitions == num_partitions) {
1562       auto resharded_lhs =
1563           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1564       auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1565       TF_ASSIGN_OR_RETURN(
1566           auto dot, create_sharded_dot(resharded_lhs.hlo(),
1567                                        replicated_rhs.hlo(), b, conv_window));
1568       return dot;
1569     }
1570     // Output is partitioned along RHS non-contracting dimensions.
1571     if (output_rhs_non_contracting_partitions == num_partitions) {
1572       auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1573       auto resharded_rhs =
1574           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1575       TF_ASSIGN_OR_RETURN(
1576           auto dot, create_sharded_dot(replicated_lhs.hlo(),
1577                                        resharded_rhs.hlo(), b, conv_window));
1578       return dot;
1579     }
1580   }
1581 
1582   // Returns true if it is beneficial to reshard the operand at `operand_idx`
1583   // across the contracting dimension.
1584   const auto should_partition_contracting_dim = [&](int64 operand_idx) {
1585     if (!output_sharding.IsReplicated()) {
1586       return false;
1587     }
1588 
1589     if (operand_idx == 0) {
1590       // If LHS and output are replicated, we compare the cost of all-gather
1591       // on RHS vs all-reduce on the output.
1592       return (rhs_contracting_partitions == num_partitions) &&
1593              lhs.sharding().IsReplicated() &&
1594              ShapeUtil::ElementsIn(rhs.base_shape()) >
1595                  ShapeUtil::ElementsIn(output_base_shape);
1596     } else {
1597       return (lhs_contracting_partitions == num_partitions) &&
1598              rhs.sharding().IsReplicated() &&
1599              ShapeUtil::ElementsIn(lhs.base_shape()) >
1600                  ShapeUtil::ElementsIn(output_base_shape);
1601     }
1602   };
1603 
1604   // When the output is replicated and one of the operands is partitioned along
1605   // contracting dimension, align the other operand to be partitioned along
1606   // the contracting dimensions.
1607   if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1608                                          should_partition_contracting_dim(1))) {
1609     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1610         LiteralUtil::Zero(output_base_shape.element_type())));
1611     if (should_partition_contracting_dim(0)) {
1612       lhs =
1613           lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1614       rhs = rhs.PadWithValue(zero);
1615     } else {
1616       lhs = lhs.PadWithValue(zero);
1617       rhs =
1618           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1619     }
1620     TF_ASSIGN_OR_RETURN(
1621         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1622 
1623     std::vector<int64> lhs_contracting_dims;
1624     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1625     for (const auto& cd : dims_mapping.contracting_dims) {
1626       lhs_contracting_dims.push_back(cd.lhs);
1627     }
1628     return lhs.state().partitioner->AllReduceAlongShardingDims(
1629         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1630         lhs_contracting_dims, lhs.state().collective_ops_creator,
1631         MakeBinaryAdd(output_base_shape.element_type(), module));
1632   }
1633   return nullptr;
1634 }
1635 
1636 StatusOr<HloInstruction*> PartitionDot(
1637     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1638     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1639     int64 num_partitions,
1640     const std::function<StatusOr<HloInstruction*>(
1641         HloInstruction*, HloInstruction*, SpmdBuilder*,
1642         const Window& conv_window)>& create_sharded_dot,
1643     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1644     const SpmdPartitionerOptions& options, SpmdBuilder* b,
1645     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1646         windowed_dot_general_loops);
1647 
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1648 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
1649     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1650     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1651     int64 num_partitions, int64 lhs_contracting_partitions,
1652     int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
1653     int64 rhs_non_contracting_partitions,
1654     const std::function<StatusOr<HloInstruction*>(
1655         HloInstruction*, HloInstruction*, SpmdBuilder*,
1656         const Window& conv_window)>& create_sharded_dot,
1657     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1658     bool require_matching_devices_to_group,
1659     const SpmdPartitionerOptions& options, SpmdBuilder* b,
1660     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1661         windowed_dot_general_loops) {
1662   std::vector<std::pair<HloInstruction*, HloSharding>>
1663       top_level_sharding_to_reset;
1664   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1665     for (auto& to_reset : top_level_sharding_to_reset) {
1666       to_reset.first->set_sharding(to_reset.second);
1667     }
1668   });
1669   std::vector<int64> lhs_dims;
1670   std::vector<int64> rhs_dims;
1671   std::vector<int64> output_dims;
1672   auto lhs_sharding_dims_adjusted_to_output =
1673       lhs.sharding().IsReplicated()
1674           ? std::vector<int64>(lhs.base_shape().rank(), 1)
1675           : lhs.sharding().tile_assignment().dimensions();
1676   auto rhs_sharding_dims_adjusted_to_output =
1677       rhs.sharding().IsReplicated()
1678           ? std::vector<int64>(rhs.base_shape().rank(), 1)
1679           : rhs.sharding().tile_assignment().dimensions();
1680   auto output_sharding_dims_adjusted_to_lhs =
1681       output_sharding.tile_assignment().dimensions();
1682   bool lhs_rhs_dims_matching = true;
1683   for (const auto& dim : dims_mapping.batch_dims) {
1684     lhs_dims.push_back(dim.lhs);
1685     rhs_dims.push_back(dim.rhs);
1686     output_dims.push_back(dim.output);
1687     if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
1688         rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
1689       lhs_rhs_dims_matching = false;
1690     }
1691     lhs_sharding_dims_adjusted_to_output[dim.lhs] =
1692         output_sharding.tile_assignment().dim(dim.output);
1693     rhs_sharding_dims_adjusted_to_output[dim.rhs] =
1694         output_sharding.tile_assignment().dim(dim.output);
1695     output_sharding_dims_adjusted_to_lhs[dim.output] =
1696         lhs.sharding().tile_assignment().dim(dim.lhs);
1697   }
1698   if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
1699     lhs_rhs_dims_matching =
1700         rhs.sharding() == UngroupSharding(AlignGroupsWith(
1701                               GroupShardingOnDims(rhs.sharding(), rhs_dims),
1702                               GroupShardingOnDims(lhs.sharding(), lhs_dims)));
1703   }
1704   auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
1705   PartitionedHlo per_group_lhs = lhs;
1706   PartitionedHlo per_group_rhs = rhs;
1707   if (lhs_rhs_dims_matching) {
1708     auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
1709     auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
1710     if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
1711         ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
1712       rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
1713       rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
1714     } else {
1715       lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
1716       lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
1717     }
1718     auto reshaped_output_tiling = output_sharding.tile_assignment();
1719     reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
1720     output_grouped = AlignGroupsWith(
1721         GroupShardingOnDims(
1722             output_sharding.ReplicateOnLastTileDim()
1723                 ? HloSharding::PartialTile(reshaped_output_tiling)
1724                 : HloSharding::Tile(reshaped_output_tiling),
1725             output_dims),
1726         lhs_grouped);
1727     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1728         lhs.state(), lhs_grouped.device_groups, b);
1729     top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
1730     lhs.hlo()->set_sharding(lhs_grouped.sharding);
1731     top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
1732     rhs.hlo()->set_sharding(rhs_grouped.sharding);
1733     CHECK(lhs.hlo() != rhs.hlo() ||
1734           lhs_grouped.sharding == rhs_grouped.sharding);
1735     per_group_lhs = PartitionedHlo(
1736         lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
1737         per_group_partitioner_state);
1738     per_group_rhs = PartitionedHlo(
1739         rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
1740         per_group_partitioner_state);
1741   } else {
1742     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1743         lhs.state(), output_grouped.device_groups, b);
1744     auto reshard_to_output_batch =
1745         [&](PartitionedHlo operand, absl::Span<const int64> batch_dims,
1746             absl::Span<const int64> contracting_dims,
1747             absl::Span<const int64> non_contracting_dims,
1748             int64 contracting_dim_partitions,
1749             int64 non_contracting_dim_partitions,
1750             int64 other_contracting_dim_partitions,
1751             std::vector<int64>* sharding_dims_adjusted_to_output)
1752         -> absl::optional<PartitionedHlo> {
1753       if (operand.sharding().IsTileMaximal()) {
1754         auto partially_sharded = PerGroupSliceFromReplicated(
1755             operand.Replicate().hlo(), operand.state().partition_id,
1756             output_grouped.device_groups, batch_dims,
1757             output_grouped.group_dim_sizes, b);
1758         partially_sharded->set_sharding(HloSharding::Replicate());
1759         return PartitionedHlo(partially_sharded, partially_sharded->shape(),
1760                               per_group_partitioner_state);
1761       }
1762       auto reshaped_tiling = operand.sharding().tile_assignment();
1763       // It's possible that the operand is not initially sharded on batch
1764       // dimensions in the same way as the output, although being tiled. In that
1765       // case, the current sharding_dims_adjusted_to_output may contain more
1766       // partitions than available devices. We remove partitioning on other
1767       // dimensions.
1768       if (Product(*sharding_dims_adjusted_to_output) >
1769           reshaped_tiling.num_elements()) {
1770         if (Product(*sharding_dims_adjusted_to_output) %
1771                 reshaped_tiling.num_elements() !=
1772             0) {
1773           return absl::nullopt;
1774         }
1775         int64 ratio = Product(*sharding_dims_adjusted_to_output) /
1776                       reshaped_tiling.num_elements();
1777         if (operand.sharding().ReplicateOnLastTileDim() &&
1778             reshaped_tiling.dimensions().back() % ratio == 0) {
1779           sharding_dims_adjusted_to_output->back() /= ratio;
1780           if (sharding_dims_adjusted_to_output->back() == 1) {
1781             sharding_dims_adjusted_to_output->pop_back();
1782           }
1783         } else if (ratio == non_contracting_dim_partitions &&
1784                    (ratio != contracting_dim_partitions ||
1785                     contracting_dim_partitions ==
1786                         other_contracting_dim_partitions)) {
1787           for (int64 dim : non_contracting_dims) {
1788             (*sharding_dims_adjusted_to_output)[dim] = 1;
1789           }
1790         } else if (ratio == contracting_dim_partitions) {
1791           for (int64 dim : contracting_dims) {
1792             (*sharding_dims_adjusted_to_output)[dim] = 1;
1793           }
1794         } else {
1795           return absl::nullopt;
1796         }
1797       }
1798       // If the operand is initially sharded more ways than the output in the
1799       // batch dimensions, sharding_dims_adjusted_to_output currently contains
1800       // fewer partitions than available devices. We do not handle this case.
1801       if (Product(*sharding_dims_adjusted_to_output) <
1802           reshaped_tiling.num_elements()) {
1803         return absl::nullopt;
1804       }
1805       reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
1806       auto grouped = AlignGroupsWith(
1807           GroupShardingOnDims(operand.base_shape().rank() <
1808                                       sharding_dims_adjusted_to_output->size()
1809                                   ? HloSharding::PartialTile(reshaped_tiling)
1810                                   : HloSharding::Tile(reshaped_tiling),
1811                               batch_dims),
1812           output_grouped);
1813       if (require_matching_devices_to_group &&
1814           operand.sharding() != UngroupSharding(grouped)) {
1815         return absl::nullopt;
1816       }
1817       auto resharded = operand.Reshard(UngroupSharding(grouped));
1818       top_level_sharding_to_reset.emplace_back(resharded.hlo(),
1819                                                resharded.sharding());
1820       resharded.hlo()->set_sharding(grouped.sharding);
1821       return PartitionedHlo(resharded.hlo(),
1822                             GetPerGroupBaseShape(grouped, operand.base_shape()),
1823                             per_group_partitioner_state);
1824     };
1825     std::vector<int64> lhs_contracting_dims;
1826     std::vector<int64> rhs_contracting_dims;
1827     lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1828     rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1829     for (const auto& dim : dims_mapping.contracting_dims) {
1830       lhs_contracting_dims.push_back(dim.lhs);
1831       rhs_contracting_dims.push_back(dim.rhs);
1832     }
1833     std::vector<int64> lhs_non_contracting_dims;
1834     std::vector<int64> rhs_non_contracting_dims;
1835     lhs_non_contracting_dims.reserve(
1836         dims_mapping.lhs_non_contracting_dims.size());
1837     rhs_non_contracting_dims.reserve(
1838         dims_mapping.rhs_non_contracting_dims.size());
1839     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
1840       lhs_non_contracting_dims.push_back(dim.lhs);
1841     }
1842     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
1843       rhs_non_contracting_dims.push_back(dim.rhs);
1844     }
1845     if (auto resharded = reshard_to_output_batch(
1846             lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
1847             lhs_contracting_partitions, lhs_non_contracting_partitions,
1848             rhs_contracting_partitions,
1849             &lhs_sharding_dims_adjusted_to_output)) {
1850       per_group_lhs = *resharded;
1851     } else {
1852       return nullptr;
1853     }
1854     if (auto resharded = reshard_to_output_batch(
1855             rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
1856             rhs_contracting_partitions, rhs_non_contracting_partitions,
1857             lhs_contracting_partitions,
1858             &rhs_sharding_dims_adjusted_to_output)) {
1859       per_group_rhs = *resharded;
1860     } else {
1861       return nullptr;
1862     }
1863     CHECK(lhs.hlo() != rhs.hlo() ||
1864           per_group_lhs.sharding() == per_group_rhs.sharding());
1865   }
1866   TF_ASSIGN_OR_RETURN(
1867       auto dot,
1868       PartitionDot(per_group_lhs, per_group_rhs,
1869                    GetPerGroupBaseShape(output_grouped, output_base_shape),
1870                    output_grouped.sharding, dims_mapping,
1871                    num_partitions / output_grouped.device_groups.size(),
1872                    create_sharded_dot, conv_window, module, original_hlo,
1873                    options, b, windowed_dot_general_loops));
1874   dot->set_sharding(UngroupSharding(output_grouped));
1875   return PartitionedHlo(dot, output_base_shape, lhs.state())
1876       .Reshard(output_sharding)
1877       .hlo();
1878 }
1879 
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)1880 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
1881     bool lhs_matching, const HloSharding& matching_sharding,
1882     const HloSharding& output_sharding,
1883     absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
1884   std::vector<int64> matching_sharding_dims =
1885       matching_sharding.tile_assignment().dimensions();
1886   std::vector<int64> matching_dims;
1887   std::vector<int64> output_dims;
1888   // Make sure the partitioning on matching's non-contracting dimensions
1889   // defines the same device groups for both matching and output.
1890   for (const auto& dim : partitioned_dims) {
1891     int64 md = lhs_matching ? dim.lhs : dim.rhs;
1892     matching_sharding_dims[md] =
1893         output_sharding.tile_assignment().dim(dim.output);
1894     matching_dims.push_back(md);
1895     output_dims.push_back(dim.output);
1896   }
1897   GroupedSharding output_grouped =
1898       GroupShardingOnDims(output_sharding, output_dims);
1899   Array<int64> reshaped_matching_tiling = matching_sharding.tile_assignment();
1900   reshaped_matching_tiling.Reshape(matching_sharding_dims);
1901   return AlignGroupsWith(
1902       GroupShardingOnDims(
1903           matching_sharding.ReplicateOnLastTileDim()
1904               ? HloSharding::PartialTile(reshaped_matching_tiling)
1905               : HloSharding::Tile(reshaped_matching_tiling),
1906           matching_dims),
1907       output_grouped);
1908 }
1909 
1910 absl::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64 other_contracting_partitions,int64 other_non_contracting_partitions,int64 matching_contracting_partitions,int64 output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)1911 GetNonContractingPartitionGroupedShardingForOtherOperand(
1912     bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
1913     int64 other_contracting_partitions, int64 other_non_contracting_partitions,
1914     int64 matching_contracting_partitions,
1915     int64 output_other_non_contracting_partitions,
1916     const HloSharding& other_sharding, const HloSharding& output_sharding,
1917     absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
1918     absl::Span<const DotConvDimsMapping::DimsMapping>
1919         other_non_contracting_dims,
1920     absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
1921   int64 group_count = 1;
1922   std::vector<int64> output_dims;
1923   for (const auto& dim : matching_partitioned_dims) {
1924     output_dims.push_back(dim.output);
1925     group_count *= output_sharding.tile_assignment().dim(dim.output);
1926   }
1927   GroupedSharding output_grouped =
1928       GroupShardingOnDims(output_sharding, output_dims);
1929   std::vector<int64> other_group_dims;
1930   if (other_sharding.ReplicateOnLastTileDim() &&
1931       other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
1932     other_group_dims.push_back(
1933         other_sharding.tile_assignment().num_dimensions() - 1);
1934   } else {
1935     const bool may_replicate_other_contracting_dims =
1936         (other_contracting_partitions == group_count &&
1937          other_non_contracting_partitions ==
1938              output_other_non_contracting_partitions);
1939     const bool may_replicate_other_non_contracting_dims =
1940         group_count == other_non_contracting_partitions &&
1941         matching_contracting_partitions == other_contracting_partitions;
1942     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
1943             other_sharding, output_grouped.device_groups)) {
1944       other_group_dims = std::move(*found_dims);
1945     } else if (may_replicate_other_contracting_dims &&
1946                (!may_replicate_other_non_contracting_dims ||
1947                 ShapeUtil::ByteSizeOf(other_shape)) <=
1948                    ShapeUtil::ByteSizeOf(MakePartitionedShape(
1949                        output_base_shape, output_sharding))) {
1950       for (const auto& dim : other_contracting_dims) {
1951         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
1952       }
1953     } else if (may_replicate_other_non_contracting_dims) {
1954       for (const auto& dim : other_non_contracting_dims) {
1955         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
1956       }
1957     } else {
1958       return absl::nullopt;
1959     }
1960   }
1961   if (other_group_dims.size() == 1 &&
1962       other_group_dims[0] ==
1963           other_sharding.tile_assignment().num_dimensions() - 1) {
1964     return AlignGroupsWith(
1965         GroupShardingOnDims(
1966             other_sharding, {other_group_dims[0]},
1967             {other_sharding.tile_assignment().dimensions().back() /
1968              group_count}),
1969         output_grouped, /*ignore_group_order=*/true);
1970 
1971   } else if (!other_sharding.IsReplicated()) {
1972     return AlignGroupsWith(
1973         GroupShardingOnDims(other_sharding, other_group_dims), output_grouped,
1974         /*ignore_group_order=*/true);
1975   }
1976   return absl::nullopt;
1977 }
1978 
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64 matching_contracting_partitions,int64 other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64 other_non_contracting_partitions,int64 output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1979 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
1980     bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
1981     int64 matching_contracting_partitions, int64 other_contracting_partitions,
1982     absl::Span<const DotConvDimsMapping::DimsMapping>
1983         partitioned_non_contracting_dims,
1984     int64 other_non_contracting_partitions,
1985     int64 output_other_non_contracting_partitions,
1986     const Shape& output_base_shape, const HloSharding& output_sharding,
1987     const DotConvDimsMapping& dims_mapping, int64 num_partitions,
1988     const std::function<StatusOr<HloInstruction*>(
1989         HloInstruction*, HloInstruction*, SpmdBuilder*,
1990         const Window& conv_window)>& create_sharded_dot,
1991     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1992     bool require_matching_devices_to_group,
1993     const SpmdPartitionerOptions& options, SpmdBuilder* b,
1994     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1995         windowed_dot_general_loops) {
1996   std::vector<std::pair<HloInstruction*, HloSharding>>
1997       top_level_sharding_to_reset;
1998   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1999     for (auto& to_reset : top_level_sharding_to_reset) {
2000       to_reset.first->set_sharding(to_reset.second);
2001     }
2002   });
2003 
2004   std::vector<int64> output_dims;
2005   for (const auto& dim : partitioned_non_contracting_dims) {
2006     output_dims.push_back(dim.output);
2007   }
2008   GroupedSharding output_grouped =
2009       GroupShardingOnDims(output_sharding, output_dims);
2010   GroupedSharding matching_grouped =
2011       GetNonContractingPartitionGroupedShardingForMatchedOperand(
2012           lhs_matching, matching.sharding(), output_sharding,
2013           partitioned_non_contracting_dims);
2014   if (require_matching_devices_to_group &&
2015       matching.sharding() != UngroupSharding(matching_grouped)) {
2016     return nullptr;
2017   }
2018   absl::optional<GroupedSharding> other_grouped =
2019       GetNonContractingPartitionGroupedShardingForOtherOperand(
2020           lhs_matching, output_base_shape, other.hlo()->shape(),
2021           other_contracting_partitions, other_non_contracting_partitions,
2022           matching_contracting_partitions,
2023           output_other_non_contracting_partitions, other.sharding(),
2024           output_sharding, partitioned_non_contracting_dims,
2025           lhs_matching ? dims_mapping.rhs_non_contracting_dims
2026                        : dims_mapping.lhs_non_contracting_dims,
2027           dims_mapping.contracting_dims);
2028 
2029   if (!other_grouped) {
2030     other = other.Replicate();
2031   }
2032   matching = matching.Reshard(UngroupSharding(matching_grouped));
2033   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2034       matching.state(), matching_grouped.device_groups, b);
2035   top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2036   matching.hlo()->set_sharding(matching_grouped.sharding);
2037   auto matching_p = PartitionedHlo(
2038       matching.hlo(),
2039       GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2040       per_group_partitioner_state);
2041 
2042   auto partially_replicated_other = other.hlo();
2043   if (other_grouped && other_grouped->group_dims.size() == 1 &&
2044       other_grouped->group_dims[0] == other.base_shape().rank()) {
2045     // Group on replication dim.
2046     other = other.Reshard(UngroupSharding(*other_grouped));
2047     partially_replicated_other = other.hlo();
2048     top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2049     partially_replicated_other->set_sharding(other_grouped->sharding);
2050   } else if (!other.sharding().IsReplicated()) {
2051     other = other.Reshard(UngroupSharding(*other_grouped));
2052     partially_replicated_other =
2053         other
2054             .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2055                 other.sharding(), other_grouped->group_dims))
2056             .hlo();
2057     top_level_sharding_to_reset.emplace_back(
2058         partially_replicated_other, partially_replicated_other->sharding());
2059     partially_replicated_other->set_sharding(other_grouped->sharding);
2060   }
2061   auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2062                                 per_group_partitioner_state);
2063   TF_ASSIGN_OR_RETURN(
2064       auto dot,
2065       PartitionDot(lhs_matching ? matching_p : other_p,
2066                    lhs_matching ? other_p : matching_p,
2067                    GetPerGroupBaseShape(output_grouped, output_base_shape),
2068                    output_grouped.sharding, dims_mapping,
2069                    num_partitions / matching_grouped.device_groups.size(),
2070                    create_sharded_dot, conv_window, module, original_hlo,
2071                    options, b, windowed_dot_general_loops));
2072   return dot;
2073 }
2074 
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contractin_dims,int64 output_batch_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2075 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2076     PartitionedHlo lhs, PartitionedHlo rhs,
2077     absl::Span<const DotConvDimsMapping::DimsMapping>
2078         partitioned_contractin_dims,
2079     int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions,
2080     int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape,
2081     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2082     int64 num_partitions,
2083     const std::function<StatusOr<HloInstruction*>(
2084         HloInstruction*, HloInstruction*, SpmdBuilder*,
2085         const Window& conv_window)>& create_sharded_dot,
2086     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2087     bool require_matching_devices_to_group,
2088     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2089     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2090         windowed_dot_general_loops) {
2091   std::vector<std::pair<HloInstruction*, HloSharding>>
2092       top_level_sharding_to_reset;
2093   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2094     for (auto& to_reset : top_level_sharding_to_reset) {
2095       to_reset.first->set_sharding(to_reset.second);
2096     }
2097   });
2098   auto lhs_sharding = lhs.sharding();
2099   auto rhs_sharding = rhs.sharding();
2100   auto lhs_tile_shape = lhs_sharding.tile_assignment().dimensions();
2101   auto rhs_tile_shape = rhs_sharding.tile_assignment().dimensions();
2102   std::vector<int64> lhs_dims;
2103   std::vector<int64> rhs_dims;
2104   int64 group_count = 1;
2105   for (const auto& dim : partitioned_contractin_dims) {
2106     lhs_dims.push_back(dim.lhs);
2107     rhs_dims.push_back(dim.rhs);
2108     group_count *= lhs_sharding.tile_assignment().dim(dim.lhs);
2109   }
2110   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2111       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2112     for (const auto& dim : partitioned_contractin_dims) {
2113       rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2114     }
2115     auto new_tile = rhs.sharding().tile_assignment();
2116     new_tile.Reshape(rhs_tile_shape);
2117     rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2118                        ? HloSharding::PartialTile(new_tile)
2119                        : HloSharding::Tile(new_tile);
2120   } else {
2121     for (const auto& dim : partitioned_contractin_dims) {
2122       lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2123     }
2124     auto new_tile = lhs.sharding().tile_assignment();
2125     new_tile.Reshape(lhs_tile_shape);
2126     lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2127                        ? HloSharding::PartialTile(new_tile)
2128                        : HloSharding::Tile(new_tile);
2129   }
2130   auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2131   auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2132   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2133       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2134     rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2135     rhs_sharding = UngroupSharding(rhs_grouped);
2136     if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2137       return nullptr;
2138     }
2139     rhs = rhs.Reshard(rhs_sharding);
2140   } else {
2141     lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2142     lhs_sharding = UngroupSharding(lhs_grouped);
2143     if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2144       return nullptr;
2145     }
2146     lhs = lhs.Reshard(lhs_sharding);
2147   }
2148   // Mask out invalid data.
2149   std::vector<int64> lhs_skipped_dims;
2150   for (int64 i = 0; i < lhs.base_shape().rank(); ++i) {
2151     if (absl::c_linear_search(lhs_dims, i)) {
2152       continue;
2153     }
2154     lhs_skipped_dims.push_back(i);
2155   }
2156   lhs = lhs.PadWithValue(
2157       CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
2158       /*left_padded_dims=*/{}, lhs_skipped_dims);
2159   std::vector<int64> rhs_skipped_dims;
2160   for (int64 i = 0; i < rhs.base_shape().rank(); ++i) {
2161     if (absl::c_linear_search(rhs_dims, i)) {
2162       continue;
2163     }
2164     rhs_skipped_dims.push_back(i);
2165   }
2166   rhs = rhs.PadWithValue(
2167       CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
2168       /*left_padded_dims=*/{}, rhs_skipped_dims);
2169   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2170   lhs.hlo()->set_sharding(lhs_grouped.sharding);
2171   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2172   rhs.hlo()->set_sharding(rhs_grouped.sharding);
2173 
2174   HloSharding inner_output_sharding = HloSharding::Replicate();
2175   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2176   if (output_sharding.ReplicateOnLastTileDim() &&
2177       output_sharding.tile_assignment().dimensions().back() % group_count ==
2178           0) {
2179     auto grouped = AlignGroupsWith(
2180         GroupShardingOnDims(
2181             output_sharding,
2182             {output_sharding.tile_assignment().num_dimensions() - 1},
2183             {output_sharding.tile_assignment().dimensions().back() /
2184              group_count}),
2185         lhs_grouped,
2186         /*ignore_group_order=*/true);
2187     outer_output_tmp_sharding = UngroupSharding(grouped);
2188     inner_output_sharding = std::move(grouped.sharding);
2189   } else {
2190     std::vector<int64> group_dims;
2191     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2192             output_sharding, lhs_grouped.device_groups)) {
2193       group_dims = std::move(*found_dims);
2194     } else if (output_lhs_non_contracting_partitions == group_count ||
2195                output_rhs_non_contracting_partitions == group_count ||
2196                output_batch_partitions == group_count) {
2197       if (output_lhs_non_contracting_partitions == group_count) {
2198         for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2199           group_dims.push_back(dim.output);
2200         }
2201       } else if (output_rhs_non_contracting_partitions == group_count) {
2202         for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2203           group_dims.push_back(dim.output);
2204         }
2205       } else {
2206         for (const auto& dim : dims_mapping.batch_dims) {
2207           group_dims.push_back(dim.output);
2208         }
2209       }
2210     }
2211     if (!group_dims.empty()) {
2212       auto grouped = AlignGroupsWith(
2213           GroupShardingOnDims(output_sharding, group_dims), lhs_grouped);
2214       inner_output_sharding = grouped.sharding;
2215       outer_output_tmp_sharding =
2216           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2217               UngroupSharding(grouped), group_dims);
2218     }
2219   }
2220   auto inner_state = CreatePerGroupPartitioningState(
2221       lhs.state(), lhs_grouped.device_groups, b);
2222   TF_ASSIGN_OR_RETURN(
2223       auto dot,
2224       PartitionDot(
2225           PartitionedHlo(lhs.hlo(),
2226                          GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2227                          inner_state),
2228           PartitionedHlo(rhs.hlo(),
2229                          GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2230                          inner_state),
2231           output_base_shape, inner_output_sharding, dims_mapping,
2232           num_partitions / group_count, create_sharded_dot, conv_window, module,
2233           original_hlo, options, b, windowed_dot_general_loops));
2234   if (!dot) {
2235     return nullptr;
2236   }
2237   auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2238       b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2239       lhs.state().collective_ops_creator,
2240       MakeBinaryAdd(output_base_shape.element_type(), module));
2241   ar->set_sharding(outer_output_tmp_sharding);
2242   return PartitionedHlo(ar, output_base_shape, lhs.state())
2243       .Reshard(output_sharding)
2244       .hlo();
2245 }
2246 
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2247 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2248     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2249   const auto& dnums = original_hlo->convolution_dimension_numbers();
2250   DotConvDimsMapping new_dims_mapping;
2251   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2252   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2253   // Append batch dims.
2254   new_dims_mapping.batch_dims.emplace_back();
2255   new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2256   new_dims_mapping.batch_dims.back().rhs =
2257       dnums.kernel_output_feature_dimension();
2258   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2259   new_dims_mapping.batch_dims.back().spatial = -1;
2260   // Setup non contracting dims.
2261   new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2262   new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2263       dnums.input_batch_dimension();
2264   new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2265   new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2266       dnums.kernel_input_feature_dimension();
2267   return new_dims_mapping;
2268 }
2269 
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2270 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2271     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2272   const auto& dnums = original_hlo->convolution_dimension_numbers();
2273   DotConvDimsMapping new_dims_mapping;
2274   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2275   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2276   new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2277   // Append batch dims.
2278   new_dims_mapping.batch_dims.emplace_back();
2279   new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2280   new_dims_mapping.batch_dims.back().rhs =
2281       dnums.kernel_output_feature_dimension();
2282   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2283   new_dims_mapping.batch_dims.back().spatial = -1;
2284   return new_dims_mapping;
2285 }
2286 
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64 num_partitions,int64 lhs_non_contracting_partitions,int64 rhs_non_contracting_partitions,int64 lhs_contracting_partitions,int64 rhs_contracting_partitions,int64 output_lhs_non_contracting_partitions,int64 output_rhs_non_contracting_partitions,int64 lhs_batch_partitions,int64 rhs_batch_partitions,bool may_group_on_lhs_non_contracting,bool may_group_on_rhs_non_contracting)2287 bool LhsIsBestMatchForNonContractingPartitioning(
2288     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2289     const PartitionedHlo& rhs, const Shape& output_base_shape,
2290     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2291     int64 num_partitions, int64 lhs_non_contracting_partitions,
2292     int64 rhs_non_contracting_partitions, int64 lhs_contracting_partitions,
2293     int64 rhs_contracting_partitions,
2294     int64 output_lhs_non_contracting_partitions,
2295     int64 output_rhs_non_contracting_partitions, int64 lhs_batch_partitions,
2296     int64 rhs_batch_partitions, bool may_group_on_lhs_non_contracting,
2297     bool may_group_on_rhs_non_contracting) {
2298   // If both match output non-contracting dimensions, choose the one which
2299   // will result in smaller replication of the other operand.
2300   bool lhs_matching = may_group_on_lhs_non_contracting &&
2301                       (!may_group_on_rhs_non_contracting ||
2302                        lhs_non_contracting_partitions *
2303                                ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
2304                            rhs_non_contracting_partitions *
2305                                ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
2306   // If both grouping are available and the option to choose faster windowed
2307   // einsums vs saving memory is enabled then try to determine which of the
2308   // operands will generate the least amount of iterations for the windowed
2309   // einsum when matched (if a windowed einsum is gonna be generated at all).
2310   if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
2311       options.choose_faster_windowed_einsum_over_mem) {
2312     const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2313         dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2314         output_base_shape.rank());
2315     auto subsequent_einsum_iterations_estimate =
2316         [&](bool assume_lhs_match) -> absl::optional<int64> {
2317       const std::vector<DotConvDimsMapping::DimsMapping>&
2318           matching_non_contracting_dims =
2319               assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2320                                : dims_mapping.rhs_non_contracting_dims;
2321       const std::vector<DotConvDimsMapping::DimsMapping>&
2322           other_non_contracting_dims =
2323               assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2324                                : dims_mapping.lhs_non_contracting_dims;
2325       const std::vector<int64>& output_to_matching_indices =
2326           assume_lhs_match ? indices_map.output_to_lhs_indices
2327                            : indices_map.output_to_rhs_indices;
2328       const std::vector<int64>& output_to_other_indices =
2329           assume_lhs_match ? indices_map.output_to_rhs_indices
2330                            : indices_map.output_to_lhs_indices;
2331       const std::vector<int64>& matching_to_output_indices =
2332           assume_lhs_match ? indices_map.lhs_to_output_indices
2333                            : indices_map.rhs_to_output_indices;
2334       const std::vector<int64>& other_to_output_indices =
2335           assume_lhs_match ? indices_map.rhs_to_output_indices
2336                            : indices_map.lhs_to_output_indices;
2337       const HloSharding& matching_sharding =
2338           assume_lhs_match ? lhs.sharding() : rhs.sharding();
2339       const HloSharding& other_sharding =
2340           assume_lhs_match ? rhs.sharding() : lhs.sharding();
2341       const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2342       const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2343       const int64 matching_non_contracting_partitions =
2344           assume_lhs_match ? lhs_non_contracting_partitions
2345                            : rhs_non_contracting_partitions;
2346       const int64 other_non_contracting_partitions =
2347           assume_lhs_match ? rhs_non_contracting_partitions
2348                            : lhs_non_contracting_partitions;
2349       const int64 matching_contracting_partitions =
2350           assume_lhs_match ? lhs_contracting_partitions
2351                            : rhs_contracting_partitions;
2352       const int64 other_contracting_partitions =
2353           assume_lhs_match ? rhs_contracting_partitions
2354                            : lhs_contracting_partitions;
2355       const int64 output_matching_non_contracting_partitions =
2356           assume_lhs_match ? output_lhs_non_contracting_partitions
2357                            : output_rhs_non_contracting_partitions;
2358       const int64 output_other_non_contracting_partitions =
2359           assume_lhs_match ? output_rhs_non_contracting_partitions
2360                            : output_lhs_non_contracting_partitions;
2361       const int64 matching_batch_partitions =
2362           assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2363       const int64 other_batch_partitions =
2364           assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2365       std::vector<int64> output_dims;
2366       output_dims.reserve(matching_non_contracting_dims.size());
2367       for (const DotConvDimsMapping::DimsMapping& dim :
2368            matching_non_contracting_dims) {
2369         output_dims.push_back(dim.output);
2370       }
2371       GroupedSharding output_grouped =
2372           GroupShardingOnDims(output_sharding, output_dims);
2373       GroupedSharding matching_grouped =
2374           GetNonContractingPartitionGroupedShardingForMatchedOperand(
2375               assume_lhs_match, matching_sharding, output_sharding,
2376               matching_non_contracting_dims);
2377       absl::optional<GroupedSharding> other_grouped =
2378           GetNonContractingPartitionGroupedShardingForOtherOperand(
2379               assume_lhs_match, output_base_shape,
2380               other_partitioned.hlo()->shape(), other_contracting_partitions,
2381               other_non_contracting_partitions, matching_contracting_partitions,
2382               output_other_non_contracting_partitions, other_sharding,
2383               output_sharding, matching_non_contracting_dims,
2384               other_non_contracting_dims, dims_mapping.contracting_dims);
2385       absl::optional<HloSharding> output_sharding_transposed_to_match_matching =
2386           hlo_sharding_util::TransposeShardingWithCollapsedDims(
2387               output_grouped.sharding, output_to_matching_indices,
2388               matching_to_output_indices);
2389       absl::optional<HloSharding> output_sharding_transposed_to_match_other =
2390           hlo_sharding_util::TransposeShardingWithCollapsedDims(
2391               output_grouped.sharding, output_to_other_indices,
2392               other_to_output_indices);
2393       const int64 new_num_partitions =
2394           num_partitions / matching_non_contracting_partitions;
2395       const int64 output_sharding_dim = FirstShardingDimWithPartitionOfSize(
2396           new_num_partitions, output_grouped.sharding);
2397       absl::optional<WindowedEinsumConfig> e_config =
2398           GetWindowedEinsumConfiguration(
2399               new_num_partitions, output_matching_non_contracting_partitions,
2400               output_other_non_contracting_partitions, 1,
2401               other_non_contracting_partitions, other_batch_partitions,
2402               matching_contracting_partitions, 1, matching_batch_partitions,
2403               output_sharding_dim,
2404               ShapeSizeInBytes(other_partitioned.base_shape()),
2405               ShapeSizeInBytes(matching_partitioned.base_shape()) /
2406                   matching_non_contracting_partitions,
2407               ShapeSizeInBytes(
2408                   GetPerGroupBaseShape(output_grouped, output_base_shape)),
2409               options.threshold_for_windowed_einsum_mib,
2410               output_sharding_transposed_to_match_matching,
2411               output_sharding_transposed_to_match_other,
2412               matching_grouped.sharding, other_grouped->sharding);
2413       return e_config ? new_num_partitions
2414                       : absl::optional<int64>(absl::nullopt);
2415     };
2416     absl::optional<int64> lhs_matching_iterations =
2417         subsequent_einsum_iterations_estimate(true);
2418     absl::optional<int64> rhs_matching_iterations =
2419         subsequent_einsum_iterations_estimate(false);
2420     if (lhs_matching_iterations && rhs_matching_iterations &&
2421         *lhs_matching_iterations != *rhs_matching_iterations) {
2422       lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
2423     }
2424   }
2425   return lhs_matching;
2426 }
2427 
2428 // Recursive partitioning function. If there are partial dimensions matching in
2429 // the operands and output, group the devices and recursively partition the
2430 // in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2431 StatusOr<HloInstruction*> PartitionDot(
2432     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2433     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2434     int64 num_partitions,
2435     const std::function<StatusOr<HloInstruction*>(
2436         HloInstruction*, HloInstruction*, SpmdBuilder*,
2437         const Window& conv_window)>& create_sharded_dot,
2438     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2439     bool require_matching_devices_to_group,
2440     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2441     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2442         windowed_dot_general_loops) {
2443   // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
2444   if (lhs.hlo() == rhs.hlo()) {
2445     auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
2446         rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
2447     copy_hlo->set_sharding(rhs.sharding());
2448     rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
2449   }
2450 
2451   // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
2452   auto get_partitions_for_dims =
2453       [&](const HloSharding& sharding,
2454           absl::Span<const DotConvDimsMapping::DimsMapping> dims,
2455           int lhs_rhs_or_output) {
2456         int64 partitions = 1;
2457         if (sharding.IsTileMaximal()) {
2458           return partitions;
2459         }
2460         for (const auto& dim : dims) {
2461           if (lhs_rhs_or_output == 0) {
2462             partitions *= sharding.tile_assignment().dim(dim.lhs);
2463           } else if (lhs_rhs_or_output == 1) {
2464             partitions *= sharding.tile_assignment().dim(dim.rhs);
2465           } else {
2466             CHECK_EQ(lhs_rhs_or_output, 2);
2467             partitions *= sharding.tile_assignment().dim(dim.output);
2468           }
2469         }
2470         return partitions;
2471       };
2472   const int64 lhs_batch_partitions =
2473       get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
2474   const int64 rhs_batch_partitions =
2475       get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
2476   const int64 output_batch_partitions =
2477       get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
2478   const int64 lhs_contracting_partitions =
2479       get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
2480   const int64 rhs_contracting_partitions =
2481       get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
2482   const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
2483       lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
2484   const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
2485       rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
2486   const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
2487       output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
2488   const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
2489       output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
2490   const int64 lhs_conv_spatial_partitions = get_partitions_for_dims(
2491       lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
2492   const int64 rhs_conv_spatial_partitions = get_partitions_for_dims(
2493       rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
2494   const int64 output_conv_spatial_partitions = get_partitions_for_dims(
2495       output_sharding, dims_mapping.conv_spatial_dims, 2);
2496   // Before we find partial matches along the dimensions, invoke base case again
2497   // without may_reshard_without_detecting_match.
2498 
2499   // Try partition the purely spatially-partitioned convolution with convolution
2500   // spatial dimension partitioned or depthwise parallel dimension partitioned.
2501   bool is_conv_spatial_dim_partitioned =
2502       (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
2503        output_conv_spatial_partitions > 1);
2504   bool is_conv_batch_or_contracting_dim_partitioned =
2505       (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
2506        output_batch_partitions > 1 ||
2507        (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
2508   if ((!dims_mapping.conv_spatial_dims.empty() &&
2509        is_conv_spatial_dim_partitioned &&
2510        !is_conv_batch_or_contracting_dim_partitioned) ||
2511       (original_hlo->opcode() == HloOpcode::kConvolution &&
2512        (original_hlo->batch_group_count() > 1 ||
2513         original_hlo->feature_group_count() > 1))) {
2514     // Partition with kernel_input_feature_dim > 1 and feature_group_count > 1
2515     // is not supported.
2516     const auto& dnums = original_hlo->convolution_dimension_numbers();
2517     if (original_hlo->feature_group_count() > 1 &&
2518         rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
2519             1) {
2520       return nullptr;
2521     }
2522 
2523     TF_ASSIGN_OR_RETURN(
2524         auto partitioned_conv,
2525         PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
2526                              dims_mapping, create_sharded_dot, conv_window,
2527                              original_hlo, num_partitions, options,
2528                              lhs.state().partition_id, module, b));
2529 
2530     if (partitioned_conv) {
2531       return partitioned_conv;
2532     }
2533 
2534     // Recursively partition on different types of dimensions for convolution.
2535     // Case 0.a: Group partitions by feature group count.
2536     if (original_hlo->feature_group_count() > 1 ||
2537         original_hlo->batch_group_count() > 1) {
2538       DotConvDimsMapping new_dims_mapping;
2539       if (original_hlo->feature_group_count() > 1) {
2540         new_dims_mapping =
2541             ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo);
2542       }
2543 
2544       if (original_hlo->batch_group_count() > 1) {
2545         new_dims_mapping =
2546             ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
2547       }
2548 
2549       const int64 conv_lhs_contracting_partitions = get_partitions_for_dims(
2550           lhs.sharding(), new_dims_mapping.contracting_dims, 0);
2551       const int64 conv_rhs_contracting_partitions = get_partitions_for_dims(
2552           rhs.sharding(), new_dims_mapping.contracting_dims, 1);
2553       const int64 conv_lhs_non_contracting_partitions = get_partitions_for_dims(
2554           lhs.sharding(), new_dims_mapping.lhs_non_contracting_dims, 0);
2555       const int64 conv_rhs_non_contracting_partitions = get_partitions_for_dims(
2556           rhs.sharding(), new_dims_mapping.rhs_non_contracting_dims, 1);
2557       const int64 conv_lhs_batch_partitions = get_partitions_for_dims(
2558           lhs.sharding(), new_dims_mapping.batch_dims, 0);
2559       const int64 conv_rhs_batch_partitions = get_partitions_for_dims(
2560           rhs.sharding(), new_dims_mapping.batch_dims, 1);
2561       const int64 conv_output_batch_partitions = get_partitions_for_dims(
2562           output_sharding, new_dims_mapping.batch_dims, 2);
2563       if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
2564            conv_rhs_batch_partitions == conv_output_batch_partitions) &&
2565           conv_output_batch_partitions > 1) {
2566         TF_ASSIGN_OR_RETURN(
2567             auto try_partitioned_conv,
2568             PartitionDotGroupOnBatch(
2569                 lhs, rhs, output_base_shape, output_sharding, new_dims_mapping,
2570                 num_partitions, conv_lhs_contracting_partitions,
2571                 conv_rhs_contracting_partitions,
2572                 conv_lhs_non_contracting_partitions,
2573                 conv_rhs_non_contracting_partitions, create_sharded_dot,
2574                 conv_window, module, original_hlo,
2575                 require_matching_devices_to_group, options, b,
2576                 windowed_dot_general_loops));
2577         if (try_partitioned_conv) {
2578           return try_partitioned_conv;
2579         }
2580       }
2581       return nullptr;
2582     }
2583   }
2584 
2585   TF_ASSIGN_OR_RETURN(
2586       auto try_partitioned_dot,
2587       PartitionBaseCase(
2588           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2589           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2590           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2591           lhs_contracting_partitions, rhs_contracting_partitions,
2592           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2593           output_lhs_non_contracting_partitions,
2594           output_rhs_non_contracting_partitions, options, b,
2595           windowed_dot_general_loops,
2596           /*may_reshard_without_detecting_match=*/false));
2597   if (try_partitioned_dot) {
2598     return try_partitioned_dot;
2599   }
2600 
2601   // Recursively partition on different types of dimensions.
2602   //
2603   // Case 1: Group partitions by batch.
2604   if ((lhs_batch_partitions == output_batch_partitions ||
2605        rhs_batch_partitions == output_batch_partitions) &&
2606       output_batch_partitions > 1) {
2607     TF_ASSIGN_OR_RETURN(
2608         auto dot,
2609         PartitionDotGroupOnBatch(
2610             lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2611             num_partitions, lhs_contracting_partitions,
2612             rhs_contracting_partitions, lhs_non_contracting_partitions,
2613             rhs_non_contracting_partitions, create_sharded_dot, conv_window,
2614             module, original_hlo, require_matching_devices_to_group, options, b,
2615             windowed_dot_general_loops));
2616     if (dot) {
2617       return dot;
2618     }
2619   }
2620 
2621   // Case 2: Group partitions by non-contracting dimensions.
2622   const bool may_group_on_lhs_non_contracting =
2623       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2624       lhs_non_contracting_partitions > 1;
2625   const bool may_group_on_rhs_non_contracting =
2626       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2627       rhs_non_contracting_partitions > 1;
2628   if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
2629     const bool lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
2630         dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2631         num_partitions, lhs_non_contracting_partitions,
2632         rhs_non_contracting_partitions, lhs_contracting_partitions,
2633         rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2634         output_rhs_non_contracting_partitions, lhs_batch_partitions,
2635         rhs_batch_partitions, may_group_on_lhs_non_contracting,
2636         may_group_on_rhs_non_contracting);
2637     TF_ASSIGN_OR_RETURN(
2638         auto dot,
2639         PartitionDotGroupOnNonContracting(
2640             lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
2641             lhs_matching ? lhs_contracting_partitions
2642                          : rhs_contracting_partitions,
2643             lhs_matching ? rhs_contracting_partitions
2644                          : lhs_contracting_partitions,
2645             lhs_matching ? dims_mapping.lhs_non_contracting_dims
2646                          : dims_mapping.rhs_non_contracting_dims,
2647             lhs_matching ? rhs_non_contracting_partitions
2648                          : lhs_non_contracting_partitions,
2649             lhs_matching ? output_rhs_non_contracting_partitions
2650                          : output_lhs_non_contracting_partitions,
2651             output_base_shape, output_sharding, dims_mapping, num_partitions,
2652             create_sharded_dot, conv_window, module, original_hlo,
2653             require_matching_devices_to_group, options, b,
2654             windowed_dot_general_loops));
2655     if (dot) {
2656       return dot;
2657     }
2658   }
2659   if (lhs_non_contracting_partitions > 1 &&
2660       output_lhs_non_contracting_partitions > 1) {
2661     // If part of LHS non-contracting dims match output, try them.
2662     std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2663     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2664       int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
2665       if (lhs_partitions > 1 &&
2666           lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
2667         matching_dims.push_back(dim);
2668       }
2669     }
2670     if (!matching_dims.empty()) {
2671       TF_ASSIGN_OR_RETURN(
2672           auto dot, PartitionDotGroupOnNonContracting(
2673                         /*lhs_matching=*/true, lhs, rhs,
2674                         lhs_contracting_partitions, rhs_contracting_partitions,
2675                         matching_dims, rhs_non_contracting_partitions,
2676                         output_rhs_non_contracting_partitions,
2677                         output_base_shape, output_sharding, dims_mapping,
2678                         num_partitions, create_sharded_dot, conv_window, module,
2679                         original_hlo, require_matching_devices_to_group,
2680                         options, b, windowed_dot_general_loops));
2681       if (dot) {
2682         return dot;
2683       }
2684     }
2685   }
2686   if (rhs_non_contracting_partitions > 1 &&
2687       output_rhs_non_contracting_partitions > 1) {
2688     // If part of RHS non-contracting dims match output, try them.
2689     std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2690     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2691       int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
2692       if (rhs_partitions > 1 &&
2693           rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
2694         matching_dims.push_back(dim);
2695       }
2696     }
2697     if (!matching_dims.empty()) {
2698       TF_ASSIGN_OR_RETURN(
2699           auto dot, PartitionDotGroupOnNonContracting(
2700                         /*lhs_matching=*/false, rhs, lhs,
2701                         rhs_contracting_partitions, lhs_contracting_partitions,
2702                         matching_dims, lhs_non_contracting_partitions,
2703                         output_lhs_non_contracting_partitions,
2704                         output_base_shape, output_sharding, dims_mapping,
2705                         num_partitions, create_sharded_dot, conv_window, module,
2706                         original_hlo, require_matching_devices_to_group,
2707                         options, b, windowed_dot_general_loops));
2708       if (dot) {
2709         return dot;
2710       }
2711     }
2712   }
2713 
2714   // Case 3: Group partitions by contracting dimensions.
2715   if (lhs_contracting_partitions == rhs_contracting_partitions &&
2716       lhs_contracting_partitions > 1) {
2717     TF_ASSIGN_OR_RETURN(
2718         auto dot,
2719         PartitionDotGroupOnContracting(
2720             lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
2721             output_lhs_non_contracting_partitions,
2722             output_rhs_non_contracting_partitions, output_base_shape,
2723             output_sharding, dims_mapping, num_partitions, create_sharded_dot,
2724             conv_window, module, original_hlo,
2725             require_matching_devices_to_group, options, b,
2726             windowed_dot_general_loops));
2727     if (dot) {
2728       return dot;
2729     }
2730   }
2731   if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
2732     // If part of contracting dims match, try them.
2733     std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
2734     for (const auto& dim : dims_mapping.contracting_dims) {
2735       int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
2736       if (lhs_partitions > 1 &&
2737           lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
2738         matching_dims.push_back(dim);
2739       }
2740     }
2741     if (!matching_dims.empty()) {
2742       TF_ASSIGN_OR_RETURN(
2743           auto dot, PartitionDotGroupOnContracting(
2744                         lhs, rhs, matching_dims, output_batch_partitions,
2745                         output_lhs_non_contracting_partitions,
2746                         output_rhs_non_contracting_partitions,
2747                         output_base_shape, output_sharding, dims_mapping,
2748                         num_partitions, create_sharded_dot, conv_window, module,
2749                         original_hlo, require_matching_devices_to_group,
2750                         options, b, windowed_dot_general_loops));
2751       if (dot) {
2752         return dot;
2753       }
2754     }
2755   }
2756 
2757   // Case 4: If operands are replicated but output is partially replicated,
2758   // recursive call with partial replication removed.
2759   if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
2760       output_sharding.ReplicateOnLastTileDim()) {
2761     auto grouped_output =
2762         GroupShardingOnDims(output_sharding, {output_base_shape.rank()});
2763     auto inner_state = CreatePerGroupPartitioningState(
2764         lhs.state(), grouped_output.device_groups, b);
2765     TF_ASSIGN_OR_RETURN(
2766         auto dot,
2767         PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
2768                      PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
2769                      output_base_shape, grouped_output.sharding, dims_mapping,
2770                      output_sharding.NumTiles(), create_sharded_dot,
2771                      conv_window, module, original_hlo, options, b,
2772                      windowed_dot_general_loops));
2773     if (dot) {
2774       return dot;
2775     }
2776   }
2777 
2778   // We failed to find partial matches, invoke base case again with
2779   // may_reshard_without_detecting_match.
2780   TF_ASSIGN_OR_RETURN(
2781       auto dot,
2782       PartitionBaseCase(
2783           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2784           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2785           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2786           lhs_contracting_partitions, rhs_contracting_partitions,
2787           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2788           output_lhs_non_contracting_partitions,
2789           output_rhs_non_contracting_partitions, options, b,
2790           windowed_dot_general_loops,
2791           /*may_reshard_without_detecting_match=*/true));
2792   if (dot) {
2793     return dot;
2794   }
2795   return nullptr;
2796 }
2797 
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64 num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2798 StatusOr<HloInstruction*> PartitionDot(
2799     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2800     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2801     int64 num_partitions,
2802     const std::function<StatusOr<HloInstruction*>(
2803         HloInstruction*, HloInstruction*, SpmdBuilder*,
2804         const Window& conv_window)>& create_sharded_dot,
2805     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2806     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2807     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2808         windowed_dot_general_loops) {
2809   // First try partitioning without resharding the groups, then try allow
2810   // resharding the groups.
2811   for (bool require_matching_devices_to_group : {true, false}) {
2812     TF_ASSIGN_OR_RETURN(
2813         auto try_partition,
2814         PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2815                      num_partitions, create_sharded_dot, conv_window, module,
2816                      original_hlo, require_matching_devices_to_group, options,
2817                      b, windowed_dot_general_loops));
2818     if (try_partition) {
2819       return try_partition;
2820     }
2821   }
2822 
2823   // Default action.
2824   TF_ASSIGN_OR_RETURN(
2825       auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
2826                                    b, conv_window));
2827   dot->set_sharding(HloSharding::Replicate());
2828   return PartitionedHlo(dot, output_base_shape, lhs.state())
2829       .Reshard(output_sharding)
2830       .hlo();
2831 }
2832 
2833 }  // namespace
2834 
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)2835 Status SpmdPartitioningVisitor::HandleDotHelper(
2836     HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
2837     const std::function<StatusOr<HloInstruction*>(
2838         HloInstruction*, HloInstruction*, SpmdBuilder*,
2839         const Window& conv_window)>& create_sharded_dot) {
2840   auto& lhs = GetPartitionedHlo(hlo->operand(0));
2841   auto& rhs = GetPartitionedHlo(hlo->operand(1));
2842   Window conv_window;
2843   if (hlo->opcode() == HloOpcode::kConvolution) {
2844     conv_window = hlo->window();
2845   }
2846 
2847   TF_ASSIGN_OR_RETURN(
2848       auto partitioned_dot,
2849       PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
2850                    num_partitions_, create_sharded_dot, conv_window, module_,
2851                    hlo, options_, &b_, &windowed_dot_general_loops_));
2852   SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
2853   return Status::OK();
2854 }
2855 
2856 namespace {
2857 
2858 // Finds a cluster of nodes that produce the inputs for `hlo` which only depend
2859 // on small operands, which means the cluster should start with broadcasts,
2860 // constants and iotas. All other internal nodes must be non-side-effecting
2861 // elemntwise ops. Returns the set of nodes, and the small operands. E.g., for
2862 // the following graph,
2863 //
2864 //     a -> broadcast -> multiply
2865 //     iota  ---> add--/
2866 //     constant/
2867 //
2868 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
2869 //    <{broadcast, iota, constant, add, multiply}, [a]>.
2870 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)2871 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
2872   absl::flat_hash_set<HloInstruction*> nodes_found;
2873   std::vector<HloInstruction*> new_operands;
2874   absl::flat_hash_set<const HloInstruction*> new_operands_set;
2875   std::vector<HloInstruction*> worklist;
2876   worklist.push_back(hlo);
2877   while (!worklist.empty()) {
2878     auto inst = worklist.back();
2879     worklist.pop_back();
2880     if (nodes_found.count(inst) > 0) {
2881       continue;
2882     }
2883     if (inst->opcode() == HloOpcode::kBroadcast ||
2884         inst->opcode() == HloOpcode::kConstant ||
2885         inst->opcode() == HloOpcode::kIota) {
2886       nodes_found.insert(inst);
2887       for (auto o : inst->operands()) {
2888         auto res = new_operands_set.emplace(o);
2889         if (res.second) {
2890           new_operands.push_back(o);
2891         }
2892       }
2893     } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
2894                inst->opcode() != HloOpcode::kAllReduce &&
2895                absl::c_all_of(inst->operands(),
2896                               [inst](const HloInstruction* o) {
2897                                 return ShapeUtil::CompatibleIgnoringElementType(
2898                                     o->shape(), inst->shape());
2899                               })) {
2900       nodes_found.insert(inst);
2901       for (auto o : inst->operands()) {
2902         worklist.push_back(o);
2903       }
2904     } else {
2905       nodes_found.clear();
2906       new_operands.clear();
2907       break;
2908     }
2909   }
2910   return {std::move(nodes_found), std::move(new_operands)};
2911 }
2912 
2913 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
2914 // on contracting dimensions. Such a loop has a dynamic slice on the
2915 // non-windowed operand. If we move the input nodes into the loop, the
2916 // dynamic-slice could be merged with them by later optimization passes, which
2917 // reduces memory.
2918 //
2919 // small_operands             small_operands
2920 //        |                          |
2921 // input_nodes                loop { |
2922 //        |          =>         input_nodes
2923 // loop { |                          |
2924 //    dynamic-slice             dynamic-slice
2925 //    ...                       ...
2926 // }                          }
2927 //
2928 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
2929 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64 non_windowed_operand_index)2930 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
2931     HloInstruction* loop, int64 non_windowed_operand_index) {
2932   auto input_tuple = loop->mutable_operand(0);
2933   auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
2934   auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
2935   auto to_sink = std::move(input_nodes.first);
2936   auto new_operands = std::move(input_nodes.second);
2937   if (to_sink.empty()) {
2938     return Status::OK();
2939   }
2940   auto computation = loop->parent();
2941   // Replace the old operand with a tuple of the found small operands.
2942   auto new_input_subtuple =
2943       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
2944   TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
2945       non_windowed_operand_index, new_input_subtuple));
2946 
2947   auto body = loop->while_body();
2948   auto body_param = body->parameter_instruction(0);
2949   auto old_body_param_users = body_param->users();
2950   // Update all tuple shapes.
2951   for (auto tuple : std::vector<HloInstruction*>{
2952            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
2953            body_param, body->root_instruction()}) {
2954     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
2955                                    {non_windowed_operand_index}) =
2956         new_input_subtuple->shape();
2957   }
2958   // Now update the loop body.
2959   auto new_operand_tuple_inside =
2960       body->AddInstruction(HloInstruction::CreateGetTupleElement(
2961           new_input_subtuple->shape(), body_param, non_windowed_operand_index));
2962   TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
2963       non_windowed_operand_index, new_operand_tuple_inside));
2964 
2965   // Create nodes inside the loop body.
2966   std::vector<HloInstruction*> worklist;
2967   absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
2968   auto add_users_if_available = [&](HloInstruction* inst) {
2969     for (auto u : inst->users()) {
2970       if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
2971           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
2972             return outside_to_inside.count(o) > 0;
2973           })) {
2974         worklist.push_back(u);
2975       }
2976     }
2977   };
2978   for (int64 i = 0; i < new_operands.size(); ++i) {
2979     outside_to_inside[new_operands[i]] =
2980         body->AddInstruction(HloInstruction::CreateGetTupleElement(
2981             new_operands[i]->shape(), new_operand_tuple_inside, i));
2982     add_users_if_available(new_operands[i]);
2983   }
2984   // HLOs to sink without operands.
2985   std::vector<HloInstruction*> nullaries_to_sink;
2986   for (auto inst : to_sink) {
2987     if (inst->operand_count() == 0) {
2988       nullaries_to_sink.push_back(inst);
2989     }
2990   }
2991   // Sort nullaries_to_sink to make it deterministic.
2992   absl::c_sort(nullaries_to_sink,
2993                [](const HloInstruction* a, const HloInstruction* b) {
2994                  return a->unique_id() < b->unique_id();
2995                });
2996   worklist.reserve(nullaries_to_sink.size());
2997   for (auto inst : nullaries_to_sink) {
2998     worklist.push_back(inst);
2999   }
3000   while (!worklist.empty()) {
3001     auto inst = worklist.back();
3002     worklist.pop_back();
3003     std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3004     for (int64 i = 0; i < inst->operand_count(); ++i) {
3005       inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3006     }
3007     outside_to_inside[inst] = body->AddInstruction(
3008         inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3009     add_users_if_available(inst);
3010   }
3011   TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3012   for (auto ou : old_body_param_users) {
3013     if (ou->opcode() == HloOpcode::kGetTupleElement &&
3014         ou->tuple_index() == non_windowed_operand_index) {
3015       TF_RETURN_IF_ERROR(
3016           ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3017       TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3018     }
3019   }
3020   return Status::OK();
3021 }
3022 
3023 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into
3024 // the windowed dot-general loop on non-contracting dimensions. Such a loop has
3025 // a dynamic-update-slice at the output. If we move the user nodes into the loop
3026 // and before the dynamic-update-slice, the user nodes can operate on smaller
3027 // shapes, which reduces memory.
3028 //
3029 // small_operands                   small_operands
3030 //  | |                 =>                  | |
3031 //  | |  loop {                     loop {  | |
3032 //  | |    conv                             | broadcast      conv
3033 //  | |      |                              |     |           /
3034 //  | | dynamic-update-slice                |  dynamic-slice /
3035 //  | |         |                           |     |         /
3036 //  | |  }      |                           |  multiply-----
3037 //  |broadcast  /                           |    /
3038 //  | |        /                            reduce
3039 //  |multiply--                             |
3040 //  \ |                                dynamic-update-slice
3041 //   reduce                         }
3042 //
3043 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3044 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop)3045 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3046     HloInstruction* loop) {
3047   CHECK_EQ(loop->user_count(), 1);
3048   // There should be a single direct user of the while loop, which is the
3049   // gte for element 2, i.e., the dot output.
3050   auto user_gte = loop->users().front();
3051   CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
3052   CHECK_EQ(user_gte->tuple_index(), 2);
3053   auto computation = loop->parent();
3054 
3055   // Find the reduce outputs and the input nodes they depend on, if input nodes
3056   // only have small operands.
3057   absl::flat_hash_set<HloInstruction*> to_move;
3058   std::vector<HloInstruction*> new_operands;
3059   absl::flat_hash_set<const HloInstruction*> new_operands_set;
3060   std::vector<HloInstruction*> reduce_outputs;
3061   std::vector<HloInstruction*> worklist;
3062   Shape padded_shape = user_gte->shape();
3063   Shape unpadded_shape = user_gte->shape();
3064   auto original_output = user_gte;
3065 
3066   if (user_gte->user_count() == 1 &&
3067       user_gte->users().back()->opcode() == HloOpcode::kSlice) {
3068     original_output = user_gte->users().back();
3069     unpadded_shape = original_output->shape();
3070   }
3071   for (auto u : original_output->users()) {
3072     worklist.push_back(u);
3073   }
3074   to_move.insert(original_output);
3075   while (!worklist.empty()) {
3076     auto inst = worklist.back();
3077     worklist.pop_back();
3078     if (to_move.count(inst) > 0) {
3079       continue;
3080     }
3081     // We only support reduces with simple reduction function, since we may need
3082     // to accumulate across iterations manually.
3083     if (inst->opcode() == HloOpcode::kReduce &&
3084         inst->to_apply()->instruction_count() == 3 &&
3085         inst->to_apply()->num_parameters() == 2 &&
3086         inst->to_apply()->root_instruction()->IsElementwise()) {
3087       to_move.insert(inst);
3088       auto other_operand = inst->mutable_operand(1);
3089       auto res = new_operands_set.emplace(other_operand);
3090       if (res.second) {
3091         new_operands.push_back(other_operand);
3092       }
3093       reduce_outputs.push_back(inst);
3094     } else if (inst != computation->root_instruction() &&
3095                inst->user_count() > 0 && inst->IsElementwise() &&
3096                !inst->HasSideEffectNoRecurse() &&
3097                inst->opcode() != HloOpcode::kAllReduce &&
3098                absl::c_all_of(inst->operands(),
3099                               [inst](const HloInstruction* o) {
3100                                 return ShapeUtil::CompatibleIgnoringElementType(
3101                                     o->shape(), inst->shape());
3102                               })) {
3103       // For an elementwise op, we need to make sure that they depend on only
3104       // nodes already in to_move and nodes with small operands.
3105       bool can_include = true;
3106       for (auto operand : inst->operands()) {
3107         if (to_move.count(operand) > 0) {
3108           continue;
3109         }
3110         auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
3111         if (find_result.first.empty()) {
3112           can_include = false;
3113           break;
3114         }
3115         for (auto n : find_result.first) {
3116           to_move.insert(n);
3117         }
3118         for (auto new_operand : find_result.second) {
3119           auto res = new_operands_set.insert(new_operand);
3120           if (res.second) {
3121             new_operands.push_back(new_operand);
3122           }
3123         }
3124       }
3125       if (!can_include) {
3126         to_move.clear();
3127         break;
3128       }
3129       to_move.insert(inst);
3130       for (auto u : inst->users()) {
3131         worklist.push_back(u);
3132       }
3133     } else {
3134       to_move.clear();
3135       break;
3136     }
3137   }
3138   // If nothing is found, to_move could contain only original_output, or cleared
3139   // by the above code.
3140   if (to_move.size() <= 1) {
3141     return Status::OK();
3142   }
3143 
3144   // We will replace the original loop output with reduce-shape outputs. Create
3145   // the initial buffers before the loop.
3146   for (auto out : reduce_outputs) {
3147     auto padded_out_shape = out->shape();
3148     int64 operand_dim = 0;
3149     int64 output_dim = 0;
3150     while (output_dim < padded_out_shape.rank()) {
3151       if (absl::c_linear_search(out->dimensions(), operand_dim)) {
3152         // Dimension colapsed.
3153         ++operand_dim;
3154         continue;
3155       }
3156       // Kept dimensions have the same size of the padded shape.
3157       padded_out_shape.set_dimensions(output_dim,
3158                                       padded_shape.dimensions(operand_dim));
3159       ++operand_dim;
3160       ++output_dim;
3161     }
3162     auto broadcast =
3163         computation->AddInstruction(HloInstruction::CreateBroadcast(
3164             padded_out_shape,
3165             computation->AddInstruction(HloInstruction::CreateConstant(
3166                 LiteralUtil::Zero(out->shape().element_type()))),
3167             {}));
3168     new_operands.push_back(broadcast);
3169   }
3170 
3171   auto input_tuple = loop->mutable_operand(0);
3172   // Create the new input subtuple that contains the small operands and the
3173   // reduce-shape result buffers.
3174   auto new_input_subtuple =
3175       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3176   TF_RETURN_IF_ERROR(
3177       input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
3178   auto body = loop->while_body();
3179   auto body_param = body->parameter_instruction(0);
3180   auto body_root = body->root_instruction();
3181   CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
3182   // Update tuple shapes.
3183   for (auto tuple : std::vector<HloInstruction*>{
3184            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3185            body_param, body_root}) {
3186     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
3187         new_input_subtuple->shape();
3188   }
3189   auto new_loop_input =
3190       body->AddInstruction(HloInstruction::CreateGetTupleElement(
3191           new_input_subtuple->shape(), body_param, 2));
3192 
3193   // Now create the moved nodes inside the loop body.
3194   absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3195   worklist.clear();
3196   auto add_users_if_available = [&](HloInstruction* inst) {
3197     for (auto u : inst->users()) {
3198       if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 &&
3199           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3200             return outside_to_inside.count(o) > 0;
3201           })) {
3202         worklist.push_back(u);
3203       }
3204     }
3205   };
3206   for (int64 i = 0; i < new_operands.size(); ++i) {
3207     outside_to_inside[new_operands[i]] =
3208         body->AddInstruction(HloInstruction::CreateGetTupleElement(
3209             new_operands[i]->shape(), new_loop_input, i));
3210     add_users_if_available(new_operands[i]);
3211   }
3212   // The elementwise nodes will be created with sliced shape. The original loop
3213   // output corresponds to the dynamic-update-slice's update slice.
3214   auto dus = body_root->mutable_operand(2);
3215   CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice);
3216   outside_to_inside[original_output] = dus->mutable_operand(1);
3217   add_users_if_available(original_output);
3218   std::vector<HloInstruction*> slice_offsets(padded_shape.rank());
3219   for (int64 i = 0; i < slice_offsets.size(); ++i) {
3220     slice_offsets[i] = dus->mutable_operand(i + 2);
3221   }
3222   auto get_slice = [&](HloInstruction* padded) {
3223     return body->AddInstruction(HloInstruction::CreateDynamicSlice(
3224         ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3225                                      padded->shape().element_type()),
3226         padded, slice_offsets, dus->operand(1)->shape().dimensions()));
3227   };
3228   // Helper functions to create nodes with small operands.
3229   auto add_broadcast = [&](const HloInstruction* broadcast) {
3230     auto padded_operand_shape = broadcast->operand(0)->shape();
3231     for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
3232       padded_operand_shape.set_dimensions(
3233           i, padded_shape.dimensions(broadcast->dimensions(i)));
3234     }
3235     auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)],
3236                                      padded_operand_shape, nullptr, body);
3237     outside_to_inside[broadcast] =
3238         get_slice(body->AddInstruction(broadcast->CloneWithNewOperands(
3239             ShapeUtil::ChangeElementType(padded_shape,
3240                                          padded_operand_shape.element_type()),
3241             {padded_operand})));
3242   };
3243   auto add_iota = [&](const HloInstruction* iota) {
3244     outside_to_inside[iota] =
3245         get_slice(body->AddInstruction(iota->CloneWithNewOperands(
3246             ShapeUtil::ChangeElementType(padded_shape,
3247                                          iota->shape().element_type()),
3248             {})));
3249   };
3250   auto add_constant = [&](const HloInstruction* constant) {
3251     outside_to_inside[constant] = body->AddInstruction(constant->Clone());
3252     outside_to_inside[constant] = get_slice(
3253         PadToShape(outside_to_inside[constant],
3254                    ShapeUtil::ChangeElementType(
3255                        padded_shape, constant->shape().element_type()),
3256                    nullptr, body));
3257   };
3258   while (!worklist.empty()) {
3259     auto inst = worklist.back();
3260     worklist.pop_back();
3261     if (outside_to_inside.count(inst) > 0) {
3262       continue;
3263     }
3264     if (inst->opcode() == HloOpcode::kBroadcast) {
3265       add_broadcast(inst);
3266     } else if (inst->opcode() == HloOpcode::kIota) {
3267       add_iota(inst);
3268     } else if (inst->opcode() == HloOpcode::kConstant) {
3269       add_constant(inst);
3270     } else if (inst->opcode() == HloOpcode::kReduce) {
3271       // This is an output, for which we has special handling later.
3272     } else {
3273       std::vector<HloInstruction*> operands_inside(inst->operand_count());
3274       for (int64 i = 0; i < operands_inside.size(); ++i) {
3275         operands_inside[i] = outside_to_inside[inst->operand(i)];
3276       }
3277       outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands(
3278           ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3279                                        inst->shape().element_type()),
3280           operands_inside));
3281     }
3282     add_users_if_available(inst);
3283   }
3284   std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
3285   for (int64 i = 0; i < new_outputs_inside.size(); ++i) {
3286     new_outputs_inside[i] = outside_to_inside[new_operands[i]];
3287   }
3288   // Now create the reduce outpus inside of the loop.
3289   for (int64 i = 0; i < reduce_outputs.size(); ++i) {
3290     auto reduce_outside = reduce_outputs[i];
3291     CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
3292     int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3293     auto last_iter_result = outside_to_inside[new_operands[index_in_operand]];
3294     auto operand0 = outside_to_inside[reduce_outside->operand(0)];
3295     auto operand1 = outside_to_inside[reduce_outside->operand(1)];
3296     TF_ASSIGN_OR_RETURN(auto reduce_shape,
3297                         ShapeInference::InferReduceShape(
3298                             {&operand0->shape(), &operand1->shape()},
3299                             reduce_outside->dimensions(),
3300                             reduce_outside->to_apply()->ComputeProgramShape()));
3301     *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
3302     std::vector<HloInstruction*> reduce_dus_offsets;
3303     // If any collapsed dimension is windowed, we need to accumulate with last
3304     // iteration's result. If such a dimension has padding, we also need to mask
3305     // off invalid data.
3306     bool needs_accumulate = false;
3307     std::vector<int64> dims_to_mask;
3308     for (int64 i = 0; i < slice_offsets.size(); ++i) {
3309       if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
3310         if (reduce_outside->operand(0)->shape().dimensions(i) !=
3311             operand0->shape().dimensions(i)) {
3312           needs_accumulate = true;
3313           if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
3314             dims_to_mask.push_back(i);
3315           }
3316         }
3317         continue;
3318       }
3319       reduce_dus_offsets.push_back(slice_offsets[i]);
3320     }
3321     // Mask off invalid data in collapsed dimensions.
3322     for (int64 dim : dims_to_mask) {
3323       auto iota = body->AddInstruction(HloInstruction::CreateIota(
3324           ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
3325       auto add = body->AddInstruction(HloInstruction::CreateBinary(
3326           iota->shape(), HloOpcode::kAdd, iota,
3327           body->AddInstruction(HloInstruction::CreateBroadcast(
3328               iota->shape(), slice_offsets[dim], {}))));
3329       auto limit = body->AddInstruction(HloInstruction::CreateBroadcast(
3330           iota->shape(),
3331           body->AddInstruction(
3332               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
3333                   reduce_outside->operand(0)->shape().dimensions(dim)))),
3334           {}));
3335       auto compare = body->AddInstruction(HloInstruction::CreateCompare(
3336           ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
3337           ComparisonDirection::kLt));
3338       operand0 = body->AddInstruction(HloInstruction::CreateTernary(
3339           operand0->shape(), HloOpcode::kSelect, compare, operand0,
3340           body->AddInstruction(HloInstruction::CreateBroadcast(
3341               operand0->shape(), operand1, {}))));
3342     }
3343     auto output_inside =
3344         body->AddInstruction(reduce_outside->CloneWithNewOperands(
3345             reduce_shape, {operand0, operand1}));
3346     // Accumulate with previous results if needed.
3347     if (needs_accumulate) {
3348       auto input_slice =
3349           body->AddInstruction(HloInstruction::CreateDynamicSlice(
3350               output_inside->shape(), last_iter_result, reduce_dus_offsets,
3351               output_inside->shape().dimensions()));
3352       output_inside = body->AddInstruction(HloInstruction::CreateBinary(
3353           output_inside->shape(),
3354           reduce_outside->to_apply()->root_instruction()->opcode(),
3355           output_inside, input_slice));
3356     }
3357     // Dynamic-update-slice if needed.
3358     if (!ShapeUtil::Compatible(output_inside->shape(),
3359                                last_iter_result->shape())) {
3360       output_inside =
3361           body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3362               last_iter_result->shape(), last_iter_result, output_inside,
3363               reduce_dus_offsets));
3364     }
3365     new_outputs_inside[index_in_operand] = output_inside;
3366   }
3367   // Body output.
3368   auto new_output_inside =
3369       body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
3370   TF_RETURN_IF_ERROR(
3371       body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
3372   TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus));
3373   // Replace uses of the reduces outside the loop.
3374   auto new_output_gte =
3375       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3376           new_output_inside->shape(), loop, 2));
3377   for (int64 i = 0; i < reduce_outputs.size(); ++i) {
3378     int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3379     auto new_output =
3380         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3381             new_outputs_inside[index_in_operand]->shape(), new_output_gte,
3382             index_in_operand));
3383     if (!ShapeUtil::Compatible(new_output->shape(),
3384                                reduce_outputs[i]->shape())) {
3385       new_output = computation->AddInstruction(HloInstruction::CreateSlice(
3386           reduce_outputs[i]->shape(), new_output,
3387           std::vector<int64>(new_output->shape().rank(), 0),
3388           reduce_outputs[i]->shape().dimensions(),
3389           std::vector<int64>(new_output->shape().rank(), 1)));
3390     }
3391     TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
3392     TF_RETURN_IF_ERROR(
3393         computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
3394   }
3395   return Status::OK();
3396 }
3397 
3398 }  // namespace
3399 
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)3400 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
3401     HloComputation* computation, const SpmdPartitionerOptions& options) {
3402   for (auto& loop : windowed_dot_general_loops_) {
3403     if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
3404         loop.operands_sharded_at_contracting_dims) {
3405       // We have a dynamic-slice for the non-windowed operand in
3406       // batch/contracting-dim/noncontracting-dim windowed dot-general. So
3407       // moving the broadcast/iota/elementwise ops into the loop could help
3408       // reduce memory via fusion.
3409       TF_RETURN_IF_ERROR(
3410           SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3411               loop.while_loop, 1 - loop.windowed_operand));
3412     }
3413     // Currently unrolled loop does not support this optimization.
3414     if (!options.bidirectional_windowed_einsum &&
3415         !options.unroll_windowed_einsum && !loop.windowed_in_contracting_dims &&
3416         !loop.operands_sharded_at_contracting_dims) {
3417       // We have a dynamic-update-slice for the output in
3418       // batch/non-contracting-dim windowed dot-general. So moving reduce ops
3419       // into the loop could help reduce memory.
3420       TF_RETURN_IF_ERROR(
3421           MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3422               loop.while_loop));
3423     }
3424   }
3425   return Status::OK();
3426 }
3427 
3428 }  // namespace spmd
3429 }  // namespace xla
3430