1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
25 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
26 #include "tensorflow/compiler/xla/service/shape_inference.h"
27 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
28 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/numbers.h"
34 
35 namespace xla {
36 namespace spmd {
37 
38 namespace {
39 
40 // Partition convolution with batch group count.
PartitionConvolutionWithBatchGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,SpmdBuilder * b)41 StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
42     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
43     const HloSharding& output_sharding,
44     const std::function<StatusOr<HloInstruction*>(
45         HloInstruction*, HloInstruction*, SpmdBuilder*,
46         const Window& conv_window)>& create_sharded_conv,
47     const Window& conv_window, HloInstruction* original_hlo,
48     int64 num_partitions, SpmdBuilder* b) {
49   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
50   if (original_hlo->batch_group_count() == 1 ||
51       original_hlo->batch_group_count() < num_partitions) {
52     return nullptr;
53   }
54 
55   const auto& dnums = original_hlo->convolution_dimension_numbers();
56   // Only supports batch_group_size equals input_batch_size case.
57   const int64 input_batch_size =
58       lhs.base_shape().dimensions(dnums.input_batch_dimension());
59   const int64 kernel_output_feature_size =
60       rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
61   if (input_batch_size != kernel_output_feature_size ||
62       original_hlo->batch_group_count() != input_batch_size) {
63     return nullptr;
64   }
65 
66   // Map RHS indices to LHS indices.
67   std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
68   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
69       dnums.input_batch_dimension();
70   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
71       dnums.input_feature_dimension();
72   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
73     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
74         dnums.input_spatial_dimensions(i);
75   }
76 
77   // Map LHS indices to RHS indices.
78   std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
79   for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
80     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
81   }
82 
83   // Map LHS indices to output indices.
84   std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1);
85   lhs_to_output_indices[dnums.input_batch_dimension()] =
86       dnums.output_feature_dimension();
87   lhs_to_output_indices[dnums.input_feature_dimension()] =
88       dnums.output_batch_dimension();
89   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
90     lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
91         dnums.output_spatial_dimensions(i);
92   }
93 
94   // Align LHS or RHS to other operand if input batch dim or kernel output
95   // feature dim is partitioned.
96   auto aligned_rhs_sharding =
97       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
98   auto aligned_lhs_sharding =
99       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
100 
101   bool lhs_batch_dim_is_partitioned =
102       (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) ==
103        num_partitions);
104   bool rhs_output_feature_dim_is_partitioned =
105       (ShardCountAtDim(rhs.sharding(),
106                        dnums.kernel_output_feature_dimension()) ==
107        num_partitions);
108   if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) {
109     return nullptr;
110   }
111   // Reshard LHS or RHS to partition at batch dimension or output feature
112   // dimension as the other operand.
113   if (lhs_batch_dim_is_partitioned) {
114     rhs = rhs.Reshard(aligned_rhs_sharding);
115   } else {
116     lhs = lhs.Reshard(aligned_lhs_sharding);
117   }
118   // Align output sharding after LHS and RHS sharding are consistent.
119   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
120       lhs.sharding(), lhs_to_output_indices);
121 
122   // Create partitioned convolution.
123   TF_ASSIGN_OR_RETURN(
124       auto sharded_conv,
125       create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
126   sharded_conv->set_sharding(aligned_output_sharding);
127   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
128       .Reshard(output_sharding)
129       .hlo();
130 }
131 
132 // Partition convolution with feature group count.
PartitionConvolutionWithFeatureGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,SpmdBuilder * b)133 StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
134     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
135     const HloSharding& output_sharding,
136     const std::function<StatusOr<HloInstruction*>(
137         HloInstruction*, HloInstruction*, SpmdBuilder*,
138         const Window& conv_window)>& create_sharded_conv,
139     const Window& conv_window, HloInstruction* original_hlo,
140     int64 num_partitions, SpmdBuilder* b) {
141   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
142   if (original_hlo->feature_group_count() == 1 ||
143       original_hlo->feature_group_count() < num_partitions) {
144     return nullptr;
145   }
146 
147   const auto& dnums = original_hlo->convolution_dimension_numbers();
148   const int64 input_feature_size =
149       lhs.base_shape().dimensions(dnums.input_feature_dimension());
150   const int64 kernel_output_feature_size =
151       rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
152   if (input_feature_size != kernel_output_feature_size ||
153       input_feature_size % original_hlo->feature_group_count() != 0) {
154     return nullptr;
155   }
156 
157   // Align RHS indices to LHS.
158   std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
159   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
160       dnums.input_feature_dimension();
161   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
162       dnums.input_batch_dimension();
163   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
164     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
165         dnums.input_spatial_dimensions(i);
166   }
167 
168   // Align LHS indices to RHS.
169   std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
170   for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
171     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
172   }
173 
174   // Align LHS indices to output.
175   std::vector<int64> lhs_to_output_indices(output_base_shape.rank());
176   lhs_to_output_indices[dnums.input_feature_dimension()] =
177       dnums.output_feature_dimension();
178   lhs_to_output_indices[dnums.input_batch_dimension()] =
179       dnums.output_batch_dimension();
180   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
181     lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
182         dnums.output_spatial_dimensions(i);
183   }
184 
185   // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is
186   // partitioned.
187   auto aligned_rhs_sharding =
188       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
189   auto aligned_lhs_sharding =
190       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
191 
192   bool lhs_feature_dim_is_partitioned =
193       (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) ==
194        num_partitions);
195   bool rhs_output_feature_dim_is_partitioned =
196       (ShardCountAtDim(rhs.sharding(),
197                        dnums.kernel_output_feature_dimension()) ==
198        num_partitions);
199   if (!lhs_feature_dim_is_partitioned &&
200       !rhs_output_feature_dim_is_partitioned) {
201     return nullptr;
202   }
203   // Reshard LHS or RHS to partition at input feature dimension or output
204   // feature dimension as the other operand.
205   if (lhs_feature_dim_is_partitioned) {
206     rhs = rhs.Reshard(aligned_rhs_sharding);
207   } else {
208     lhs = lhs.Reshard(aligned_lhs_sharding);
209   }
210 
211   // Align output sharding after LHS and RHS sharding are consistent.
212   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
213       lhs.sharding(), lhs_to_output_indices);
214 
215   TF_ASSIGN_OR_RETURN(
216       auto sharded_conv,
217       create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
218   sharded_conv->set_sharding(aligned_output_sharding);
219   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
220       .Reshard(output_sharding)
221       .hlo();
222 }
223 
224 // Partition convolution when both LHS and RHS are partitioned at spatial
225 // dimensions. Halo exchange will happen on RHS only.
226 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)227 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
228     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
229     const HloSharding& output_sharding,
230     const std::function<StatusOr<HloInstruction*>(
231         HloInstruction*, HloInstruction*, SpmdBuilder*,
232         const Window& conv_window)>& create_sharded_conv,
233     const Window& conv_window, HloInstruction* original_hlo,
234     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
235   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
236   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
237                !rhs.sharding().IsTileMaximal());
238 
239   const auto& dnums = original_hlo->convolution_dimension_numbers();
240   std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
241   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
242       dnums.input_batch_dimension();
243   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
244       dnums.input_feature_dimension();
245   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
246     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
247         dnums.input_spatial_dimensions(i);
248   }
249   std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
250   for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
251     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
252   }
253   auto aligned_rhs_sharding =
254       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
255   auto aligned_lhs_sharding =
256       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
257 
258   auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
259                                   const HloSharding& rhs_sharding) {
260     // We currently don't support partitioning input batch or output feature
261     // dimensions.
262     return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
263                1 ||
264            rhs_sharding.tile_assignment().dim(
265                dnums.kernel_output_feature_dimension()) != 1;
266   };
267 
268   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
269       LiteralUtil::Zero(output_base_shape.element_type())));
270   if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
271     if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
272       return nullptr;
273     }
274     lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
275     rhs = rhs.PadWithValue(zero);
276   } else {
277     if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
278       return nullptr;
279     }
280     lhs = lhs.PadWithValue(zero);
281     rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
282   }
283 
284   if (original_hlo->feature_group_count() > 1 &&
285       (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
286            1 ||
287        rhs.sharding().tile_assignment().dim(
288            dnums.kernel_output_feature_dimension()) > 1)) {
289     return nullptr;
290   }
291 
292   if (original_hlo->batch_group_count() > 1 &&
293       (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
294            1 ||
295        rhs.sharding().tile_assignment().dim(
296            dnums.kernel_output_feature_dimension()) > 1)) {
297     return nullptr;
298   }
299 
300   // Reshard RHS so that each shard computes the partial sum of the full
301   // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
302   // that reshards LHS.
303   //
304   // The size of halo on each dimension can be calculated from the
305   // projection onto the RHS that shard i needs to read. RHS and LHS below
306   // refers to the shard size of RHS and LHS, WC is the number of windows,
307   // and D is the window dilation.
308   //
309   // * offset(i): LHS * i + low_padding - (WC - 1) * stride
310   // * limit(i): LHS * (i + 1) + low_padding
311   //
312   // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
313   // * left-halo: i * RHS - offset(i)
314   //              = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
315   // * right-halo: limit(i) - (i + 1) * RHS
316   //              = (i + 1) * (LHS - RHS * D) + low_pading
317   const auto& collective_ops_creator = lhs.state().collective_ops_creator;
318   std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
319   std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
320   std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
321 
322   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
323     int64 lhs_dimension = dnums.input_spatial_dimensions(i);
324     int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
325     int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
326     const auto& wd = conv_window.dimensions(i);
327     if (wd.base_dilation() != 1 || wd.window_reversal()) {
328       return nullptr;
329     }
330 
331     int64 lhs_shard_size =
332         CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
333     int64 rhs_shard_size =
334         CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
335     shard_counts[i] = shard_count;
336     lhs_shard_sizes[i] = lhs_shard_size;
337     rhs_shard_sizes[i] = rhs_shard_size;
338   }
339 
340   std::vector<OffsetCalculation> left_halo_size_functions(
341       output_base_shape.rank());
342   std::vector<OffsetCalculation> right_halo_size_functions(
343       output_base_shape.rank());
344   Window new_window = conv_window;
345 
346   // Data structures needed for Pad and DynamicSlice on LHS if needed.
347   bool need_dynamic_slice_lhs = false;
348   auto partition_ordinals =
349       MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
350   std::vector<int64> zero_padding(output_base_shape.rank());
351   PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
352   auto zero_s32 =
353       b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
354   std::vector<HloInstruction*> dynamic_slice_start_indices(
355       output_base_shape.rank(), zero_s32);
356   Shape dynamic_slice_shape = lhs.hlo()->shape();
357   Shape pad_shape = lhs.hlo()->shape();
358 
359   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
360     int64 lhs_dimension = dnums.input_spatial_dimensions(i);
361     int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
362     int64 lhs_shard_size = lhs_shard_sizes[i];
363     int64 rhs_shard_size = rhs_shard_sizes[i];
364 
365     if (shard_counts[i] == 1) {
366       continue;
367     }
368 
369     // Calculate the left and right halo sizes as described in the comments
370     // above. It calculcates the halo sizes with dilation, so we apply
371     // CeilOfRatio({left,right}_halo_size, window_dilation).
372     const auto& wd = conv_window.dimensions(i);
373     int64 padding_low = wd.padding_low();
374     int64 padding_high = wd.padding_high();
375     int64 base = lhs.base_shape().dimensions(lhs_dimension);
376     int64 window_count = 1 + (padding_low + padding_high + base -
377                               (1 + (wd.size() - 1) * wd.window_dilation())) /
378                                  wd.stride();
379     left_halo_size_functions[rhs_dimension] =
380         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
381             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
382             (window_count - 1) * wd.stride() - padding_low +
383                 wd.window_dilation() - 1,
384             wd.window_dilation()));
385     right_halo_size_functions[rhs_dimension] =
386         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
387             lhs_shard_size - rhs_shard_size * wd.window_dilation(),
388             lhs_shard_size - rhs_shard_size * wd.window_dilation() +
389                 padding_low + wd.window_dilation() - 1,
390             wd.window_dilation()));
391 
392     // New RHS window size includes the maximum of both left and right
393     // halos.
394     int64 halo_size =
395         left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) +
396         right_halo_size_functions[rhs_dimension].MaxInRange(
397             0, shard_counts[i] - 1);
398     int64 new_window_size =
399         rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
400 
401     // The amount of new low padding could be dynamic (e.g., window_dilation
402     // != 1), which requires pad (to the maximum) and dynamic slice on LHS.
403     //
404     // If we consider the first window, the offset of the dilated RHS that
405     // aligns with the first valid LHS element for shard i is 'padding_low +
406     // LHS * i'. When the left halo is added to RHS, the offset of the first
407     // RHS element is (RHS * i - left_halo) * window_dilation. The
408     // difference between the two values is the amount of padding_low we
409     // need on LHS.
410     auto new_padding_low_function =
411         OffsetCalculation(HloOpcode::kMultiply,
412                           left_halo_size_functions[rhs_dimension],
413                           OffsetCalculation(MultiplyAddDivideOffsetCalculation(
414                               0, wd.window_dilation(), 1))) -
415         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
416             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
417             -padding_low, 1));
418 
419     int64 new_padding_low_max =
420         new_padding_low_function.MaxInRange(0, shard_counts[i]);
421     int64 new_padding_low = new_padding_low_max;
422     int64 new_padding_high = window_count * wd.stride() +
423                              (new_window_size - 1) * wd.window_dilation() -
424                              new_padding_low - lhs_shard_size;
425 
426     // We do pad/dynamic-slice only when the padding is dynamic.
427     if (!new_padding_low_function.IsConstant()) {
428       need_dynamic_slice_lhs = true;
429       new_padding_low = 0;
430       pad_config.mutable_dimensions(lhs_dimension)
431           ->set_edge_padding_low(new_padding_low_max);
432       pad_config.mutable_dimensions(lhs_dimension)
433           ->set_edge_padding_high(new_padding_low_max);
434       pad_shape.set_dimensions(lhs_dimension,
435                                lhs_shard_size + 2 * new_padding_low_max);
436       dynamic_slice_start_indices[lhs_dimension] =
437           (OffsetCalculation(
438                MultiplyAddDivideOffsetCalculation(0, new_padding_low_max, 1)) -
439            new_padding_low_function)
440               .Calculate(partition_ordinals[lhs_dimension], b);
441       dynamic_slice_shape.set_dimensions(lhs_dimension,
442                                          lhs_shard_size + new_padding_low_max);
443     }
444 
445     // Since the convolution RHS operand size increased with halos, adjust
446     // the window config accordingly.
447     new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
448     new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
449     new_window.mutable_dimensions(i)->set_size(
450         rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
451   }
452 
453   HloInstruction* conv_lhs = lhs.hlo();
454   if (need_dynamic_slice_lhs) {
455     auto pad = b->AddInstruction(
456         HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
457     conv_lhs = b->AddInstruction(HloInstruction::CreateDynamicSlice(
458         dynamic_slice_shape, pad, dynamic_slice_start_indices,
459         dynamic_slice_shape.dimensions()));
460   }
461 
462   // Exchange halo and concatenate.
463   HloInstruction* rhs_with_halo = rhs.hlo();
464   for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
465     int64 dim = dnums.kernel_spatial_dimensions(i);
466     int64 explicit_left_padding_on_full_shape =
467         left_halo_size_functions[dim].Calculate(0);
468     int64 shard_size_with_halo = new_window.dimensions(i).size();
469 
470     // offset_on_padded_shape and padded_full_shape_size are needed only if
471     // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
472     // Since the default value for both the collective-permute is zero and
473     // also we call PadWithValue() on both operands at the beginning, we
474     // don't need to mask here.
475     //
476     // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
477     // if it's always safe.
478     auto offset_on_padded_shape =
479         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
480             rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
481         left_halo_size_functions[dim];
482     int64 padded_full_shape_size =
483         offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
484         new_window.dimensions(i).size();
485     auto concat = ExchangeHaloAndGetValidData(
486         rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
487         right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
488         padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
489         offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
490         partition_ordinals[dim], collective_ops_creator,
491         lhs.state().next_channel_id, b,
492         /*mask_invalid_region=*/false);
493     if (!concat) {
494       return nullptr;
495     }
496     rhs_with_halo = *concat;
497   }
498 
499   TF_ASSIGN_OR_RETURN(
500       auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
501 
502   auto ar = collective_ops_creator.create_cross_partition_all_reduce(
503       b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
504       (*lhs.state().next_channel_id)++);
505   ar->set_sharding(HloSharding::Replicate());
506   return PartitionedHlo(ar, output_base_shape, lhs.state())
507       .Reshard(output_sharding)
508       .hlo();
509 }
510 
511 // Partition convolution when both LHS and RHS are partitioned at spatial
512 // dimensions. Halo exchange will happen on LHS only.
513 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)514 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
515     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
516     const HloSharding& output_sharding,
517     const std::function<StatusOr<HloInstruction*>(
518         HloInstruction*, HloInstruction*, SpmdBuilder*,
519         const Window& conv_window)>& create_sharded_conv,
520     const Window& conv_window, HloInstruction* original_hlo,
521     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
522   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
523   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
524                !rhs.sharding().IsTileMaximal());
525 
526   const auto& dnums = original_hlo->convolution_dimension_numbers();
527 
528   // Check if the operand shardings are aligned. Also we currently don't
529   // support partitioning non-spatial dimensions.
530   std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
531   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
532       dnums.input_batch_dimension();
533   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
534       dnums.input_feature_dimension();
535   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
536     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
537         dnums.input_spatial_dimensions(i);
538   }
539   std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
540   for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
541     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
542   }
543 
544   const Window& window = conv_window;
545   std::vector<int64> reversed_rhs_dims;
546   for (int64 i = 0; i < window.dimensions_size(); ++i) {
547     if (window.dimensions(i).window_reversal()) {
548       reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
549     }
550   }
551   if (!reversed_rhs_dims.empty()) {
552     // Make the reversed dims left-padded to prepare for window reversal.
553     auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
554     if (left_padded_rhs == nullptr) {
555       return nullptr;
556     }
557     left_padded_rhs->set_sharding(rhs.sharding());
558     rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
559   }
560   // Consider window reversal when resharding RHS or LHS. Note: this will not
561   // reverse the data in the shard. We use window reversal to do that.
562   auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
563       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
564       reversed_rhs_dims);
565   auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
566       hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
567       lhs_to_rhs_indices);
568 
569   auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
570                                   const HloSharding& rhs_sharding) {
571     return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
572                1 ||
573            rhs_sharding.tile_assignment().dim(
574                dnums.kernel_output_feature_dimension()) != 1;
575   };
576 
577   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
578       LiteralUtil::Zero(output_base_shape.element_type())));
579   if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
580     if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
581       return nullptr;
582     }
583     lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
584     rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
585   } else {
586     if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
587       return nullptr;
588     }
589     lhs = lhs.PadWithValue(zero);
590     rhs =
591         rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
592   }
593 
594   if (original_hlo->feature_group_count() > 1 &&
595       (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
596            1 ||
597        rhs.sharding().tile_assignment().dim(
598            dnums.kernel_output_feature_dimension()) > 1)) {
599     return nullptr;
600   }
601 
602   if (original_hlo->batch_group_count() > 1 &&
603       (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
604            1 ||
605        rhs.sharding().tile_assignment().dim(
606            dnums.kernel_output_feature_dimension()) > 1)) {
607     return nullptr;
608   }
609   // Reshard LHS by exchanging halo such that each shard computes the partial
610   // sum of the full shape result, and add AllReduce.
611   //
612   // The size of halo on each dimension can be calculated from the projection
613   // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
614   // to the shard size of RHS and LHS, WC is the number of windows, and D is the
615   // window dilation.
616   //
617   // * offset(i): RHS * D * i - low_padding
618   // * limit(i): {RHS * (i + 1) * D - (D - 1)} + (WC - 1) * stride - low_padding
619   //
620   // Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
621   // * left-halo: i * LHS - offset(i)
622   //              = (LHS - RHS * D) * i + low_padding
623   // * right-halo: limit(i) - (i + 1) * LHS
624   //   = (RHS * D - LHS) * (i + 1) + (1 - D)  + (WC - 1) * stride - low_padding
625   //   = (RHS * D - LHS) * i + (RHS * D - LHS) + (1-D)
626   //     + (WC - 1) * stride - low_padding
627   std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
628   std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
629   std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
630   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
631     int64 lhs_dimension = dnums.input_spatial_dimensions(i);
632     int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
633     int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
634     const auto& wd = window.dimensions(i);
635     if (wd.base_dilation() != 1) {
636       // TODO(wangtao): support parallel dim if it is replicate here.
637       return nullptr;
638     }
639 
640     int64 lhs_shard_size =
641         CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
642     int64 rhs_shard_size =
643         CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
644     shard_counts[i] = shard_count;
645     lhs_shard_sizes[i] = lhs_shard_size;
646     rhs_shard_sizes[i] = rhs_shard_size;
647   }
648 
649   std::vector<OffsetCalculation> left_halo_size_functions(
650       output_base_shape.rank());
651   std::vector<OffsetCalculation> right_halo_size_functions(
652       output_base_shape.rank());
653   Window new_window = window;
654 
655   auto partition_ordinals =
656       MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
657   HloInstruction* lhs_with_halo = lhs.hlo();
658   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
659     int64 lhs_dimension = dnums.input_spatial_dimensions(i);
660     int64 lhs_shard_size = lhs_shard_sizes[i];
661     int64 rhs_shard_size = rhs_shard_sizes[i];
662 
663     if (shard_counts[i] == 1) {
664       continue;
665     }
666 
667     // Calculate the left and right halo sizes as described in the comments
668     // above.
669     const auto& wd = window.dimensions(i);
670     int64 padding_low = wd.padding_low();
671     int64 padding_high = wd.padding_high();
672     int64 base = lhs.base_shape().dimensions(lhs_dimension);
673     int64 window_count = 1 + (padding_low + padding_high + base -
674                               (1 + (wd.size() - 1) * wd.window_dilation())) /
675                                  wd.stride();
676     int64 rhs_shard_size_dilated =
677         (rhs_shard_size - 1) * wd.window_dilation() + 1;
678 
679     left_halo_size_functions[lhs_dimension] =
680         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
681             lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
682             1));
683     right_halo_size_functions[lhs_dimension] =
684         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
685             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
686             rhs_shard_size * wd.window_dilation() - lhs_shard_size + 1 -
687                 wd.window_dilation() + wd.stride() * (window_count - 1) -
688                 padding_low,
689             1));
690 
691     // Exchange halo and concatenate.
692     int64 dim = dnums.input_spatial_dimensions(i);
693     int64 explicit_left_padding_on_full_shape = padding_low;
694     int64 shard_size_with_halo =
695         wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
696 
697     new_window.mutable_dimensions(i)->set_padding_low(0);
698     new_window.mutable_dimensions(i)->set_padding_high(0);
699     new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
700 
701     // offset_on_padded_shape and padded_full_shape_size are needed only if
702     // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
703     // Since the default value for both the collective-permute is zero and
704     // also we call PadWithValue() on both operands at the beginning, we
705     // don't need to mask here.
706     //
707     // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
708     // if it's always safe.
709     auto offset_on_padded_shape =
710         OffsetCalculation(MultiplyAddDivideOffsetCalculation());
711     int64 padded_full_shape_size = 0;
712     auto concat = ExchangeHaloAndGetValidData(
713         lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
714         right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
715         padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
716         offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
717         partition_ordinals[dim], lhs.state().collective_ops_creator,
718         lhs.state().next_channel_id, b,
719         /*mask_invalid_region=*/false);
720     if (!concat) {
721       return nullptr;
722     }
723     lhs_with_halo = *concat;
724   }
725 
726   TF_ASSIGN_OR_RETURN(
727       auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
728   auto ar =
729       lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
730           b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
731           (*lhs.state().next_channel_id)++);
732   ar->set_sharding(HloSharding::Replicate());
733   return PartitionedHlo(ar, output_base_shape, lhs.state())
734       .Reshard(output_sharding)
735       .hlo();
736 }
737 
738 // Partition convolution when output is sharded. Will shard LHS with replicated
739 // RHS.
PartitionConvolutionTiledOutput(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,SpmdBuilder * b)740 StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
741     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
742     const HloSharding& output_sharding,
743     const std::function<StatusOr<HloInstruction*>(
744         HloInstruction*, HloInstruction*, SpmdBuilder*,
745         const Window& conv_window)>& create_sharded_conv,
746     const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
747   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
748   const auto& dnums = original_hlo->convolution_dimension_numbers();
749   TF_RET_CHECK(!output_sharding.IsTileMaximal());
750   // We don't currently support sharding on output feature dimension.
751   if (output_sharding.tile_assignment().dim(dnums.output_feature_dimension()) >
752       1) {
753     return nullptr;
754   }
755 
756   // Check if the operand and the output sharding are aligned.
757   std::vector<int64> input_to_output_indices(output_base_shape.rank());
758   input_to_output_indices[dnums.input_batch_dimension()] =
759       dnums.output_batch_dimension();
760   input_to_output_indices[dnums.input_feature_dimension()] =
761       dnums.output_feature_dimension();
762   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
763     input_to_output_indices[dnums.input_spatial_dimensions(i)] =
764         dnums.output_spatial_dimensions(i);
765   }
766   auto target_operand_sharding = hlo_sharding_util::TransposeSharding(
767       output_sharding, input_to_output_indices);
768   lhs = lhs.Reshard(target_operand_sharding);
769 
770   // Replicate the RHS.
771   rhs = rhs.Reshard(HloSharding::Replicate());
772 
773   // Convolution window config does not include batch and feature dimensions,
774   // whereas ReshardAsWindowedInput() expects the same number of window
775   // dimensions as the rank of the operand. So add two more trivial
776   // dimensions.
777   std::vector<int64> ones(output_base_shape.rank(), 1);
778   auto operand_window = window_util::MakeWindow(ones);
779   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
780     *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
781         conv_window.dimensions(i);
782   }
783 
784   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
785       LiteralUtil::Zero(output_base_shape.element_type())));
786   auto resharded_operand_and_window =
787       lhs.ReshardAsWindowedInput(operand_window, target_operand_sharding, zero);
788   if (!resharded_operand_and_window.has_value()) {
789     return nullptr;
790   }
791   Window new_window;
792   for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
793     *new_window.add_dimensions() =
794         resharded_operand_and_window->shard_window.dimensions(
795             dnums.input_spatial_dimensions(i));
796   }
797 
798   TF_ASSIGN_OR_RETURN(
799       auto sharded_conv,
800       create_sharded_conv(resharded_operand_and_window->sharded_input,
801                           rhs.hlo(), b, new_window));
802 
803   auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
804   if (!resharded_operand_and_window->dynamic_slice_index_on_output
805            .has_value()) {
806     CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
807     return sharded_conv;
808   }
809   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
810       shard_shape, sharded_conv,
811       *resharded_operand_and_window->dynamic_slice_index_on_output,
812       shard_shape.dimensions()));
813 }
814 
815 // Partition convolution with only one kind of dims partitioned.
PartitionConvolutionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)816 StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
817     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
818     const HloSharding& output_sharding,
819     const std::function<StatusOr<HloInstruction*>(
820         HloInstruction*, HloInstruction*, SpmdBuilder*,
821         const Window& conv_window)>& create_sharded_conv,
822     const Window& conv_window, HloInstruction* original_hlo,
823     int64 num_partitions, const SpmdPartitionerOptions& options,
824     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
825   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
826 
827   // Case 1: Handle depthwise convolution with batch group count or
828   // feature group count.
829   if (original_hlo->batch_group_count() > 1) {
830     TF_ASSIGN_OR_RETURN(
831         auto parallel_partitioned_conv,
832         PartitionConvolutionWithBatchGroupCount(
833             lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
834             conv_window, original_hlo, num_partitions, b));
835     if (parallel_partitioned_conv) {
836       return parallel_partitioned_conv;
837     }
838   }
839 
840   if (original_hlo->feature_group_count() > 1) {
841     TF_ASSIGN_OR_RETURN(
842         auto parallel_partitioned_conv,
843         PartitionConvolutionWithFeatureGroupCount(
844             lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
845             conv_window, original_hlo, num_partitions, b));
846     if (parallel_partitioned_conv) {
847       return parallel_partitioned_conv;
848     }
849   }
850 
851   // Case 2: both RHS and LHS are tiled.
852   // Handling cases where both operands' shardings are aligned. We check that
853   // the LHS batch dimension is not partitioned because it is mapped to the
854   // output feature dimension in aligned_rhs_sharding, which are not the same
855   // dimension.
856   if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
857     if (options.conv_halo_exchange_always_on_lhs) {
858       TF_ASSIGN_OR_RETURN(
859           auto partitioned_conv,
860           PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
861               lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
862               conv_window, original_hlo, partition_id, module, b));
863       if (partitioned_conv) {
864         return partitioned_conv;
865       }
866     } else {
867       TF_ASSIGN_OR_RETURN(
868           auto partitioned_conv,
869           PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
870               lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
871               conv_window, original_hlo, partition_id, module, b));
872 
873       if (partitioned_conv) {
874         return partitioned_conv;
875       }
876     }
877   }
878 
879   // Case 3: output is tiled.
880   if (!output_sharding.IsTileMaximal()) {
881     TF_ASSIGN_OR_RETURN(auto partitioned_conv,
882                         PartitionConvolutionTiledOutput(
883                             lhs, rhs, output_base_shape, output_sharding,
884                             create_sharded_conv, conv_window, original_hlo, b));
885 
886     if (partitioned_conv) {
887       return partitioned_conv;
888     }
889   }
890   return nullptr;
891 }
892 
CreateShardedConvConvolution(const HloInstruction & conv,const dot_as_convolution_util::DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo,const Window & conv_window)893 StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
894     const HloInstruction& conv,
895     const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
896     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
897     const Window& conv_window) {
898   CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
899   const auto& conv_dnums = conv.convolution_dimension_numbers();
900   auto window = conv.window();
901   for (const auto& dim : dot_dnums.batch_dims) {
902     auto wd = window.mutable_dimensions(dim.spatial_dim);
903     wd->set_size(sharded_lhs_hlo->shape().dimensions(
904         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
905     wd->set_stride(std::max<int64>(1, wd->size() - 1));
906     wd->set_base_dilation(wd->size());
907   }
908   for (const auto& dim : dot_dnums.contracting_dims) {
909     if (dim.spatial_dim < 0) {
910       continue;
911     }
912     auto wd = window.mutable_dimensions(dim.spatial_dim);
913     wd->set_size(sharded_lhs_hlo->shape().dimensions(
914         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
915   }
916   for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
917     if (dim.spatial_dim < 0) {
918       continue;
919     }
920     auto wd = window.mutable_dimensions(dim.spatial_dim);
921     wd->set_size(sharded_rhs_hlo->shape().dimensions(
922         conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
923     wd->set_padding_high(wd->size() - 1);
924     wd->set_padding_low(wd->size() - 1);
925   }
926 
927   for (const auto& dim : dot_dnums.conv_spatial_dims) {
928     auto wd = window.mutable_dimensions(dim.spatial_dim);
929     const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
930     wd->set_size(new_window_dimension.size());
931     wd->set_padding_high(new_window_dimension.padding_high());
932     wd->set_padding_low(new_window_dimension.padding_low());
933   }
934 
935   int64 feature_group_count = conv.feature_group_count();
936   if (feature_group_count > 1) {
937     feature_group_count = sharded_lhs_hlo->shape().dimensions(
938                               conv_dnums.input_feature_dimension()) /
939                           sharded_rhs_hlo->shape().dimensions(
940                               conv_dnums.kernel_input_feature_dimension());
941   }
942 
943   int64 batch_group_count = conv.batch_group_count();
944   if (batch_group_count > 1) {
945     batch_group_count =
946         sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
947   }
948 
949   TF_ASSIGN_OR_RETURN(
950       Shape sharded_conv_shape,
951       ShapeInference::InferConvolveShape(
952           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
953           feature_group_count, batch_group_count, window, conv_dnums,
954           /*preferred_element_type=*/conv.shape().element_type()));
955   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
956   return HloInstruction::CreateConvolve(
957       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
958       batch_group_count, window, conv_dnums, conv.precision_config());
959 }
960 
961 }  // namespace
962 
963 // Partition convolution.
PartitionConvolution(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)964 StatusOr<HloInstruction*> PartitionConvolution(
965     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
966     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
967     const std::function<StatusOr<HloInstruction*>(
968         HloInstruction*, HloInstruction*, SpmdBuilder*,
969         const Window& conv_window)>& create_sharded_conv,
970     const Window& conv_window, HloInstruction* original_hlo,
971     int64 num_partitions, const SpmdPartitionerOptions& options,
972     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
973   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
974 
975   TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
976                       PartitionConvolutionBaseCase(
977                           lhs, rhs, output_base_shape, output_sharding,
978                           create_sharded_conv, conv_window, original_hlo,
979                           num_partitions, options, partition_id, module, b));
980   if (try_partitioned_conv) {
981     return try_partitioned_conv;
982   }
983 
984   return nullptr;
985 }
986 
HandleConvolution(HloInstruction * hlo)987 Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
988   auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
989   spmd::DotConvDimsMapping mapping;
990   for (const auto& dims : dims_info.batch_dims) {
991     mapping.batch_dims.emplace_back();
992     mapping.batch_dims.back().lhs = dims.lhs;
993     mapping.batch_dims.back().rhs = dims.rhs;
994     mapping.batch_dims.back().output = dims.output;
995     mapping.batch_dims.back().spatial = dims.spatial_dim;
996   }
997   for (const auto& dims : dims_info.contracting_dims) {
998     mapping.contracting_dims.emplace_back();
999     mapping.contracting_dims.back().lhs = dims.lhs;
1000     mapping.contracting_dims.back().rhs = dims.rhs;
1001     mapping.contracting_dims.back().output = dims.output;
1002     mapping.contracting_dims.back().spatial = dims.spatial_dim;
1003   }
1004   for (const auto& dims : dims_info.lhs_non_contracting_dims) {
1005     mapping.lhs_non_contracting_dims.emplace_back();
1006     mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
1007     mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
1008     mapping.lhs_non_contracting_dims.back().output = dims.output;
1009     mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1010   }
1011   for (const auto& dims : dims_info.rhs_non_contracting_dims) {
1012     mapping.rhs_non_contracting_dims.emplace_back();
1013     mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
1014     mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
1015     mapping.rhs_non_contracting_dims.back().output = dims.output;
1016     mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1017   }
1018   for (const auto& dims : dims_info.conv_spatial_dims) {
1019     mapping.conv_spatial_dims.emplace_back();
1020     mapping.conv_spatial_dims.back().lhs = dims.lhs;
1021     mapping.conv_spatial_dims.back().rhs = dims.rhs;
1022     mapping.conv_spatial_dims.back().output = dims.output;
1023     mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
1024   }
1025   auto create_sharded_conv =
1026       [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
1027           spmd::SpmdBuilder* b,
1028           const Window& conv_window) -> StatusOr<HloInstruction*> {
1029     if (dims_info.conv_spatial_dims.empty()) {
1030       TF_ASSIGN_OR_RETURN(
1031           auto sharded_conv,
1032           dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
1033               *hlo, dims_info, lhs_hlo, rhs_hlo));
1034       return b->AddInstruction(std::move(sharded_conv));
1035     } else {
1036       TF_ASSIGN_OR_RETURN(auto sharded_conv,
1037                           CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
1038                                                        rhs_hlo, conv_window));
1039       return b->AddInstruction(std::move(sharded_conv));
1040     }
1041   };
1042 
1043   return HandleDotHelper(hlo, mapping, create_sharded_conv);
1044 }
1045 
1046 }  // namespace spmd
1047 }  // namespace xla
1048