1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17
18 #include <algorithm>
19 #include <memory>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43
44 namespace xla {
45 namespace spmd {
46
HasReplicatedSharding(const HloSharding & sharding)47 bool HasReplicatedSharding(const HloSharding& sharding) {
48 if (sharding.IsTuple()) {
49 return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
50 }
51 return sharding.IsReplicated();
52 }
53
CreateConstant(const Shape & shape,Literal value,SpmdBuilder * b)54 HloInstruction* CreateConstant(const Shape& shape, Literal value,
55 SpmdBuilder* b) {
56 if (shape.IsTuple()) {
57 std::vector<HloInstruction*> elements;
58 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
59 elements.push_back(CreateConstant(
60 ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b));
61 }
62 return b->AddInstruction(HloInstruction::CreateTuple(elements));
63 }
64
65 CHECK(
66 ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type()));
67 auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value)));
68 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
69 }
70
CreateZero(const Shape & shape,SpmdBuilder * b)71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
72 if (shape.IsTuple()) {
73 std::vector<HloInstruction*> elements;
74 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
75 elements.push_back(
76 CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
77 }
78 return b->AddInstruction(HloInstruction::CreateTuple(elements));
79 }
80
81 if (shape.IsToken()) {
82 return b->AddInstruction(HloInstruction::CreateToken());
83 }
84 auto zero = b->AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
86 if (shape.rank() == 0) {
87 return zero;
88 }
89 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
90 }
91
CreateOne(const Shape & shape,SpmdBuilder * b)92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) {
93 if (shape.IsTuple()) {
94 std::vector<HloInstruction*> elements;
95 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
96 elements.push_back(
97 CreateOne(ShapeUtil::GetTupleElementShape(shape, i), b));
98 }
99 return b->AddInstruction(HloInstruction::CreateTuple(elements));
100 }
101
102 if (shape.IsToken()) {
103 return b->AddInstruction(HloInstruction::CreateToken());
104 }
105 auto one = b->AddInstruction(
106 HloInstruction::CreateConstant(LiteralUtil::One(shape.element_type())));
107 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, one, {}));
108 }
109
MakeBinaryAdd(PrimitiveType type,HloModule * module)110 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
111 HloComputation::Builder sum_b("add");
112 auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
113 /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
114 auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
115 /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
116 if (type == PRED) {
117 sum_b.AddInstruction(HloInstruction::CreateBinary(
118 ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
119 } else {
120 sum_b.AddInstruction(HloInstruction::CreateBinary(
121 ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
122 }
123 HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
124 return reduction;
125 }
126
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
128 if (sharding.IsTuple()) {
129 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
130 if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
131 sharding.GetSubSharding(shape, {i}))) {
132 return false;
133 }
134 }
135 }
136
137 if (sharding.IsTileMaximal()) {
138 return sharding.IsReplicated();
139 }
140 for (int64 i = 0; i < shape.dimensions_size(); ++i) {
141 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
142 return false;
143 }
144 }
145 return true;
146 }
147
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
149 if (sharding.IsTuple()) {
150 std::vector<Shape> subshapes;
151 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
152 subshapes.push_back(
153 MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
154 sharding.GetSubSharding(shape, {i})));
155 }
156 return ShapeUtil::MakeTupleShape(subshapes);
157 }
158 return sharding.TileShape(shape);
159 }
160
ShapeSizeInBytes(const Shape & shape)161 int64 ShapeSizeInBytes(const Shape& shape) {
162 return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
163 ShapeUtil::ElementsIn(shape);
164 }
165
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64 partition_id)166 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
167 const HloSharding& sharding,
168 int64 partition_id) {
169 if (sharding.IsTuple()) {
170 std::vector<Shape> subshapes;
171 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
172 subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
173 ShapeUtil::GetTupleElementShape(shape, i),
174 sharding.GetSubSharding(shape, {i}), partition_id));
175 }
176 return ShapeUtil::MakeTupleShape(subshapes);
177 }
178
179 if (sharding.IsReplicated()) {
180 return shape;
181 }
182 if (sharding.IsTileMaximal()) {
183 if (partition_id == *sharding.UniqueDevice()) {
184 return shape;
185 }
186 return ShapeUtil::MakeTupleShape({});
187 }
188
189 auto partition_shape = shape;
190 std::vector<int64> tile_offset =
191 sharding.TileOffsetForDevice(shape, partition_id);
192 std::vector<int64> tile_limit =
193 sharding.TileLimitForDevice(shape, partition_id);
194 for (int64 i = 0; i < tile_offset.size(); ++i) {
195 if (sharding.UsesDevice(partition_id)) {
196 partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
197 } else {
198 partition_shape.set_dimensions(i, 0);
199 }
200 }
201 return partition_shape;
202 }
203
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64> dims)204 std::vector<HloInstruction*> MakePartitionOffsets(
205 const Shape& shape, const HloSharding& sharding,
206 HloInstruction* partition_id, SpmdBuilder* b,
207 absl::Span<const int64> dims) {
208 CHECK(!shape.IsTuple());
209
210 std::vector<std::vector<int32>> offset_arrays(shape.rank());
211 for (int64 i = 0; i < shape.rank(); ++i) {
212 offset_arrays[i].resize(sharding.tile_assignment().num_elements());
213 }
214 auto shard_shape = MakePartitionedShape(shape, sharding);
215 sharding.tile_assignment().Each(
216 [&](absl::Span<const int64> indices, int64 device) {
217 for (int64 i = 0; i < shape.rank(); ++i) {
218 offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
219 }
220 });
221 std::vector<HloInstruction*> offsets;
222 for (int64 i = 0; i < shape.rank(); ++i) {
223 if (sharding.tile_assignment().dim(i) == 1 ||
224 (!dims.empty() && !absl::c_linear_search(dims, i))) {
225 offsets.push_back(b->AddInstruction(
226 HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
227 } else {
228 auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
229 LiteralUtil::CreateR1<int32>(offset_arrays[i])));
230 auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
231 ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
232 offsets.push_back(b->AddInstruction(
233 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
234 }
235 }
236 return offsets;
237 }
238
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)239 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
241 CHECK(!sharding.IsTileMaximal());
242 auto dimensions = sharding.tile_assignment().dimensions();
243 if (sharding.ReplicateOnLastTileDim()) {
244 dimensions.pop_back();
245 }
246 auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
247 return MakePartitionOffsets(table_shape, sharding, partition_id, b);
248 }
249
PadToShape(HloInstruction * hlo,const Shape & padded_shape,SpmdBuilder * b,HloComputation * computation)250 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
251 SpmdBuilder* b, HloComputation* computation) {
252 CHECK(b == nullptr || computation == nullptr);
253 if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
254 return hlo;
255 }
256 PaddingConfig padding_config;
257 for (int64 i = 0; i < padded_shape.rank(); ++i) {
258 auto padding_config_dim = padding_config.add_dimensions();
259 padding_config_dim->set_edge_padding_low(0);
260 padding_config_dim->set_interior_padding(0);
261 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
262 hlo->shape().dimensions(i));
263 }
264 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
265 if (b == nullptr) {
266 return computation->AddInstruction(std::move(to_add));
267 }
268 return b->AddInstruction(std::move(to_add));
269 };
270 auto zero = add_hlo(HloInstruction::CreateConstant(
271 LiteralUtil::Zero(hlo->shape().element_type())));
272 return add_hlo(
273 HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
274 }
275
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)276 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
277 const HloSharding& sharding) {
278 if (sharding.IsTileMaximal()) {
279 return base_shape;
280 }
281 if (EvenlyPartitions(base_shape, sharding)) {
282 return base_shape;
283 }
284 auto shard_shape = MakePartitionedShape(base_shape, sharding);
285 Shape padded_base_shape = base_shape;
286 for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
287 padded_base_shape.set_dimensions(
288 i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
289 }
290 return padded_base_shape;
291 }
292
PadBaseShapeBeforeUnevenTiledSharding(HloInstruction * hlo,const HloSharding & sharding,SpmdBuilder * b)293 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
295 auto padded_base_shape =
296 GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
297 if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
298 return hlo;
299 }
300 return PadToShape(hlo, padded_base_shape, b);
301 }
302
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)303 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
304 const HloSharding& partial_sharding, const HloSharding& target_sharding) {
305 if (!partial_sharding.ReplicateOnLastTileDim()) {
306 return absl::nullopt;
307 }
308 int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1;
309 int64 target_rank = target_sharding.tile_assignment().num_dimensions() -
310 (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
311 if (target_rank != rank) {
312 return absl::nullopt;
313 }
314
315 absl::flat_hash_map<int64, int64> device_to_replication_group;
316 partial_sharding.tile_assignment().Each(
317 [&](absl::Span<const int64> indices, int64 device) {
318 int64 gid = 0;
319 for (int64 i = 0; i < rank; ++i) {
320 gid *= partial_sharding.tile_assignment().dim(i);
321 gid += indices[i];
322 }
323 device_to_replication_group[device] = gid;
324 });
325
326 // A dimension is expanded when target_tile_size > partial_tile_size and
327 // target_tile_size % partial_tile_size == 0.
328 // expand_tile_dims_positions is the index of the expand_dim.
329 std::vector<int64> expand_tile_dims_indices(rank, -1);
330 // expand_tile_size = target_tile_size / partial_tile_size.
331 std::vector<int64> expand_tile_sizes;
332 int num_expand_dims = 0;
333 for (int64 dim = 0; dim < rank; dim++) {
334 int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim);
335 int64 target_tile_size = target_sharding.tile_assignment().dim(dim);
336 if (target_tile_size % partial_tile_size != 0 ||
337 target_tile_size < partial_tile_size) {
338 return absl::nullopt;
339 }
340
341 if (target_tile_size > partial_tile_size) {
342 expand_tile_dims_indices[dim] = num_expand_dims++;
343 expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
344 }
345 }
346
347 // Reshape the partial replicate tile_dimensions.
348 int64 num_target_replication = 1;
349 if (target_sharding.ReplicateOnLastTileDim()) {
350 num_target_replication =
351 target_sharding.tile_assignment().dimensions().back();
352 }
353 auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
354 int64 num_replication = reshape_dimensions.back();
355 if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
356 num_replication % num_target_replication != 0) {
357 return absl::nullopt;
358 }
359
360 reshape_dimensions.pop_back();
361 reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
362 expand_tile_sizes.end());
363
364 if (target_sharding.ReplicateOnLastTileDim()) {
365 reshape_dimensions.push_back(num_target_replication);
366 }
367
368 auto reshape_tile_assignment = partial_sharding.tile_assignment();
369 reshape_tile_assignment.Reshape(reshape_dimensions);
370
371 // Transpose.
372 std::vector<int64> perm;
373 perm.reserve(rank + expand_tile_sizes.size());
374 for (int64 dim = 0; dim < rank; dim++) {
375 perm.emplace_back(dim);
376 if (expand_tile_dims_indices[dim] > -1) {
377 perm.emplace_back(expand_tile_dims_indices[dim] + rank);
378 }
379 }
380 auto transpose_sharding = hlo_sharding_util::TransposeSharding(
381 target_sharding.ReplicateOnLastTileDim()
382 ? HloSharding::PartialTile(reshape_tile_assignment)
383 : HloSharding::Tile(reshape_tile_assignment),
384 perm);
385
386 // Reshape to target shape
387 auto transpose_tile_assignment = transpose_sharding.tile_assignment();
388 transpose_tile_assignment.Reshape(
389 target_sharding.tile_assignment().dimensions());
390
391 bool groups_matching = true;
392 target_sharding.tile_assignment().Each(
393 [&](absl::Span<const int64> indices, int64 device) {
394 if (device_to_replication_group[device] !=
395 device_to_replication_group[transpose_tile_assignment(indices)]) {
396 groups_matching = false;
397 }
398 });
399
400 if (groups_matching) {
401 return target_sharding;
402 }
403 return target_sharding.ReplicateOnLastTileDim()
404 ? HloSharding::PartialTile(transpose_tile_assignment)
405 : HloSharding::Tile(transpose_tile_assignment);
406 }
407
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)408 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
409 HloInstruction* hlo, const Shape& base_shape,
410 const HloSharding& src_sharding, const HloSharding& dst_sharding,
411 const std::vector<int64>& replicate_dims,
412 const SPMDCollectiveOpsCreator& collective_ops_creator,
413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
414 // Source is tile sharding.
415 auto padded_src_shape =
416 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
417 // Target is partial replicate.
418 auto padded_dst_shape =
419 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
420 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
421 return hlo;
422 }
423
424 auto partition_ordinals =
425 MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
426
427 auto result = hlo;
428 auto hlo_shape = hlo->shape();
429 for (auto dim : replicate_dims) {
430 int64 dst_shard_count = dst_sharding.tile_assignment().dim(dim);
431 int64 src_per_shard_size =
432 padded_src_shape.dimensions(dim) / dst_shard_count;
433 // Calculate per shard size using the sharding to compare if dst_sharding
434 // needs more padding at the end.
435 int64 dst_per_shard_size =
436 padded_dst_shape.dimensions(dim) / dst_shard_count;
437
438 // If src per shard doesn't have redudant data.
439 if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
440 continue;
441 }
442
443 // If src_per_shard * replicate_factor > dst_per_shard , need to
444 // re-distribute the data between each shard using collective permute. For
445 // example, if dimension size is 6 and shard 4 ways in the src but needs to
446 // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
447 // while 2 way sharding has 3 elements, the last element in the first shard
448 // will be sliced out. re-distribution is needed.
449 //
450 // 1. Calculate left_halo size.
451 // left-halo size is
452 // (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
453 int64 replicate_factor = src_sharding.tile_assignment().dim(dim) /
454 dst_sharding.tile_assignment().dim(dim);
455 OffsetCalculation left_halo_size_function =
456 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
457 src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
458
459 // 2. Calculate right_halo size.
460 // right-halo size is 0
461 OffsetCalculation right_halo_size_function =
462 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
463
464 auto concat = result;
465 // 3. Halo exchange.
466 auto halo_exchange_result = ExchangeHalo(
467 result, left_halo_size_function, right_halo_size_function, dim,
468 src_sharding, collective_ops_creator, next_channel_id, b);
469
470 if (halo_exchange_result.has_value()) {
471 concat = halo_exchange_result.value();
472 } else {
473 return absl::nullopt;
474 }
475
476 // 4. Slice the valid result.
477 // Slice offset is
478 // (dst_shard_count - i - 1) *
479 // (src_per_shard_size - dst_per_shard_size)
480 // i is the index in dst_sharindg.
481 auto zero_s32 = b->AddInstruction(
482 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
483 OffsetCalculation start_offset_on_padded_concat_calculation =
484 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
485 dst_per_shard_size - src_per_shard_size,
486 (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
487 1));
488 auto slice_shape = concat->shape();
489 slice_shape.set_dimensions(dim,
490 padded_src_shape.dimensions(dim) /
491 src_sharding.tile_assignment().dim(dim));
492 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
493 zero_s32);
494 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
495 partition_ordinals[dim], b);
496 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
497 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
498 }
499 return result;
500 }
501
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)502 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
503 HloInstruction* hlo, const Shape& base_shape,
504 const HloSharding& src_sharding, const HloSharding& dst_sharding,
505 const std::vector<int64>& expand_tile_dims,
506 const SPMDCollectiveOpsCreator& collective_ops_creator,
507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
508 auto padded_src_shape =
509 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
510 auto padded_dst_shape =
511 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
512 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
513 return hlo;
514 }
515
516 auto partition_ordinals =
517 MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
518
519 HloInstruction* result = hlo;
520 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
521 LiteralUtil::Zero(hlo->shape().element_type())));
522 std::vector<int64> expand_dims_without_halo_exchange;
523 // Pad the dimensions needs halo exchange and record the padded dims that
524 // won't need halo exchange.
525 for (auto dim : expand_tile_dims) {
526 int64 src_shard_count = src_sharding.tile_assignment().dim(dim);
527 int64 src_per_shard_size =
528 padded_src_shape.dimensions(dim) / src_shard_count;
529 // Calculate per shard size using the sharding to compare if dst_sharding
530 // needs more padding at the end.
531 int64 dst_per_shard_size =
532 padded_dst_shape.dimensions(dim) / src_shard_count;
533
534 // If dst_sharding doesn't need more padding at the end.
535 if (src_per_shard_size >= dst_per_shard_size) {
536 continue;
537 }
538 // If src sharding at this dimension is not partitoned, simply pad to
539 // the desired shape.
540 if (src_shard_count == 1) {
541 expand_dims_without_halo_exchange.emplace_back(dim);
542 continue;
543 }
544
545 // If dst_padding needs more padding at the end, need to re-distribute the
546 // data between each shard using collective permute.
547 // For example, if dimension size is 6 and shard 2 ways in the src but
548 // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
549 // and has 2 elements at each shard, while 2 way sharding has 3 elements
550 // in each shard, re-distribution is needed.
551 //
552 // 1. Calculate left_halo size.
553 // left-halo size is 0
554 OffsetCalculation left_halo_size_function =
555 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
556
557 // 2. Calculate right_halo size.
558 // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
559 OffsetCalculation right_halo_size_function =
560 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
561 dst_per_shard_size - src_per_shard_size,
562 dst_per_shard_size - src_per_shard_size, 1));
563
564 auto concat = result;
565 // 3. Halo exchange.
566 auto halo_exchange_result = ExchangeHalo(
567 result, left_halo_size_function, right_halo_size_function, dim,
568 src_sharding, collective_ops_creator, next_channel_id, b);
569
570 if (halo_exchange_result.has_value()) {
571 concat = halo_exchange_result.value();
572 } else {
573 return absl::nullopt;
574 }
575
576 // 4. Pad.
577 std::vector<int64> zero_padding(concat->shape().rank());
578 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
579 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
580 int64 max_right_halo_size =
581 right_halo_size_function.MaxInRange(0, src_shard_count - 1);
582 pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max(
583 int64{0}, padded_dst_shape.dimensions(dim) -
584 padded_src_shape.dimensions(dim) - max_right_halo_size));
585 auto padded_concat_shape = ShapeInference::InferPadShape(
586 concat->shape(), zero->shape(), pad_config)
587 .ValueOrDie();
588 concat = b->AddInstruction(HloInstruction::CreatePad(
589 padded_concat_shape, concat, zero, pad_config));
590
591 // 5. Slice the valid result.
592 // Slice offset is (D-S) * i
593 auto zero_s32 = b->AddInstruction(
594 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
595 OffsetCalculation start_offset_on_padded_concat_calculation =
596 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
597 dst_per_shard_size - src_per_shard_size, 0, 1));
598 auto slice_shape = concat->shape();
599 slice_shape.set_dimensions(dim, dst_per_shard_size);
600 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
601 zero_s32);
602 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
603 partition_ordinals[dim], b);
604 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
605 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
606 }
607
608 // Pad other dimensions that won't need halo exchange with a single pad.
609 if (!expand_dims_without_halo_exchange.empty()) {
610 std::vector<int64> zero_padding(result->shape().rank());
611 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
612
613 auto padded_shape = result->shape();
614 for (auto dim : expand_dims_without_halo_exchange) {
615 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
616 pad_config.mutable_dimensions(dim)->set_edge_padding_high(
617 padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
618 padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
619 padded_dst_shape.dimensions(dim) -
620 padded_src_shape.dimensions(dim));
621 }
622 result = b->AddInstruction(
623 HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
624 }
625
626 return result;
627 }
628
UniqueTiledDim(const HloSharding & sharding)629 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
630 if (sharding.IsTileMaximal()) {
631 return absl::nullopt;
632 }
633 int64 dim = -1;
634 int64 rank = sharding.ReplicateOnLastTileDim()
635 ? sharding.tile_assignment().num_dimensions() - 1
636 : sharding.tile_assignment().num_dimensions();
637 for (int64 i = 0; i < rank; ++i) {
638 if (sharding.tile_assignment().dim(i) > 1) {
639 if (dim != -1) {
640 return absl::nullopt;
641 }
642 dim = i;
643 }
644 }
645 CHECK_NE(dim, -1);
646 return dim;
647 }
648
MultiplyAddDivideOffsetCalculation(int64 multiplier,int64 offset,int64 divisor)649 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
650 int64 multiplier, int64 offset, int64 divisor)
651 : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
652 CHECK_GT(divisor_, 0);
653 Simplify();
654 }
655
operator -(const MultiplyAddDivideOffsetCalculation & other) const656 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
657 const MultiplyAddDivideOffsetCalculation& other) const {
658 if (divisor_ == 1 && other.divisor_ == 1) {
659 return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
660 multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
661 }
662 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
663 }
664
Simplify()665 void MultiplyAddDivideOffsetCalculation::Simplify() {
666 // We could simplify the calculation when multiplier is a multiple of
667 // divisor_. However, when offset_ is not a multiple of divisor_, we must
668 // make sure that offset_ and multiplier_ are both non-negative or both
669 // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1.
670 if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
671 (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
672 multiplier_ /= divisor_;
673 offset_ /= divisor_;
674 divisor_ = 1;
675 }
676 }
677
Calculate(int64 shard_ordinal) const678 int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const {
679 return (shard_ordinal * multiplier_ + offset_) / divisor_;
680 }
681
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const682 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
683 HloInstruction* shard_ordinal, SpmdBuilder* b) const {
684 auto scalar_shape = ShapeUtil::MakeShape(S32, {});
685 if (multiplier_ == 0) {
686 return b->AddInstruction(HloInstruction::CreateConstant(
687 LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
688 }
689 HloInstruction* result = shard_ordinal;
690 if (multiplier_ != 1) {
691 result = b->AddInstruction(HloInstruction::CreateBinary(
692 scalar_shape, HloOpcode::kMultiply, shard_ordinal,
693 b->AddInstruction(HloInstruction::CreateConstant(
694 LiteralUtil::CreateR0<int32>(multiplier_)))));
695 }
696 if (offset_ != 0) {
697 auto offset = b->AddInstruction(
698 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
699 result = b->AddInstruction(HloInstruction::CreateBinary(
700 scalar_shape, HloOpcode::kAdd, result, offset));
701 }
702 if (divisor_ != 1) {
703 auto divisor = b->AddInstruction(
704 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
705 result = b->AddInstruction(HloInstruction::CreateBinary(
706 scalar_shape, HloOpcode::kDivide, result, divisor));
707 }
708 return result;
709 }
710
MaxInRange(int64 start_ordinal,int64 limit_ordinal) const711 int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
712 int64 start_ordinal, int64 limit_ordinal) const {
713 int64 max = Calculate(start_ordinal);
714 for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
715 max = std::max(max, Calculate(i));
716 }
717 return max;
718 }
719
operator =(const OffsetCalculation & other)720 OffsetCalculation& OffsetCalculation::operator=(
721 const OffsetCalculation& other) {
722 opcode_ = other.opcode_;
723 copy_from_ = other.copy_from_;
724 if (opcode_ != HloOpcode::kCopy) {
725 lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
726 rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
727 }
728 return *this;
729 }
730
IsConstant() const731 bool OffsetCalculation::IsConstant() const {
732 if (opcode_ == HloOpcode::kCopy) {
733 return copy_from_.IsConstant();
734 }
735 if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
736 return true;
737 }
738 return lhs_->IsConstant() && rhs_->IsConstant();
739 }
740
operator -(const OffsetCalculation & other) const741 OffsetCalculation OffsetCalculation::operator-(
742 const OffsetCalculation& other) const {
743 if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
744 return copy_from_ - other.copy_from_;
745 }
746 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
747 }
748
operator ==(const OffsetCalculation & other) const749 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
750 if (opcode_ != other.opcode_) {
751 return false;
752 }
753 if (opcode_ == HloOpcode::kCopy) {
754 return copy_from_ == other.copy_from_;
755 }
756 return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
757 }
758
Calculate(int64 shard_ordinal) const759 int64 OffsetCalculation::Calculate(int64 shard_ordinal) const {
760 switch (opcode_) {
761 case HloOpcode::kCopy:
762 return copy_from_.Calculate(shard_ordinal);
763 case HloOpcode::kSubtract:
764 return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
765 case HloOpcode::kMultiply:
766 return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
767 default:
768 LOG(FATAL) << "Should not happen";
769 }
770 }
771
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const772 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
773 SpmdBuilder* b) const {
774 if (opcode_ == HloOpcode::kCopy) {
775 return copy_from_.Calculate(shard_ordinal, b);
776 }
777 auto lhs = lhs_->Calculate(shard_ordinal, b);
778 auto rhs = rhs_->Calculate(shard_ordinal, b);
779 return b->AddInstruction(
780 HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
781 }
782
MaxInRange(int64 start_ordinal,int64 limit_ordinal) const783 int64 OffsetCalculation::MaxInRange(int64 start_ordinal,
784 int64 limit_ordinal) const {
785 if (IsConstant()) {
786 return Calculate(start_ordinal);
787 }
788 if (opcode_ == HloOpcode::kCopy) {
789 return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
790 }
791 int64 max = Calculate(start_ordinal);
792 for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
793 max = std::max(max, Calculate(i));
794 }
795 return max;
796 }
797
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64 dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)798 absl::optional<HloInstruction*> ExchangeHalo(
799 HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
800 const OffsetCalculation& right_halo_size_function, int64 dim,
801 const HloSharding& target,
802 const SPMDCollectiveOpsCreator& collective_ops_creator,
803 int64* next_channel_id, SpmdBuilder* b) {
804 int64 input_shard_size = hlo->shape().dimensions(dim);
805 int64 shard_count = target.tile_assignment().dim(dim);
806
807 std::vector<HloInstruction*> concat_pieces;
808
809 int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
810 int64 max_right_halo_size =
811 right_halo_size_function.MaxInRange(0, shard_count - 1);
812 if (max_left_halo_size + max_right_halo_size + input_shard_size >=
813 input_shard_size * shard_count &&
814 (max_left_halo_size > input_shard_size ||
815 max_right_halo_size > input_shard_size)) {
816 return absl::nullopt;
817 }
818 // Left halo.
819 for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
820 --i) {
821 std::vector<std::pair<int64, int64>> source_target_pairs;
822 target.tile_assignment().Each(
823 [&](absl::Span<const int64> indices, int64 device) {
824 if (indices[dim] > i) {
825 std::vector<int64> source_indices(indices.begin(), indices.end());
826 source_indices[dim] -= i + 1;
827 source_target_pairs.emplace_back(
828 target.tile_assignment()(source_indices), device);
829 }
830 });
831 int64 halo_size =
832 std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
833 auto halo_shape = hlo->shape();
834 auto source_halo_slice = hlo;
835 if (halo_size != hlo->shape().dimensions(dim)) {
836 halo_shape.set_dimensions(dim, halo_size);
837 std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
838 halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
839 std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
840 source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
841 halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
842 halo_slice_strides));
843 }
844 auto left_halo =
845 collective_ops_creator.create_cross_partition_collective_permute(
846 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
847 concat_pieces.push_back(left_halo);
848 }
849
850 concat_pieces.push_back(hlo);
851
852 // Right halo.
853 for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
854 ++i) {
855 std::vector<std::pair<int64, int64>> source_target_pairs;
856 target.tile_assignment().Each(
857 [&](absl::Span<const int64> indices, int64 device) {
858 if (indices[dim] > i) {
859 std::vector<int64> target_indices(indices.begin(), indices.end());
860 target_indices[dim] -= i + 1;
861 source_target_pairs.emplace_back(
862 device, target.tile_assignment()(target_indices));
863 }
864 });
865 int64 halo_size =
866 std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
867 auto halo_shape = hlo->shape();
868 HloInstruction* source_halo_slice = hlo;
869 if (halo_size != halo_shape.dimensions(dim)) {
870 halo_shape.set_dimensions(dim, halo_size);
871 std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
872 std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
873 source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
874 halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
875 halo_slice_strides));
876 }
877 auto right_halo =
878 collective_ops_creator.create_cross_partition_collective_permute(
879 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
880 concat_pieces.push_back(right_halo);
881 }
882
883 auto concat = hlo;
884 // Concat with halos/padding.
885 if (concat_pieces.size() > 1) {
886 auto concat_shape = hlo->shape();
887 int64 concat_dim_size = 0;
888 for (auto piece : concat_pieces) {
889 concat_dim_size += piece->shape().dimensions(dim);
890 }
891 concat_shape.set_dimensions(dim, concat_dim_size);
892 concat = b->AddInstruction(
893 HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
894 }
895
896 return concat;
897 }
898
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)899 absl::optional<HloInstruction*> ExchangeHalo(
900 HloInstruction* hlo,
901 std::vector<OffsetCalculation> left_halo_size_functions,
902 std::vector<OffsetCalculation> right_halo_size_functions,
903 const HloSharding& target,
904 const SPMDCollectiveOpsCreator& collective_ops_creator,
905 int64* next_channel_id, SpmdBuilder* b) {
906 CHECK(left_halo_size_functions.size() == hlo->shape().rank());
907 CHECK(right_halo_size_functions.size() == hlo->shape().rank());
908
909 HloInstruction* visiting_hlo = hlo;
910 for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
911 auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
912 right_halo_size_functions[dim], dim, target,
913 collective_ops_creator, next_channel_id, b);
914 if (!concat) {
915 return absl::nullopt;
916 }
917 visiting_hlo = *concat;
918 }
919 return visiting_hlo;
920 }
921
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64 explicit_left_padding_on_full_shape,int64 padded_full_shape_size,int64 shard_size_with_halo,int64 dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)922 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
923 HloInstruction* hlo, const Shape& base_shape,
924 const OffsetCalculation& left_halo_size_function,
925 const OffsetCalculation& right_halo_size_function,
926 int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
927 int64 shard_size_with_halo, int64 dim, const HloSharding& target,
928 HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
929 HloInstruction* partition_ordinal,
930 const SPMDCollectiveOpsCreator& collective_ops_creator,
931 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
932 auto halo_exchange_result =
933 ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
934 target, collective_ops_creator, next_channel_id, b);
935 if (!halo_exchange_result) {
936 return absl::nullopt;
937 }
938 auto concat = *halo_exchange_result;
939 int64 shard_count = target.tile_assignment().dim(dim);
940 int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
941
942 // Now we determine if we need extra padding after the concat.
943 //
944 // The max of halo size or the first shard's explicit left padding.
945 int64 max_left_halo_or_padding_size =
946 std::max(std::max(int64{0}, max_left_halo_size),
947 explicit_left_padding_on_full_shape);
948 // The calculation that returns the dynamic slice index for a shard on the
949 // padded concat, which is the difference between
950 // max_left_halo_or_padding_size and its left halo size.
951 auto start_offset_on_padded_concat_calculation =
952 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
953 0, max_left_halo_or_padding_size, 1)) -
954 left_halo_size_function;
955
956 // See if we need to pad the concat before dynamic slice.
957 int64 extra_left_padding =
958 std::max(int64{0}, max_left_halo_or_padding_size -
959 std::max(int64{0}, max_left_halo_size));
960 int64 extra_right_padding =
961 start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
962 shard_size_with_halo - concat->shape().dimensions(dim) -
963 extra_left_padding;
964 extra_right_padding = std::max(int64{0}, extra_right_padding);
965 if (extra_left_padding > 0 || extra_right_padding > 0) {
966 PaddingConfig padding_config;
967 auto padded_concat_shape = concat->shape();
968 for (int64 i = 0; i < base_shape.rank(); ++i) {
969 auto padding_config_dim = padding_config.add_dimensions();
970 padding_config_dim->set_interior_padding(0);
971 padding_config_dim->set_edge_padding_low(0);
972 padding_config_dim->set_edge_padding_high(0);
973 if (i != dim) {
974 continue;
975 }
976 padding_config_dim->set_edge_padding_low(extra_left_padding);
977 padding_config_dim->set_edge_padding_high(extra_right_padding);
978 padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
979 extra_left_padding +
980 extra_right_padding);
981 }
982 concat = b->AddInstruction(HloInstruction::CreatePad(
983 padded_concat_shape, concat, pad_value, padding_config));
984 }
985
986 auto valid_slice = concat;
987 if (shard_size_with_halo != concat->shape().dimensions(dim)) {
988 // Concat is bigger than the shard shape, so we need a dynamic slice.
989 CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
990 auto slice_shape = concat->shape();
991 slice_shape.set_dimensions(dim, shard_size_with_halo);
992
993 if (left_halo_size_function.IsConstant() &&
994 left_halo_size_function.Calculate(0) ==
995 explicit_left_padding_on_full_shape) {
996 std::vector<int64> start_indices(slice_shape.rank(), 0);
997 std::vector<int64> strides(slice_shape.rank(), 1);
998 valid_slice = b->AddInstruction(
999 HloInstruction::CreateSlice(slice_shape, concat, start_indices,
1000 slice_shape.dimensions(), strides));
1001 } else {
1002 auto zero = b->AddInstruction(
1003 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
1004 std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
1005 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
1006 partition_ordinal, b);
1007 valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
1008 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
1009 }
1010 }
1011
1012 if (!mask_invalid_region) {
1013 return valid_slice;
1014 }
1015
1016 int64 total_right_padding = padded_full_shape_size -
1017 base_shape.dimensions(dim) -
1018 explicit_left_padding_on_full_shape;
1019 // Mask off garbage data due to uneven partition or low/high padding.
1020 if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
1021 auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
1022 auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
1023 auto broadcast_start_index_in_padded_shape =
1024 b->AddInstruction(HloInstruction::CreateBroadcast(
1025 index_shape, offset_on_padded_shape, {}));
1026 auto index_in_padded_shape = b->AddInstruction(
1027 HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
1028 broadcast_start_index_in_padded_shape));
1029 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
1030 std::vector<HloInstruction*> predicates;
1031 if (explicit_left_padding_on_full_shape > 0) {
1032 auto valid_index_start =
1033 b->AddInstruction(HloInstruction::CreateBroadcast(
1034 index_shape,
1035 b->AddInstruction(
1036 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1037 explicit_left_padding_on_full_shape))),
1038 {}));
1039 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1040 mask_shape, index_in_padded_shape, valid_index_start,
1041 ComparisonDirection::kGe)));
1042 }
1043 if (total_right_padding > 0) {
1044 auto valid_index_limit =
1045 b->AddInstruction(HloInstruction::CreateBroadcast(
1046 index_shape,
1047 b->AddInstruction(
1048 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1049 base_shape.dimensions(dim) +
1050 explicit_left_padding_on_full_shape))),
1051 {}));
1052 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1053 mask_shape, index_in_padded_shape, valid_index_limit,
1054 ComparisonDirection::kLt)));
1055 }
1056 CHECK(!predicates.empty());
1057 auto is_valid =
1058 predicates.size() == 2
1059 ? b->AddInstruction(HloInstruction::CreateBinary(
1060 mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1061 : predicates[0];
1062 auto masking_value = b->AddInstruction(
1063 HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1064 valid_slice = b->AddInstruction(
1065 HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1066 is_valid, valid_slice, masking_value));
1067 }
1068 return valid_slice;
1069 }
1070
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64> dims)1071 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1072 absl::Span<const int64> dims) {
1073 if (original.sharding().IsTileMaximal()) {
1074 return original.hlo();
1075 }
1076 // Create a window config to halo exchange for unevenly partitioned reverse
1077 // dimensions.
1078 Window window;
1079 for (int64 i = 0; i < original.base_shape().rank(); ++i) {
1080 WindowDimension* dim = window.add_dimensions();
1081 dim->set_size(1);
1082 dim->set_stride(1);
1083 dim->set_window_dilation(1);
1084 dim->set_window_reversal(false);
1085 int64 low_padding = 0;
1086 if (absl::c_linear_search(dims, i)) {
1087 low_padding =
1088 RoundUpToNearest(original.base_shape().dimensions(i),
1089 original.sharding().tile_assignment().dim(i)) -
1090 original.base_shape().dimensions(i);
1091 }
1092 dim->set_padding_low(low_padding);
1093 dim->set_padding_high(0);
1094 dim->set_base_dilation(1);
1095 }
1096
1097 auto reshard_window = original.ReshardAsWindowedInput(
1098 window, original.sharding(),
1099 CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1100 original.state().b),
1101 /*mask_invalid_region=*/false);
1102 if (!reshard_window.has_value()) {
1103 return nullptr;
1104 }
1105 CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1106 return reshard_window->sharded_input;
1107 }
1108
IsNanSafeGt(HloComputation * comp)1109 bool IsNanSafeGt(HloComputation* comp) {
1110 namespace m = match;
1111 auto match_bitcast_f32 = [](int64 parameter_number) {
1112 auto param = m::Parameter(parameter_number)
1113 .WithShape(m::Shape().WithElementType(F32));
1114 auto param_s32 =
1115 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1116 auto param_u32 =
1117 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1118 return m::Select(
1119 m::Lt(param_s32, m::ConstantScalar(0)),
1120 m::BitcastConvert(
1121 m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1122 param_u32))
1123 .WithShape(m::Shape().WithElementType(S32)),
1124 param_s32);
1125 };
1126 auto match_bitcast_bf16 = [](int64 parameter_number) {
1127 auto param = m::Convert(m::Parameter(parameter_number)
1128 .WithShape(m::Shape().WithElementType(BF16)))
1129 .WithShape(m::Shape().WithElementType(F32));
1130 auto param_s32 =
1131 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1132 auto param_u32 =
1133 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1134 return m::Select(
1135 m::Lt(param_s32, m::ConstantScalar(0)),
1136 m::BitcastConvert(
1137 m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1138 param_u32))
1139 .WithShape(m::Shape().WithElementType(S32)),
1140 param_s32);
1141 };
1142 // If root instruction is kSelect and compares indices if values are equal.
1143 if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1144 return Match(comp->root_instruction()->operand(2),
1145 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1146 Match(comp->root_instruction()->operand(2),
1147 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1148 }
1149 return Match(comp->root_instruction(),
1150 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1151 Match(comp->root_instruction(),
1152 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1153 }
1154
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1155 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
1156 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1157 if (sort == nullptr || sort->operand_count() != 2) {
1158 return absl::nullopt;
1159 }
1160 if (!IsNanSafeGt(sort->to_apply())) {
1161 return absl::nullopt;
1162 }
1163 HloInstruction* data = sort->mutable_operand(0);
1164 HloIotaInstruction* iota =
1165 DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1166 const PrimitiveType element_type = data->shape().element_type();
1167 if (iota == nullptr || iota->shape().element_type() != S32 ||
1168 iota->opcode() != HloOpcode::kIota ||
1169 iota->iota_dimension() != sort->sort_dimension()) {
1170 return absl::nullopt;
1171 }
1172
1173 const int64 sort_dim = sort->sort_dimension();
1174
1175 if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1176 element_type != U32) {
1177 return absl::nullopt;
1178 }
1179
1180 bool supported = true;
1181 absl::optional<int64> k;
1182 for (HloInstruction* gte : sort->users()) {
1183 if (gte->opcode() != HloOpcode::kGetTupleElement) {
1184 supported = false;
1185 break;
1186 }
1187
1188 const HloInstruction* slice = gte->users()[0];
1189 if (slice->opcode() != HloOpcode::kSlice) {
1190 // Non-slice user means we are not doing a TopK
1191 supported = false;
1192 break;
1193 }
1194 if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1195 absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1196 // Strided slice or slicing at the beginning isn't supported.
1197 supported = false;
1198 break;
1199 }
1200 for (int64 dim = 0; dim < data->shape().dimensions_size(); dim++) {
1201 if (dim == sort_dim) {
1202 continue;
1203 }
1204 if (slice->slice_limits(dim) !=
1205 slice->operand(0)->shape().dimensions(dim)) {
1206 // Slicing along the other dimension isn't supported.
1207 supported = false;
1208 break;
1209 }
1210 }
1211 if (!k.has_value()) {
1212 k = slice->slice_limits(sort_dim);
1213 } else if (k != slice->slice_limits(sort_dim)) {
1214 // Different k for the different operands isn't supported.
1215 supported = false;
1216 break;
1217 }
1218 }
1219 if (k == absl::nullopt || !supported) {
1220 return absl::nullopt;
1221 }
1222
1223 // Only support when sort dim is sharded.
1224 if (!data->has_sharding()) {
1225 return absl::nullopt;
1226 }
1227 const HloSharding& sharding = sort->operand(0)->sharding();
1228
1229 if (sharding.IsTileMaximal()) {
1230 return absl::nullopt;
1231 }
1232
1233 // Check if partitioned at sort dimension.
1234 for (int64 dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1235 ++dim) {
1236 if (sharding.tile_assignment().dim(dim) > 1) {
1237 if (dim != sort_dim) {
1238 return absl::nullopt;
1239 }
1240 }
1241 }
1242
1243 // Checks if partition size is smaller than k.
1244 const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
1245
1246 if (shard_count <= 1) {
1247 return absl::nullopt;
1248 }
1249
1250 const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1251 const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
1252
1253 if (k.value() >= per_partition_size) {
1254 return absl::nullopt;
1255 }
1256
1257 return k;
1258 }
1259
1260 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64 slice_dim,int64 k)1261 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1262 int64 slice_dim, int64 k) {
1263 const Shape& hlo_shape = hlo->shape();
1264 auto hlo_dims = hlo_shape.dimensions();
1265 std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
1266 std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
1267 std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
1268 limit_indices[slice_dim] = k;
1269 auto output_shape = hlo_shape;
1270 output_shape.set_dimensions(slice_dim, k);
1271 return builder->AddInstruction(HloInstruction::CreateSlice(
1272 output_shape, hlo, start_indices, limit_indices, strides));
1273 }
1274
1275 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64 dim)1276 int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
1277 if (sharding.IsTileMaximal()) {
1278 return 1;
1279 }
1280 return sharding.tile_assignment().dim(dim);
1281 }
1282
1283 absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1284 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1285 const HloSharding& target) {
1286 if (source.IsTileMaximal() || target.IsTileMaximal() ||
1287 source.tile_assignment().num_dimensions() !=
1288 target.tile_assignment().num_dimensions() ||
1289 source.NumTiles() != target.NumTiles()) {
1290 return absl::nullopt;
1291 }
1292 // Record partition count to index for indices that have different partition
1293 // counts on source and target.
1294 std::map<int64, std::vector<int64>> source_size_to_dim;
1295 std::map<int64, std::vector<int64>> target_size_to_dim;
1296 for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1297 if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1298 continue;
1299 }
1300 source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1301 target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1302 }
1303 // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1304 // must have the same distribution.
1305 if (source_size_to_dim.empty() ||
1306 source_size_to_dim.size() != target_size_to_dim.size()) {
1307 return absl::nullopt;
1308 }
1309 for (const auto& entry : source_size_to_dim) {
1310 auto target_it = target_size_to_dim.find(entry.first);
1311 if (target_it == target_size_to_dim.end() ||
1312 target_it->second.size() != entry.second.size()) {
1313 return absl::nullopt;
1314 }
1315 }
1316 std::vector<std::pair<int64, int64>> result;
1317 auto remove_entry = [](int64 size, int64 dim,
1318 std::map<int64, std::vector<int64>>& size_to_dim) {
1319 size_to_dim[size].erase(
1320 std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1321 [dim](int64 a) { return a == dim; }),
1322 size_to_dim[size].end());
1323 if (size_to_dim[size].empty()) {
1324 size_to_dim.erase(size);
1325 }
1326 };
1327 // Find one pair of dimensions to swap at a time.
1328 while (!source_size_to_dim.empty()) {
1329 int64 source_size = source_size_to_dim.begin()->first;
1330 int64 i = source_size_to_dim.begin()->second.back();
1331 int64 target_i_size = target.tile_assignment().dim(i);
1332 if (target_i_size == source_size) {
1333 remove_entry(source_size, i, source_size_to_dim);
1334 remove_entry(source_size, i, target_size_to_dim);
1335 continue;
1336 }
1337 auto j_it = source_size_to_dim[target_i_size].begin();
1338 int64 j = *j_it;
1339 if (source_size == 1) {
1340 // If possible, find a j where the target partition count is not one, so
1341 // that when we swap, the resulting size-1 dimension will still be useful
1342 // to other dimensions.
1343 while (target.tile_assignment().dim(j) == 1) {
1344 if (++j_it == source_size_to_dim[target_i_size].end()) {
1345 break;
1346 }
1347 j = *j_it;
1348 }
1349 } else if (target_i_size % source_size == 0) {
1350 // If possible, find a j where the target partition count is source_size,
1351 // so that we can do a single swap.
1352 while (target.tile_assignment().dim(j) != source_size) {
1353 if (++j_it == source_size_to_dim[target_i_size].end()) {
1354 break;
1355 }
1356 j = *j_it;
1357 }
1358 } else {
1359 return absl::nullopt;
1360 }
1361 result.emplace_back(j, i);
1362 remove_entry(target_i_size, i, target_size_to_dim);
1363 source_size_to_dim.begin()->second.back() = j;
1364 remove_entry(target_i_size, j, source_size_to_dim);
1365 }
1366 return result;
1367 }
1368
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1369 bool CanReshardWithCollectivePermute(const HloSharding& source,
1370 const HloSharding& target) {
1371 return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1372 source.tile_assignment().dimensions() ==
1373 target.tile_assignment().dimensions() &&
1374 source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1375 source.tile_assignment() != target.tile_assignment();
1376 }
1377
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims)1378 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1379 absl::Span<const int64> group_dims) {
1380 std::vector<int64> group_dim_shards(group_dims.size(), 1);
1381 return GroupShardingOnDims(sharding, group_dims, group_dim_shards);
1382 }
1383
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_shards)1384 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1385 absl::Span<const int64> group_dims,
1386 absl::Span<const int64> group_dim_shards) {
1387 CHECK(!sharding.IsTileMaximal());
1388 std::vector<int64> grouped_tiling_dims =
1389 sharding.tile_assignment().dimensions();
1390 std::vector<int64> group_dim_sizes(group_dims.size());
1391 for (int64 i = 0; i < group_dims.size(); ++i) {
1392 CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1393 group_dim_sizes[i] =
1394 grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1395 grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1396 }
1397
1398 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
1399 sharding.tile_assignment().Each(
1400 [&](absl::Span<const int64> indices, int64 device) {
1401 int64 group_id = 0;
1402 for (int64 i = 0; i < group_dims.size(); ++i) {
1403 group_id *= sharding.tile_assignment().dim(group_dims[i]) /
1404 group_dim_shards[i];
1405 group_id += indices[group_dims[i]] / group_dim_shards[i];
1406 }
1407 device_groups[group_id].push_back(device);
1408 });
1409 auto grouped = GroupedSharding(
1410 std::move(device_groups),
1411 std::vector<int64>(group_dims.begin(), group_dims.end()),
1412 std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1413 HloSharding::Replicate());
1414 if (sharding.ReplicateOnLastTileDim()) {
1415 grouped.data_rank--;
1416 }
1417 if (Product(grouped_tiling_dims) == 1 ||
1418 (sharding.ReplicateOnLastTileDim() &&
1419 Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1420 return grouped;
1421 }
1422 if (sharding.ReplicateOnLastTileDim() && grouped_tiling_dims.back() == 1) {
1423 grouped_tiling_dims.pop_back();
1424 }
1425 Array<int64> grouped_tiling(grouped_tiling_dims);
1426 grouped_tiling.FillIota(0);
1427 grouped.sharding = sharding.ReplicateOnLastTileDim() &&
1428 grouped_tiling_dims.size() ==
1429 sharding.tile_assignment().num_dimensions()
1430 ? HloSharding::PartialTile(grouped_tiling)
1431 : HloSharding::Tile(grouped_tiling);
1432 return grouped;
1433 }
1434
UngroupSharding(const GroupedSharding & grouped_sharding)1435 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1436 std::vector<int64> tiling_dims;
1437 bool partial_sharding = false;
1438 auto grouped_tiling = grouped_sharding.sharding.tile_assignment();
1439 if (grouped_sharding.sharding.IsTileMaximal()) {
1440 tiling_dims = std::vector<int64>(grouped_sharding.data_rank, 1);
1441 if (grouped_sharding.device_groups[0].size() != 1) {
1442 // This is partial sharding.
1443 tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1444 partial_sharding = true;
1445 }
1446 grouped_tiling = Array<int64>(tiling_dims);
1447 grouped_tiling.FillIota(0);
1448 } else {
1449 partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1450 tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1451 if (absl::c_linear_search(grouped_sharding.group_dims,
1452 tiling_dims.size())) {
1453 tiling_dims.push_back(1);
1454 grouped_tiling.Reshape(tiling_dims);
1455 partial_sharding = true;
1456 }
1457 }
1458 for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1459 int64 dim = grouped_sharding.group_dims[i];
1460 tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1461 }
1462 Array<int64> tiling(tiling_dims);
1463 grouped_tiling.Each([&](absl::Span<const int64> indices, int64 device) {
1464 std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
1465 for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1466 int64 remaining_group_index = g;
1467 for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1468 int64 dim = grouped_sharding.group_dims[i];
1469 int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1470 ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1471 grouped_tiling.dim(dim) +
1472 indices[dim];
1473 remaining_group_index /= groups_in_this_dim;
1474 }
1475 tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1476 }
1477 });
1478 return partial_sharding ? HloSharding::PartialTile(tiling)
1479 : HloSharding::Tile(tiling);
1480 }
1481
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1482 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1483 const GroupedSharding& reference,
1484 bool ignore_group_order) {
1485 // Returns src -> dst index mapping.
1486 auto get_permutation = [](absl::Span<const int64> src,
1487 absl::Span<const int64> dst) {
1488 CHECK_EQ(src.size(), dst.size());
1489 absl::flat_hash_map<int64, int64> dst_reverse_map;
1490 for (int64 i = 0; i < dst.size(); ++i) {
1491 dst_reverse_map[dst[i]] = i;
1492 }
1493 std::vector<int64> permutation(src.size());
1494 for (int64 i = 0; i < src.size(); ++i) {
1495 auto it = dst_reverse_map.find(src[i]);
1496 CHECK(it != dst_reverse_map.end());
1497 permutation[i] = it->second;
1498 }
1499 return permutation;
1500 };
1501 CHECK_EQ(grouped_sharding.device_groups.size(),
1502 reference.device_groups.size());
1503 absl::flat_hash_map<int64, int64> device_to_ref_group;
1504 for (int64 g = 0; g < reference.device_groups.size(); ++g) {
1505 for (int64 device : reference.device_groups[g]) {
1506 device_to_ref_group[device] = g;
1507 }
1508 }
1509 auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
1510 int64 ref_g = -1;
1511 for (int64 device : devices) {
1512 if (ref_g == -1) {
1513 ref_g = device_to_ref_group[device];
1514 } else if (ref_g != device_to_ref_group[device]) {
1515 return -1;
1516 }
1517 }
1518 return ref_g;
1519 };
1520 bool matching_groups = true;
1521 std::vector<int64> original_src_to_ref_permutation;
1522 for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1523 int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1524 if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1525 matching_groups = false;
1526 break;
1527 }
1528 if (g == 0) {
1529 original_src_to_ref_permutation = get_permutation(
1530 grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1531 }
1532 }
1533 if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1534 auto tiles = grouped_sharding.sharding.tile_assignment();
1535 tiles.Each([&](absl::Span<const int64> indices, int64* device) {
1536 *device = original_src_to_ref_permutation[*device];
1537 });
1538 grouped_sharding.sharding =
1539 grouped_sharding.sharding.ReplicateOnLastTileDim()
1540 ? HloSharding::PartialTile(tiles)
1541 : HloSharding::Tile(tiles);
1542 }
1543 grouped_sharding.device_groups = std::move(reference.device_groups);
1544 return grouped_sharding;
1545 }
1546
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64> sharding_dims,const HloSharding & reference,absl::Span<const int64> reference_dims)1547 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1548 absl::Span<const int64> sharding_dims,
1549 const HloSharding& reference,
1550 absl::Span<const int64> reference_dims) {
1551 auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims);
1552 auto reference_grouped = GroupShardingOnDims(reference, reference_dims);
1553 return UngroupSharding(AlignGroupsWith(sharding_grouped, reference_grouped));
1554 }
1555
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1556 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1557 const Shape& original_base_shape) {
1558 auto result = original_base_shape;
1559 for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1560 int64 dim = grouped_sharding.group_dims[i];
1561 if (dim >= original_base_shape.rank()) {
1562 continue;
1563 }
1564 int64 groups = grouped_sharding.group_dim_sizes[i];
1565 result.set_dimensions(dim, result.dimensions(dim) / groups);
1566 }
1567 return result;
1568 }
1569
1570 namespace {
1571
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1572 HloInstruction* GetInGroupPartitionId(
1573 HloInstruction* partition_id,
1574 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1575 int64 total_devices = device_groups.size() * device_groups[0].size();
1576 std::vector<uint32> in_group_ids(total_devices);
1577 for (uint32 i = 0; i < device_groups.size(); ++i) {
1578 for (uint32 j = 0; j < device_groups[i].size(); ++j) {
1579 in_group_ids[device_groups[i][j]] = j;
1580 }
1581 }
1582 auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1583 LiteralUtil::CreateR1<uint32>(in_group_ids)));
1584 return b->AddInstruction(HloInstruction::CreateReshape(
1585 ShapeUtil::MakeScalarShape(U32),
1586 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1587 ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1588 }
1589
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64>> & device_groups)1590 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1591 const SPMDCollectiveOpsCreator& creator,
1592 const std::vector<std::vector<int64>>& device_groups) {
1593 SPMDCollectiveOpsCreator result;
1594 result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1595 return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1596 b);
1597 };
1598 auto expand_partition_groups =
1599 [device_groups](
1600 const std::vector<std::vector<int64>>& partition_subgroups) {
1601 if (partition_subgroups.empty()) {
1602 return device_groups;
1603 }
1604 std::vector<std::vector<int64>> result(partition_subgroups.size() *
1605 device_groups.size());
1606 for (int64 g = 0; g < device_groups.size(); ++g) {
1607 for (int64 i = 0; i < partition_subgroups.size(); ++i) {
1608 result[g * partition_subgroups.size() + i].resize(
1609 partition_subgroups[i].size());
1610 for (int64 j = 0; j < partition_subgroups[i].size(); ++j) {
1611 result[g * partition_subgroups.size() + i][j] =
1612 device_groups[g][partition_subgroups[i][j]];
1613 }
1614 }
1615 }
1616 return result;
1617 };
1618 result.create_cross_partition_all_reduce =
1619 [creator, expand_partition_groups](
1620 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1621 const std::vector<std::vector<int64>>& partition_subgroups,
1622 int64 channel_id) {
1623 return creator.create_cross_partition_all_reduce(
1624 b, operand, reduction, expand_partition_groups(partition_subgroups),
1625 channel_id);
1626 };
1627 result.create_cross_partition_collective_permute =
1628 [creator, device_groups](
1629 SpmdBuilder* b, HloInstruction* operand,
1630 std::vector<std::pair<int64, int64>>& src_dst_pairs,
1631 int64 next_channel_id) {
1632 std::vector<std::pair<int64, int64>> expanded_pairs(
1633 src_dst_pairs.size() * device_groups.size());
1634 for (int64 g = 0; g < device_groups.size(); ++g) {
1635 for (int64 i = 0; i < src_dst_pairs.size(); ++i) {
1636 expanded_pairs[g * src_dst_pairs.size() + i] =
1637 std::pair<int64, int64>{
1638 device_groups[g][src_dst_pairs[i].first],
1639 device_groups[g][src_dst_pairs[i].second]};
1640 }
1641 }
1642 return creator.create_cross_partition_collective_permute(
1643 b, operand, expanded_pairs, next_channel_id);
1644 };
1645 result.create_cross_partition_all_to_all =
1646 [creator, expand_partition_groups](
1647 SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1648 const std::vector<std::vector<int64>>& partition_subgroups,
1649 int64 channel_id, absl::optional<int64> split_dimension) {
1650 return creator.create_cross_partition_all_to_all(
1651 b, operands, expand_partition_groups(partition_subgroups),
1652 channel_id, split_dimension);
1653 };
1654 if (creator.create_cross_partition_all_gather) {
1655 result.create_cross_partition_all_gather =
1656 [creator, expand_partition_groups](
1657 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1658 const std::vector<std::vector<int64>>& partition_subgroups,
1659 int64 channel_id, int64 all_gather_dimension) {
1660 return creator.create_cross_partition_all_gather(
1661 b, operand, ag_shape,
1662 expand_partition_groups(partition_subgroups), channel_id,
1663 all_gather_dimension);
1664 };
1665 }
1666 return result;
1667 }
1668
1669 } // namespace
1670
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1671 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1672 const PartitionedHlo::PartitioningState& state,
1673 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1674 auto result = state;
1675 result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1676 state.collective_ops_creator, device_groups);
1677 result.partition_id =
1678 GetInGroupPartitionId(state.partition_id, device_groups, b);
1679 // Create a string key for the groups.
1680 std::vector<std::string> per_group_strings(device_groups.size());
1681 for (int64 i = 0; i < per_group_strings.size(); ++i) {
1682 per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1683 }
1684 auto& grouped_cache =
1685 state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1686 if (!grouped_cache) {
1687 grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
1688 }
1689 result.reshard_cache = grouped_cache.get();
1690 return result;
1691 }
1692
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_sizes,SpmdBuilder * b)1693 HloInstruction* PerGroupSliceFromReplicated(
1694 HloInstruction* replicated, HloInstruction* partition_id,
1695 const std::vector<std::vector<int64>>& device_groups,
1696 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
1697 SpmdBuilder* b) {
1698 std::vector<uint32> group_ids(device_groups.size() * device_groups[0].size());
1699 for (int64 g = 0; g < device_groups.size(); ++g) {
1700 for (int64 device : device_groups[g]) {
1701 group_ids[device] = g;
1702 }
1703 }
1704 auto group_id_table = b->AddInstruction(
1705 HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint32>(group_ids)));
1706 auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1707 ShapeUtil::MakeScalarShape(U32),
1708 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1709 ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1710 {1}))));
1711 std::vector<int64> group_level_tile_dims(replicated->shape().rank(), 1);
1712 for (int64 i = 0; i < group_dims.size(); ++i) {
1713 group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1714 }
1715 Array<int64> group_level_tile(group_level_tile_dims);
1716 group_level_tile.Each([&](absl::Span<const int64> indices, int64* group) {
1717 *group = 0;
1718 for (int64 dim : group_dims) {
1719 *group *= group_level_tile.dim(dim);
1720 *group += indices[dim];
1721 }
1722 });
1723 auto group_level_sharding = HloSharding::Tile(group_level_tile);
1724 auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1725 replicated, group_level_sharding, b);
1726 auto shard_shape =
1727 MakePartitionedShape(replicated->shape(), group_level_sharding);
1728 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1729 shard_shape, padded_hlo,
1730 MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1731 b),
1732 shard_shape.dimensions()));
1733 }
1734
ParseReductionComputation(const HloComputation * reduction_comp)1735 absl::optional<HloOpcode> ParseReductionComputation(
1736 const HloComputation* reduction_comp) {
1737 if (reduction_comp->num_parameters() != 2) {
1738 return absl::nullopt;
1739 }
1740 auto root = reduction_comp->root_instruction();
1741 if (!root->IsElementwiseBinary()) {
1742 return absl::nullopt;
1743 }
1744 if (!absl::c_linear_search(root->operands(),
1745 reduction_comp->parameter_instruction(0)) ||
1746 !absl::c_linear_search(root->operands(),
1747 reduction_comp->parameter_instruction(1))) {
1748 return absl::nullopt;
1749 }
1750 return root->opcode();
1751 }
1752
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64>> & device_groups)1753 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
1754 const HloSharding& sharding,
1755 const std::vector<std::vector<int64>>& device_groups) {
1756 if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 ||
1757 device_groups[0].size() < 2) {
1758 return absl::nullopt;
1759 }
1760 int64 rank = sharding.tile_assignment().num_dimensions();
1761 if (sharding.ReplicateOnLastTileDim()) {
1762 rank--;
1763 }
1764 absl::flat_hash_map<int64, std::vector<int64>> device_to_index;
1765 sharding.tile_assignment().Each(
1766 [&](absl::Span<const int64> index, int64 device) {
1767 device_to_index[device] =
1768 std::vector<int64>(index.begin(), index.begin() + rank);
1769 });
1770 std::vector<int64> dims;
1771 int64 group_count = 1;
1772 for (int64 i = 0; i < rank; ++i) {
1773 if (device_to_index[device_groups[0][0]][i] ==
1774 device_to_index[device_groups[0][1]][i]) {
1775 dims.push_back(i);
1776 group_count *= sharding.tile_assignment().dim(i);
1777 }
1778 }
1779 if (group_count != device_groups.size()) {
1780 return absl::nullopt;
1781 }
1782 for (const auto& group : device_groups) {
1783 for (int64 i = 1; i < group.size(); ++i) {
1784 if (absl::c_any_of(dims, [&](const int64 dim) {
1785 return device_to_index[group[i]][dim] !=
1786 device_to_index[group[0]][dim];
1787 })) {
1788 return absl::nullopt;
1789 }
1790 }
1791 }
1792 return dims;
1793 }
1794
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64> target_dims,absl::Span<const int64> source_dims)1795 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
1796 const HloSharding& source_sharding,
1797 absl::Span<const int64> target_dims,
1798 absl::Span<const int64> source_dims) {
1799 CHECK(target_dims.size() == source_dims.size())
1800 << "Expected 1:1 match between parallel dimensions";
1801 if (source_sharding.IsReplicated()) {
1802 return HloSharding::Replicate();
1803 }
1804 absl::InlinedVector<int64, 4> tile_dims(target_shape.dimensions_size(), 1);
1805 int num_tiles = 1;
1806 for (int i = 0, end = target_dims.size(); i < end; ++i) {
1807 num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1808 tile_dims[target_dims[i]] =
1809 source_sharding.tile_assignment().dim(source_dims[i]);
1810 }
1811 // If there is some partition across non-parallel dimensions in the
1812 // other operand then partially replicate for the new
1813 bool to_be_partially_replicated = false;
1814 if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1815 CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1816 to_be_partially_replicated = true;
1817 tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1818 num_tiles);
1819 }
1820 auto tgt_tile_assignment = source_sharding.tile_assignment();
1821 tgt_tile_assignment.Reshape(tile_dims);
1822 if (to_be_partially_replicated) {
1823 return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1824 target_dims, source_sharding, source_dims);
1825 } else {
1826 return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1827 target_dims, source_sharding, source_dims);
1828 }
1829 }
1830
1831 absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1832 GatherOperandsShardedAcrossParallelDims(
1833 const HloInstruction& operand, const HloInstruction& indices,
1834 const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1835 auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1836 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1837 if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1838 return absl::nullopt;
1839 }
1840 auto new_index_shard = indices.sharding();
1841 auto new_operand_shard = operand.sharding();
1842 int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1843 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1844 if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1845 return absl::nullopt;
1846 }
1847 absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
1848 for (int idx : parallel_dims.index_parallel_in_dim) {
1849 if (idx != -1) {
1850 indices_parallel_dims_ordered_as_operand.push_back(idx);
1851 }
1852 }
1853 if (new_index_shard.IsReplicated()) {
1854 return GatherParallelDimSharding{
1855 CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1856 indices_parallel_dims_ordered_as_operand,
1857 operand_parallel_dims),
1858 new_operand_shard};
1859 }
1860 if (new_operand_shard.IsReplicated()) {
1861 return GatherParallelDimSharding{
1862 new_index_shard,
1863 CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1864 operand_parallel_dims,
1865 indices_parallel_dims_ordered_as_operand)};
1866 }
1867
1868 // Parallel dimension distribution needs to be the same, so try to steal
1869 // sharding from partial replication to compensate.
1870 if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1871 auto to_adjust_dims = operand_parallel_dims;
1872 auto target_dims = indices_parallel_dims_ordered_as_operand;
1873 HloSharding* target = &new_index_shard;
1874 HloSharding* to_adjust = &new_operand_shard;
1875 if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1876 std::swap(to_adjust_dims, target_dims);
1877 std::swap(to_adjust, target);
1878 }
1879 if (!to_adjust->ReplicateOnLastTileDim()) {
1880 return absl::nullopt;
1881 }
1882 auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1883 for (int i = 0; i < to_adjust_dims.size(); ++i) {
1884 int64 target_dim = target->tile_assignment().dim(target_dims[i]);
1885 int64 to_adjust_dim = to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1886 if (target_dim < to_adjust_dim) {
1887 return absl::nullopt;
1888 }
1889 if (target_dim == to_adjust_dim) {
1890 continue;
1891 }
1892 int64 ratio = target_dim / to_adjust_dim;
1893 if (target_dim % to_adjust_dim != 0 ||
1894 new_tile_assignment_dims.back() % ratio != 0) {
1895 return absl::nullopt;
1896 }
1897 new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1898 new_tile_assignment_dims.back() /= ratio;
1899 }
1900 CHECK_GE(new_tile_assignment_dims.back(), 1);
1901 bool to_partially_replicate = true;
1902 if (new_tile_assignment_dims.back() == 1) {
1903 new_tile_assignment_dims.pop_back();
1904 to_partially_replicate = false;
1905 }
1906 auto new_tile_assignment = to_adjust->tile_assignment();
1907 new_tile_assignment.Reshape(new_tile_assignment_dims);
1908 if (to_partially_replicate) {
1909 *to_adjust =
1910 AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1911 to_adjust_dims, *target, target_dims);
1912 } else {
1913 *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1914 to_adjust_dims, *target, target_dims);
1915 }
1916 }
1917 // Make sure that the parallel dimensions are aligned.
1918 auto operand_shard_tile_dims =
1919 new_operand_shard.tile_assignment().dimensions();
1920 for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1921 operand_shard_tile_dims[operand_parallel_dims[i]] =
1922 new_index_shard.tile_assignment().dim(
1923 indices_parallel_dims_ordered_as_operand[i]);
1924 }
1925 auto operand_shard_tiles = new_operand_shard.tile_assignment();
1926 operand_shard_tiles.Reshape(operand_shard_tile_dims);
1927 new_operand_shard =
1928 AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1929 ? HloSharding::PartialTile(operand_shard_tiles)
1930 : HloSharding::Tile(operand_shard_tiles),
1931 operand_parallel_dims, new_index_shard,
1932 indices_parallel_dims_ordered_as_operand);
1933 return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1934 }
1935
1936 } // namespace spmd
1937 } // namespace xla
1938