1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
25 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
26 #include "tensorflow/compiler/xla/service/shape_inference.h"
27 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
28 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/numbers.h"
34
35 namespace xla {
36 namespace spmd {
37
38 namespace {
39
40 // Partition convolution with batch group count.
PartitionConvolutionWithBatchGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,SpmdBuilder * b)41 StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
42 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
43 const HloSharding& output_sharding,
44 const std::function<StatusOr<HloInstruction*>(
45 HloInstruction*, HloInstruction*, SpmdBuilder*,
46 const Window& conv_window)>& create_sharded_conv,
47 const Window& conv_window, HloInstruction* original_hlo,
48 int64 num_partitions, SpmdBuilder* b) {
49 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
50 if (original_hlo->batch_group_count() == 1 ||
51 original_hlo->batch_group_count() < num_partitions) {
52 return nullptr;
53 }
54
55 const auto& dnums = original_hlo->convolution_dimension_numbers();
56 // Only supports batch_group_size equals input_batch_size case.
57 const int64 input_batch_size =
58 lhs.base_shape().dimensions(dnums.input_batch_dimension());
59 const int64 kernel_output_feature_size =
60 rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
61 if (input_batch_size != kernel_output_feature_size ||
62 original_hlo->batch_group_count() != input_batch_size) {
63 return nullptr;
64 }
65
66 // Map RHS indices to LHS indices.
67 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
68 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
69 dnums.input_batch_dimension();
70 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
71 dnums.input_feature_dimension();
72 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
73 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
74 dnums.input_spatial_dimensions(i);
75 }
76
77 // Map LHS indices to RHS indices.
78 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
79 for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
80 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
81 }
82
83 // Map LHS indices to output indices.
84 std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1);
85 lhs_to_output_indices[dnums.input_batch_dimension()] =
86 dnums.output_feature_dimension();
87 lhs_to_output_indices[dnums.input_feature_dimension()] =
88 dnums.output_batch_dimension();
89 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
90 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
91 dnums.output_spatial_dimensions(i);
92 }
93
94 // Align LHS or RHS to other operand if input batch dim or kernel output
95 // feature dim is partitioned.
96 auto aligned_rhs_sharding =
97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
98 auto aligned_lhs_sharding =
99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
100
101 bool lhs_batch_dim_is_partitioned =
102 (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) ==
103 num_partitions);
104 bool rhs_output_feature_dim_is_partitioned =
105 (ShardCountAtDim(rhs.sharding(),
106 dnums.kernel_output_feature_dimension()) ==
107 num_partitions);
108 if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) {
109 return nullptr;
110 }
111 // Reshard LHS or RHS to partition at batch dimension or output feature
112 // dimension as the other operand.
113 if (lhs_batch_dim_is_partitioned) {
114 rhs = rhs.Reshard(aligned_rhs_sharding);
115 } else {
116 lhs = lhs.Reshard(aligned_lhs_sharding);
117 }
118 // Align output sharding after LHS and RHS sharding are consistent.
119 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
120 lhs.sharding(), lhs_to_output_indices);
121
122 // Create partitioned convolution.
123 TF_ASSIGN_OR_RETURN(
124 auto sharded_conv,
125 create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
126 sharded_conv->set_sharding(aligned_output_sharding);
127 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
128 .Reshard(output_sharding)
129 .hlo();
130 }
131
132 // Partition convolution with feature group count.
PartitionConvolutionWithFeatureGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,SpmdBuilder * b)133 StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
134 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
135 const HloSharding& output_sharding,
136 const std::function<StatusOr<HloInstruction*>(
137 HloInstruction*, HloInstruction*, SpmdBuilder*,
138 const Window& conv_window)>& create_sharded_conv,
139 const Window& conv_window, HloInstruction* original_hlo,
140 int64 num_partitions, SpmdBuilder* b) {
141 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
142 if (original_hlo->feature_group_count() == 1 ||
143 original_hlo->feature_group_count() < num_partitions) {
144 return nullptr;
145 }
146
147 const auto& dnums = original_hlo->convolution_dimension_numbers();
148 const int64 input_feature_size =
149 lhs.base_shape().dimensions(dnums.input_feature_dimension());
150 const int64 kernel_output_feature_size =
151 rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
152 if (input_feature_size != kernel_output_feature_size ||
153 input_feature_size % original_hlo->feature_group_count() != 0) {
154 return nullptr;
155 }
156
157 // Align RHS indices to LHS.
158 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
159 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
160 dnums.input_feature_dimension();
161 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
162 dnums.input_batch_dimension();
163 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
164 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
165 dnums.input_spatial_dimensions(i);
166 }
167
168 // Align LHS indices to RHS.
169 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
170 for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
171 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
172 }
173
174 // Align LHS indices to output.
175 std::vector<int64> lhs_to_output_indices(output_base_shape.rank());
176 lhs_to_output_indices[dnums.input_feature_dimension()] =
177 dnums.output_feature_dimension();
178 lhs_to_output_indices[dnums.input_batch_dimension()] =
179 dnums.output_batch_dimension();
180 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
181 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
182 dnums.output_spatial_dimensions(i);
183 }
184
185 // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is
186 // partitioned.
187 auto aligned_rhs_sharding =
188 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
189 auto aligned_lhs_sharding =
190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
191
192 bool lhs_feature_dim_is_partitioned =
193 (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) ==
194 num_partitions);
195 bool rhs_output_feature_dim_is_partitioned =
196 (ShardCountAtDim(rhs.sharding(),
197 dnums.kernel_output_feature_dimension()) ==
198 num_partitions);
199 if (!lhs_feature_dim_is_partitioned &&
200 !rhs_output_feature_dim_is_partitioned) {
201 return nullptr;
202 }
203 // Reshard LHS or RHS to partition at input feature dimension or output
204 // feature dimension as the other operand.
205 if (lhs_feature_dim_is_partitioned) {
206 rhs = rhs.Reshard(aligned_rhs_sharding);
207 } else {
208 lhs = lhs.Reshard(aligned_lhs_sharding);
209 }
210
211 // Align output sharding after LHS and RHS sharding are consistent.
212 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
213 lhs.sharding(), lhs_to_output_indices);
214
215 TF_ASSIGN_OR_RETURN(
216 auto sharded_conv,
217 create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
218 sharded_conv->set_sharding(aligned_output_sharding);
219 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
220 .Reshard(output_sharding)
221 .hlo();
222 }
223
224 // Partition convolution when both LHS and RHS are partitioned at spatial
225 // dimensions. Halo exchange will happen on RHS only.
226 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)227 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
228 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
229 const HloSharding& output_sharding,
230 const std::function<StatusOr<HloInstruction*>(
231 HloInstruction*, HloInstruction*, SpmdBuilder*,
232 const Window& conv_window)>& create_sharded_conv,
233 const Window& conv_window, HloInstruction* original_hlo,
234 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
235 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
236 TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
237 !rhs.sharding().IsTileMaximal());
238
239 const auto& dnums = original_hlo->convolution_dimension_numbers();
240 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
241 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
242 dnums.input_batch_dimension();
243 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
244 dnums.input_feature_dimension();
245 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
246 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
247 dnums.input_spatial_dimensions(i);
248 }
249 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
250 for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
251 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
252 }
253 auto aligned_rhs_sharding =
254 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
255 auto aligned_lhs_sharding =
256 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
257
258 auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
259 const HloSharding& rhs_sharding) {
260 // We currently don't support partitioning input batch or output feature
261 // dimensions.
262 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
263 1 ||
264 rhs_sharding.tile_assignment().dim(
265 dnums.kernel_output_feature_dimension()) != 1;
266 };
267
268 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
269 LiteralUtil::Zero(output_base_shape.element_type())));
270 if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
271 if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
272 return nullptr;
273 }
274 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
275 rhs = rhs.PadWithValue(zero);
276 } else {
277 if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
278 return nullptr;
279 }
280 lhs = lhs.PadWithValue(zero);
281 rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
282 }
283
284 if (original_hlo->feature_group_count() > 1 &&
285 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
286 1 ||
287 rhs.sharding().tile_assignment().dim(
288 dnums.kernel_output_feature_dimension()) > 1)) {
289 return nullptr;
290 }
291
292 if (original_hlo->batch_group_count() > 1 &&
293 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
294 1 ||
295 rhs.sharding().tile_assignment().dim(
296 dnums.kernel_output_feature_dimension()) > 1)) {
297 return nullptr;
298 }
299
300 // Reshard RHS so that each shard computes the partial sum of the full
301 // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
302 // that reshards LHS.
303 //
304 // The size of halo on each dimension can be calculated from the
305 // projection onto the RHS that shard i needs to read. RHS and LHS below
306 // refers to the shard size of RHS and LHS, WC is the number of windows,
307 // and D is the window dilation.
308 //
309 // * offset(i): LHS * i + low_padding - (WC - 1) * stride
310 // * limit(i): LHS * (i + 1) + low_padding
311 //
312 // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
313 // * left-halo: i * RHS - offset(i)
314 // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
315 // * right-halo: limit(i) - (i + 1) * RHS
316 // = (i + 1) * (LHS - RHS * D) + low_pading
317 const auto& collective_ops_creator = lhs.state().collective_ops_creator;
318 std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
319 std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
320 std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
321
322 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
323 int64 lhs_dimension = dnums.input_spatial_dimensions(i);
324 int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
325 int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
326 const auto& wd = conv_window.dimensions(i);
327 if (wd.base_dilation() != 1 || wd.window_reversal()) {
328 return nullptr;
329 }
330
331 int64 lhs_shard_size =
332 CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
333 int64 rhs_shard_size =
334 CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
335 shard_counts[i] = shard_count;
336 lhs_shard_sizes[i] = lhs_shard_size;
337 rhs_shard_sizes[i] = rhs_shard_size;
338 }
339
340 std::vector<OffsetCalculation> left_halo_size_functions(
341 output_base_shape.rank());
342 std::vector<OffsetCalculation> right_halo_size_functions(
343 output_base_shape.rank());
344 Window new_window = conv_window;
345
346 // Data structures needed for Pad and DynamicSlice on LHS if needed.
347 bool need_dynamic_slice_lhs = false;
348 auto partition_ordinals =
349 MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
350 std::vector<int64> zero_padding(output_base_shape.rank());
351 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
352 auto zero_s32 =
353 b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
354 std::vector<HloInstruction*> dynamic_slice_start_indices(
355 output_base_shape.rank(), zero_s32);
356 Shape dynamic_slice_shape = lhs.hlo()->shape();
357 Shape pad_shape = lhs.hlo()->shape();
358
359 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
360 int64 lhs_dimension = dnums.input_spatial_dimensions(i);
361 int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
362 int64 lhs_shard_size = lhs_shard_sizes[i];
363 int64 rhs_shard_size = rhs_shard_sizes[i];
364
365 if (shard_counts[i] == 1) {
366 continue;
367 }
368
369 // Calculate the left and right halo sizes as described in the comments
370 // above. It calculcates the halo sizes with dilation, so we apply
371 // CeilOfRatio({left,right}_halo_size, window_dilation).
372 const auto& wd = conv_window.dimensions(i);
373 int64 padding_low = wd.padding_low();
374 int64 padding_high = wd.padding_high();
375 int64 base = lhs.base_shape().dimensions(lhs_dimension);
376 int64 window_count = 1 + (padding_low + padding_high + base -
377 (1 + (wd.size() - 1) * wd.window_dilation())) /
378 wd.stride();
379 left_halo_size_functions[rhs_dimension] =
380 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
381 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
382 (window_count - 1) * wd.stride() - padding_low +
383 wd.window_dilation() - 1,
384 wd.window_dilation()));
385 right_halo_size_functions[rhs_dimension] =
386 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
387 lhs_shard_size - rhs_shard_size * wd.window_dilation(),
388 lhs_shard_size - rhs_shard_size * wd.window_dilation() +
389 padding_low + wd.window_dilation() - 1,
390 wd.window_dilation()));
391
392 // New RHS window size includes the maximum of both left and right
393 // halos.
394 int64 halo_size =
395 left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) +
396 right_halo_size_functions[rhs_dimension].MaxInRange(
397 0, shard_counts[i] - 1);
398 int64 new_window_size =
399 rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
400
401 // The amount of new low padding could be dynamic (e.g., window_dilation
402 // != 1), which requires pad (to the maximum) and dynamic slice on LHS.
403 //
404 // If we consider the first window, the offset of the dilated RHS that
405 // aligns with the first valid LHS element for shard i is 'padding_low +
406 // LHS * i'. When the left halo is added to RHS, the offset of the first
407 // RHS element is (RHS * i - left_halo) * window_dilation. The
408 // difference between the two values is the amount of padding_low we
409 // need on LHS.
410 auto new_padding_low_function =
411 OffsetCalculation(HloOpcode::kMultiply,
412 left_halo_size_functions[rhs_dimension],
413 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
414 0, wd.window_dilation(), 1))) -
415 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
416 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
417 -padding_low, 1));
418
419 int64 new_padding_low_max =
420 new_padding_low_function.MaxInRange(0, shard_counts[i]);
421 int64 new_padding_low = new_padding_low_max;
422 int64 new_padding_high = window_count * wd.stride() +
423 (new_window_size - 1) * wd.window_dilation() -
424 new_padding_low - lhs_shard_size;
425
426 // We do pad/dynamic-slice only when the padding is dynamic.
427 if (!new_padding_low_function.IsConstant()) {
428 need_dynamic_slice_lhs = true;
429 new_padding_low = 0;
430 pad_config.mutable_dimensions(lhs_dimension)
431 ->set_edge_padding_low(new_padding_low_max);
432 pad_config.mutable_dimensions(lhs_dimension)
433 ->set_edge_padding_high(new_padding_low_max);
434 pad_shape.set_dimensions(lhs_dimension,
435 lhs_shard_size + 2 * new_padding_low_max);
436 dynamic_slice_start_indices[lhs_dimension] =
437 (OffsetCalculation(
438 MultiplyAddDivideOffsetCalculation(0, new_padding_low_max, 1)) -
439 new_padding_low_function)
440 .Calculate(partition_ordinals[lhs_dimension], b);
441 dynamic_slice_shape.set_dimensions(lhs_dimension,
442 lhs_shard_size + new_padding_low_max);
443 }
444
445 // Since the convolution RHS operand size increased with halos, adjust
446 // the window config accordingly.
447 new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
448 new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
449 new_window.mutable_dimensions(i)->set_size(
450 rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
451 }
452
453 HloInstruction* conv_lhs = lhs.hlo();
454 if (need_dynamic_slice_lhs) {
455 auto pad = b->AddInstruction(
456 HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
457 conv_lhs = b->AddInstruction(HloInstruction::CreateDynamicSlice(
458 dynamic_slice_shape, pad, dynamic_slice_start_indices,
459 dynamic_slice_shape.dimensions()));
460 }
461
462 // Exchange halo and concatenate.
463 HloInstruction* rhs_with_halo = rhs.hlo();
464 for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
465 int64 dim = dnums.kernel_spatial_dimensions(i);
466 int64 explicit_left_padding_on_full_shape =
467 left_halo_size_functions[dim].Calculate(0);
468 int64 shard_size_with_halo = new_window.dimensions(i).size();
469
470 // offset_on_padded_shape and padded_full_shape_size are needed only if
471 // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
472 // Since the default value for both the collective-permute is zero and
473 // also we call PadWithValue() on both operands at the beginning, we
474 // don't need to mask here.
475 //
476 // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
477 // if it's always safe.
478 auto offset_on_padded_shape =
479 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
480 rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
481 left_halo_size_functions[dim];
482 int64 padded_full_shape_size =
483 offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
484 new_window.dimensions(i).size();
485 auto concat = ExchangeHaloAndGetValidData(
486 rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
487 right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
488 padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
489 offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
490 partition_ordinals[dim], collective_ops_creator,
491 lhs.state().next_channel_id, b,
492 /*mask_invalid_region=*/false);
493 if (!concat) {
494 return nullptr;
495 }
496 rhs_with_halo = *concat;
497 }
498
499 TF_ASSIGN_OR_RETURN(
500 auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
501
502 auto ar = collective_ops_creator.create_cross_partition_all_reduce(
503 b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
504 (*lhs.state().next_channel_id)++);
505 ar->set_sharding(HloSharding::Replicate());
506 return PartitionedHlo(ar, output_base_shape, lhs.state())
507 .Reshard(output_sharding)
508 .hlo();
509 }
510
511 // Partition convolution when both LHS and RHS are partitioned at spatial
512 // dimensions. Halo exchange will happen on LHS only.
513 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)514 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
515 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
516 const HloSharding& output_sharding,
517 const std::function<StatusOr<HloInstruction*>(
518 HloInstruction*, HloInstruction*, SpmdBuilder*,
519 const Window& conv_window)>& create_sharded_conv,
520 const Window& conv_window, HloInstruction* original_hlo,
521 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
522 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
523 TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
524 !rhs.sharding().IsTileMaximal());
525
526 const auto& dnums = original_hlo->convolution_dimension_numbers();
527
528 // Check if the operand shardings are aligned. Also we currently don't
529 // support partitioning non-spatial dimensions.
530 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
531 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
532 dnums.input_batch_dimension();
533 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
534 dnums.input_feature_dimension();
535 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
536 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
537 dnums.input_spatial_dimensions(i);
538 }
539 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
540 for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
541 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
542 }
543
544 const Window& window = conv_window;
545 std::vector<int64> reversed_rhs_dims;
546 for (int64 i = 0; i < window.dimensions_size(); ++i) {
547 if (window.dimensions(i).window_reversal()) {
548 reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
549 }
550 }
551 if (!reversed_rhs_dims.empty()) {
552 // Make the reversed dims left-padded to prepare for window reversal.
553 auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
554 if (left_padded_rhs == nullptr) {
555 return nullptr;
556 }
557 left_padded_rhs->set_sharding(rhs.sharding());
558 rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
559 }
560 // Consider window reversal when resharding RHS or LHS. Note: this will not
561 // reverse the data in the shard. We use window reversal to do that.
562 auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
563 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
564 reversed_rhs_dims);
565 auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
566 hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
567 lhs_to_rhs_indices);
568
569 auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
570 const HloSharding& rhs_sharding) {
571 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
572 1 ||
573 rhs_sharding.tile_assignment().dim(
574 dnums.kernel_output_feature_dimension()) != 1;
575 };
576
577 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
578 LiteralUtil::Zero(output_base_shape.element_type())));
579 if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
580 if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
581 return nullptr;
582 }
583 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
584 rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
585 } else {
586 if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
587 return nullptr;
588 }
589 lhs = lhs.PadWithValue(zero);
590 rhs =
591 rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
592 }
593
594 if (original_hlo->feature_group_count() > 1 &&
595 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
596 1 ||
597 rhs.sharding().tile_assignment().dim(
598 dnums.kernel_output_feature_dimension()) > 1)) {
599 return nullptr;
600 }
601
602 if (original_hlo->batch_group_count() > 1 &&
603 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
604 1 ||
605 rhs.sharding().tile_assignment().dim(
606 dnums.kernel_output_feature_dimension()) > 1)) {
607 return nullptr;
608 }
609 // Reshard LHS by exchanging halo such that each shard computes the partial
610 // sum of the full shape result, and add AllReduce.
611 //
612 // The size of halo on each dimension can be calculated from the projection
613 // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
614 // to the shard size of RHS and LHS, WC is the number of windows, and D is the
615 // window dilation.
616 //
617 // * offset(i): RHS * D * i - low_padding
618 // * limit(i): {RHS * (i + 1) * D - (D - 1)} + (WC - 1) * stride - low_padding
619 //
620 // Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
621 // * left-halo: i * LHS - offset(i)
622 // = (LHS - RHS * D) * i + low_padding
623 // * right-halo: limit(i) - (i + 1) * LHS
624 // = (RHS * D - LHS) * (i + 1) + (1 - D) + (WC - 1) * stride - low_padding
625 // = (RHS * D - LHS) * i + (RHS * D - LHS) + (1-D)
626 // + (WC - 1) * stride - low_padding
627 std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
628 std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
629 std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
630 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
631 int64 lhs_dimension = dnums.input_spatial_dimensions(i);
632 int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
633 int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
634 const auto& wd = window.dimensions(i);
635 if (wd.base_dilation() != 1) {
636 // TODO(wangtao): support parallel dim if it is replicate here.
637 return nullptr;
638 }
639
640 int64 lhs_shard_size =
641 CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
642 int64 rhs_shard_size =
643 CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
644 shard_counts[i] = shard_count;
645 lhs_shard_sizes[i] = lhs_shard_size;
646 rhs_shard_sizes[i] = rhs_shard_size;
647 }
648
649 std::vector<OffsetCalculation> left_halo_size_functions(
650 output_base_shape.rank());
651 std::vector<OffsetCalculation> right_halo_size_functions(
652 output_base_shape.rank());
653 Window new_window = window;
654
655 auto partition_ordinals =
656 MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
657 HloInstruction* lhs_with_halo = lhs.hlo();
658 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
659 int64 lhs_dimension = dnums.input_spatial_dimensions(i);
660 int64 lhs_shard_size = lhs_shard_sizes[i];
661 int64 rhs_shard_size = rhs_shard_sizes[i];
662
663 if (shard_counts[i] == 1) {
664 continue;
665 }
666
667 // Calculate the left and right halo sizes as described in the comments
668 // above.
669 const auto& wd = window.dimensions(i);
670 int64 padding_low = wd.padding_low();
671 int64 padding_high = wd.padding_high();
672 int64 base = lhs.base_shape().dimensions(lhs_dimension);
673 int64 window_count = 1 + (padding_low + padding_high + base -
674 (1 + (wd.size() - 1) * wd.window_dilation())) /
675 wd.stride();
676 int64 rhs_shard_size_dilated =
677 (rhs_shard_size - 1) * wd.window_dilation() + 1;
678
679 left_halo_size_functions[lhs_dimension] =
680 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
681 lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
682 1));
683 right_halo_size_functions[lhs_dimension] =
684 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
685 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
686 rhs_shard_size * wd.window_dilation() - lhs_shard_size + 1 -
687 wd.window_dilation() + wd.stride() * (window_count - 1) -
688 padding_low,
689 1));
690
691 // Exchange halo and concatenate.
692 int64 dim = dnums.input_spatial_dimensions(i);
693 int64 explicit_left_padding_on_full_shape = padding_low;
694 int64 shard_size_with_halo =
695 wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
696
697 new_window.mutable_dimensions(i)->set_padding_low(0);
698 new_window.mutable_dimensions(i)->set_padding_high(0);
699 new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
700
701 // offset_on_padded_shape and padded_full_shape_size are needed only if
702 // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
703 // Since the default value for both the collective-permute is zero and
704 // also we call PadWithValue() on both operands at the beginning, we
705 // don't need to mask here.
706 //
707 // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
708 // if it's always safe.
709 auto offset_on_padded_shape =
710 OffsetCalculation(MultiplyAddDivideOffsetCalculation());
711 int64 padded_full_shape_size = 0;
712 auto concat = ExchangeHaloAndGetValidData(
713 lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
714 right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
715 padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
716 offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
717 partition_ordinals[dim], lhs.state().collective_ops_creator,
718 lhs.state().next_channel_id, b,
719 /*mask_invalid_region=*/false);
720 if (!concat) {
721 return nullptr;
722 }
723 lhs_with_halo = *concat;
724 }
725
726 TF_ASSIGN_OR_RETURN(
727 auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
728 auto ar =
729 lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
730 b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
731 (*lhs.state().next_channel_id)++);
732 ar->set_sharding(HloSharding::Replicate());
733 return PartitionedHlo(ar, output_base_shape, lhs.state())
734 .Reshard(output_sharding)
735 .hlo();
736 }
737
738 // Partition convolution when output is sharded. Will shard LHS with replicated
739 // RHS.
PartitionConvolutionTiledOutput(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,SpmdBuilder * b)740 StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
741 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
742 const HloSharding& output_sharding,
743 const std::function<StatusOr<HloInstruction*>(
744 HloInstruction*, HloInstruction*, SpmdBuilder*,
745 const Window& conv_window)>& create_sharded_conv,
746 const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
747 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
748 const auto& dnums = original_hlo->convolution_dimension_numbers();
749 TF_RET_CHECK(!output_sharding.IsTileMaximal());
750 // We don't currently support sharding on output feature dimension.
751 if (output_sharding.tile_assignment().dim(dnums.output_feature_dimension()) >
752 1) {
753 return nullptr;
754 }
755
756 // Check if the operand and the output sharding are aligned.
757 std::vector<int64> input_to_output_indices(output_base_shape.rank());
758 input_to_output_indices[dnums.input_batch_dimension()] =
759 dnums.output_batch_dimension();
760 input_to_output_indices[dnums.input_feature_dimension()] =
761 dnums.output_feature_dimension();
762 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
763 input_to_output_indices[dnums.input_spatial_dimensions(i)] =
764 dnums.output_spatial_dimensions(i);
765 }
766 auto target_operand_sharding = hlo_sharding_util::TransposeSharding(
767 output_sharding, input_to_output_indices);
768 lhs = lhs.Reshard(target_operand_sharding);
769
770 // Replicate the RHS.
771 rhs = rhs.Reshard(HloSharding::Replicate());
772
773 // Convolution window config does not include batch and feature dimensions,
774 // whereas ReshardAsWindowedInput() expects the same number of window
775 // dimensions as the rank of the operand. So add two more trivial
776 // dimensions.
777 std::vector<int64> ones(output_base_shape.rank(), 1);
778 auto operand_window = window_util::MakeWindow(ones);
779 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
780 *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
781 conv_window.dimensions(i);
782 }
783
784 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
785 LiteralUtil::Zero(output_base_shape.element_type())));
786 auto resharded_operand_and_window =
787 lhs.ReshardAsWindowedInput(operand_window, target_operand_sharding, zero);
788 if (!resharded_operand_and_window.has_value()) {
789 return nullptr;
790 }
791 Window new_window;
792 for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
793 *new_window.add_dimensions() =
794 resharded_operand_and_window->shard_window.dimensions(
795 dnums.input_spatial_dimensions(i));
796 }
797
798 TF_ASSIGN_OR_RETURN(
799 auto sharded_conv,
800 create_sharded_conv(resharded_operand_and_window->sharded_input,
801 rhs.hlo(), b, new_window));
802
803 auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
804 if (!resharded_operand_and_window->dynamic_slice_index_on_output
805 .has_value()) {
806 CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
807 return sharded_conv;
808 }
809 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
810 shard_shape, sharded_conv,
811 *resharded_operand_and_window->dynamic_slice_index_on_output,
812 shard_shape.dimensions()));
813 }
814
815 // Partition convolution with only one kind of dims partitioned.
PartitionConvolutionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)816 StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
817 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
818 const HloSharding& output_sharding,
819 const std::function<StatusOr<HloInstruction*>(
820 HloInstruction*, HloInstruction*, SpmdBuilder*,
821 const Window& conv_window)>& create_sharded_conv,
822 const Window& conv_window, HloInstruction* original_hlo,
823 int64 num_partitions, const SpmdPartitionerOptions& options,
824 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
825 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
826
827 // Case 1: Handle depthwise convolution with batch group count or
828 // feature group count.
829 if (original_hlo->batch_group_count() > 1) {
830 TF_ASSIGN_OR_RETURN(
831 auto parallel_partitioned_conv,
832 PartitionConvolutionWithBatchGroupCount(
833 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
834 conv_window, original_hlo, num_partitions, b));
835 if (parallel_partitioned_conv) {
836 return parallel_partitioned_conv;
837 }
838 }
839
840 if (original_hlo->feature_group_count() > 1) {
841 TF_ASSIGN_OR_RETURN(
842 auto parallel_partitioned_conv,
843 PartitionConvolutionWithFeatureGroupCount(
844 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
845 conv_window, original_hlo, num_partitions, b));
846 if (parallel_partitioned_conv) {
847 return parallel_partitioned_conv;
848 }
849 }
850
851 // Case 2: both RHS and LHS are tiled.
852 // Handling cases where both operands' shardings are aligned. We check that
853 // the LHS batch dimension is not partitioned because it is mapped to the
854 // output feature dimension in aligned_rhs_sharding, which are not the same
855 // dimension.
856 if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
857 if (options.conv_halo_exchange_always_on_lhs) {
858 TF_ASSIGN_OR_RETURN(
859 auto partitioned_conv,
860 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
861 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
862 conv_window, original_hlo, partition_id, module, b));
863 if (partitioned_conv) {
864 return partitioned_conv;
865 }
866 } else {
867 TF_ASSIGN_OR_RETURN(
868 auto partitioned_conv,
869 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
870 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
871 conv_window, original_hlo, partition_id, module, b));
872
873 if (partitioned_conv) {
874 return partitioned_conv;
875 }
876 }
877 }
878
879 // Case 3: output is tiled.
880 if (!output_sharding.IsTileMaximal()) {
881 TF_ASSIGN_OR_RETURN(auto partitioned_conv,
882 PartitionConvolutionTiledOutput(
883 lhs, rhs, output_base_shape, output_sharding,
884 create_sharded_conv, conv_window, original_hlo, b));
885
886 if (partitioned_conv) {
887 return partitioned_conv;
888 }
889 }
890 return nullptr;
891 }
892
CreateShardedConvConvolution(const HloInstruction & conv,const dot_as_convolution_util::DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo,const Window & conv_window)893 StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
894 const HloInstruction& conv,
895 const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
896 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
897 const Window& conv_window) {
898 CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
899 const auto& conv_dnums = conv.convolution_dimension_numbers();
900 auto window = conv.window();
901 for (const auto& dim : dot_dnums.batch_dims) {
902 auto wd = window.mutable_dimensions(dim.spatial_dim);
903 wd->set_size(sharded_lhs_hlo->shape().dimensions(
904 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
905 wd->set_stride(std::max<int64>(1, wd->size() - 1));
906 wd->set_base_dilation(wd->size());
907 }
908 for (const auto& dim : dot_dnums.contracting_dims) {
909 if (dim.spatial_dim < 0) {
910 continue;
911 }
912 auto wd = window.mutable_dimensions(dim.spatial_dim);
913 wd->set_size(sharded_lhs_hlo->shape().dimensions(
914 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
915 }
916 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
917 if (dim.spatial_dim < 0) {
918 continue;
919 }
920 auto wd = window.mutable_dimensions(dim.spatial_dim);
921 wd->set_size(sharded_rhs_hlo->shape().dimensions(
922 conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
923 wd->set_padding_high(wd->size() - 1);
924 wd->set_padding_low(wd->size() - 1);
925 }
926
927 for (const auto& dim : dot_dnums.conv_spatial_dims) {
928 auto wd = window.mutable_dimensions(dim.spatial_dim);
929 const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
930 wd->set_size(new_window_dimension.size());
931 wd->set_padding_high(new_window_dimension.padding_high());
932 wd->set_padding_low(new_window_dimension.padding_low());
933 }
934
935 int64 feature_group_count = conv.feature_group_count();
936 if (feature_group_count > 1) {
937 feature_group_count = sharded_lhs_hlo->shape().dimensions(
938 conv_dnums.input_feature_dimension()) /
939 sharded_rhs_hlo->shape().dimensions(
940 conv_dnums.kernel_input_feature_dimension());
941 }
942
943 int64 batch_group_count = conv.batch_group_count();
944 if (batch_group_count > 1) {
945 batch_group_count =
946 sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
947 }
948
949 TF_ASSIGN_OR_RETURN(
950 Shape sharded_conv_shape,
951 ShapeInference::InferConvolveShape(
952 sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
953 feature_group_count, batch_group_count, window, conv_dnums,
954 /*preferred_element_type=*/conv.shape().element_type()));
955 *sharded_conv_shape.mutable_layout() = conv.shape().layout();
956 return HloInstruction::CreateConvolve(
957 sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
958 batch_group_count, window, conv_dnums, conv.precision_config());
959 }
960
961 } // namespace
962
963 // Partition convolution.
PartitionConvolution(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64 num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)964 StatusOr<HloInstruction*> PartitionConvolution(
965 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
966 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
967 const std::function<StatusOr<HloInstruction*>(
968 HloInstruction*, HloInstruction*, SpmdBuilder*,
969 const Window& conv_window)>& create_sharded_conv,
970 const Window& conv_window, HloInstruction* original_hlo,
971 int64 num_partitions, const SpmdPartitionerOptions& options,
972 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
973 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
974
975 TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
976 PartitionConvolutionBaseCase(
977 lhs, rhs, output_base_shape, output_sharding,
978 create_sharded_conv, conv_window, original_hlo,
979 num_partitions, options, partition_id, module, b));
980 if (try_partitioned_conv) {
981 return try_partitioned_conv;
982 }
983
984 return nullptr;
985 }
986
HandleConvolution(HloInstruction * hlo)987 Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
988 auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
989 spmd::DotConvDimsMapping mapping;
990 for (const auto& dims : dims_info.batch_dims) {
991 mapping.batch_dims.emplace_back();
992 mapping.batch_dims.back().lhs = dims.lhs;
993 mapping.batch_dims.back().rhs = dims.rhs;
994 mapping.batch_dims.back().output = dims.output;
995 mapping.batch_dims.back().spatial = dims.spatial_dim;
996 }
997 for (const auto& dims : dims_info.contracting_dims) {
998 mapping.contracting_dims.emplace_back();
999 mapping.contracting_dims.back().lhs = dims.lhs;
1000 mapping.contracting_dims.back().rhs = dims.rhs;
1001 mapping.contracting_dims.back().output = dims.output;
1002 mapping.contracting_dims.back().spatial = dims.spatial_dim;
1003 }
1004 for (const auto& dims : dims_info.lhs_non_contracting_dims) {
1005 mapping.lhs_non_contracting_dims.emplace_back();
1006 mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
1007 mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
1008 mapping.lhs_non_contracting_dims.back().output = dims.output;
1009 mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1010 }
1011 for (const auto& dims : dims_info.rhs_non_contracting_dims) {
1012 mapping.rhs_non_contracting_dims.emplace_back();
1013 mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
1014 mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
1015 mapping.rhs_non_contracting_dims.back().output = dims.output;
1016 mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1017 }
1018 for (const auto& dims : dims_info.conv_spatial_dims) {
1019 mapping.conv_spatial_dims.emplace_back();
1020 mapping.conv_spatial_dims.back().lhs = dims.lhs;
1021 mapping.conv_spatial_dims.back().rhs = dims.rhs;
1022 mapping.conv_spatial_dims.back().output = dims.output;
1023 mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
1024 }
1025 auto create_sharded_conv =
1026 [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
1027 spmd::SpmdBuilder* b,
1028 const Window& conv_window) -> StatusOr<HloInstruction*> {
1029 if (dims_info.conv_spatial_dims.empty()) {
1030 TF_ASSIGN_OR_RETURN(
1031 auto sharded_conv,
1032 dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
1033 *hlo, dims_info, lhs_hlo, rhs_hlo));
1034 return b->AddInstruction(std::move(sharded_conv));
1035 } else {
1036 TF_ASSIGN_OR_RETURN(auto sharded_conv,
1037 CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
1038 rhs_hlo, conv_window));
1039 return b->AddInstruction(std::move(sharded_conv));
1040 }
1041 };
1042
1043 return HandleDotHelper(hlo, mapping, create_sharded_conv);
1044 }
1045
1046 } // namespace spmd
1047 } // namespace xla
1048