1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 
44 namespace xla {
45 namespace spmd {
46 
HasReplicatedSharding(const HloSharding & sharding)47 bool HasReplicatedSharding(const HloSharding& sharding) {
48   if (sharding.IsTuple()) {
49     return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
50   }
51   return sharding.IsReplicated();
52 }
53 
CreateConstant(const Shape & shape,Literal value,SpmdBuilder * b)54 HloInstruction* CreateConstant(const Shape& shape, Literal value,
55                                SpmdBuilder* b) {
56   if (shape.IsTuple()) {
57     std::vector<HloInstruction*> elements;
58     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
59       elements.push_back(CreateConstant(
60           ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b));
61     }
62     return b->AddInstruction(HloInstruction::CreateTuple(elements));
63   }
64 
65   CHECK(
66       ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type()));
67   auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value)));
68   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
69 }
70 
CreateZero(const Shape & shape,SpmdBuilder * b)71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
72   if (shape.IsTuple()) {
73     std::vector<HloInstruction*> elements;
74     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
75       elements.push_back(
76           CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
77     }
78     return b->AddInstruction(HloInstruction::CreateTuple(elements));
79   }
80 
81   if (shape.IsToken()) {
82     return b->AddInstruction(HloInstruction::CreateToken());
83   }
84   auto zero = b->AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
86   if (shape.rank() == 0) {
87     return zero;
88   }
89   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
90 }
91 
CreateOne(const Shape & shape,SpmdBuilder * b)92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) {
93   if (shape.IsTuple()) {
94     std::vector<HloInstruction*> elements;
95     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
96       elements.push_back(
97           CreateOne(ShapeUtil::GetTupleElementShape(shape, i), b));
98     }
99     return b->AddInstruction(HloInstruction::CreateTuple(elements));
100   }
101 
102   if (shape.IsToken()) {
103     return b->AddInstruction(HloInstruction::CreateToken());
104   }
105   auto one = b->AddInstruction(
106       HloInstruction::CreateConstant(LiteralUtil::One(shape.element_type())));
107   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, one, {}));
108 }
109 
MakeBinaryAdd(PrimitiveType type,HloModule * module)110 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
111   HloComputation::Builder sum_b("add");
112   auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
113       /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
114   auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
115       /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
116   if (type == PRED) {
117     sum_b.AddInstruction(HloInstruction::CreateBinary(
118         ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
119   } else {
120     sum_b.AddInstruction(HloInstruction::CreateBinary(
121         ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
122   }
123   HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
124   return reduction;
125 }
126 
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
128   if (sharding.IsTuple()) {
129     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
130       if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
131                             sharding.GetSubSharding(shape, {i}))) {
132         return false;
133       }
134     }
135   }
136 
137   if (sharding.IsTileMaximal()) {
138     return sharding.IsReplicated();
139   }
140   for (int64 i = 0; i < shape.dimensions_size(); ++i) {
141     if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
142       return false;
143     }
144   }
145   return true;
146 }
147 
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
149   if (sharding.IsTuple()) {
150     std::vector<Shape> subshapes;
151     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
152       subshapes.push_back(
153           MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
154                                sharding.GetSubSharding(shape, {i})));
155     }
156     return ShapeUtil::MakeTupleShape(subshapes);
157   }
158   return sharding.TileShape(shape);
159 }
160 
ShapeSizeInBytes(const Shape & shape)161 int64 ShapeSizeInBytes(const Shape& shape) {
162   return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
163          ShapeUtil::ElementsIn(shape);
164 }
165 
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64 partition_id)166 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
167                                           const HloSharding& sharding,
168                                           int64 partition_id) {
169   if (sharding.IsTuple()) {
170     std::vector<Shape> subshapes;
171     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
172       subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
173           ShapeUtil::GetTupleElementShape(shape, i),
174           sharding.GetSubSharding(shape, {i}), partition_id));
175     }
176     return ShapeUtil::MakeTupleShape(subshapes);
177   }
178 
179   if (sharding.IsReplicated()) {
180     return shape;
181   }
182   if (sharding.IsTileMaximal()) {
183     if (partition_id == *sharding.UniqueDevice()) {
184       return shape;
185     }
186     return ShapeUtil::MakeTupleShape({});
187   }
188 
189   auto partition_shape = shape;
190   std::vector<int64> tile_offset =
191       sharding.TileOffsetForDevice(shape, partition_id);
192   std::vector<int64> tile_limit =
193       sharding.TileLimitForDevice(shape, partition_id);
194   for (int64 i = 0; i < tile_offset.size(); ++i) {
195     if (sharding.UsesDevice(partition_id)) {
196       partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
197     } else {
198       partition_shape.set_dimensions(i, 0);
199     }
200   }
201   return partition_shape;
202 }
203 
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64> dims)204 std::vector<HloInstruction*> MakePartitionOffsets(
205     const Shape& shape, const HloSharding& sharding,
206     HloInstruction* partition_id, SpmdBuilder* b,
207     absl::Span<const int64> dims) {
208   CHECK(!shape.IsTuple());
209 
210   std::vector<std::vector<int32>> offset_arrays(shape.rank());
211   for (int64 i = 0; i < shape.rank(); ++i) {
212     offset_arrays[i].resize(sharding.tile_assignment().num_elements());
213   }
214   auto shard_shape = MakePartitionedShape(shape, sharding);
215   sharding.tile_assignment().Each(
216       [&](absl::Span<const int64> indices, int64 device) {
217         for (int64 i = 0; i < shape.rank(); ++i) {
218           offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
219         }
220       });
221   std::vector<HloInstruction*> offsets;
222   for (int64 i = 0; i < shape.rank(); ++i) {
223     if (sharding.tile_assignment().dim(i) == 1 ||
224         (!dims.empty() && !absl::c_linear_search(dims, i))) {
225       offsets.push_back(b->AddInstruction(
226           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
227     } else {
228       auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
229           LiteralUtil::CreateR1<int32>(offset_arrays[i])));
230       auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
231           ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
232       offsets.push_back(b->AddInstruction(
233           HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
234     }
235   }
236   return offsets;
237 }
238 
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)239 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
240     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
241   CHECK(!sharding.IsTileMaximal());
242   auto dimensions = sharding.tile_assignment().dimensions();
243   if (sharding.ReplicateOnLastTileDim()) {
244     dimensions.pop_back();
245   }
246   auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
247   return MakePartitionOffsets(table_shape, sharding, partition_id, b);
248 }
249 
PadToShape(HloInstruction * hlo,const Shape & padded_shape,SpmdBuilder * b,HloComputation * computation)250 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
251                            SpmdBuilder* b, HloComputation* computation) {
252   CHECK(b == nullptr || computation == nullptr);
253   if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
254     return hlo;
255   }
256   PaddingConfig padding_config;
257   for (int64 i = 0; i < padded_shape.rank(); ++i) {
258     auto padding_config_dim = padding_config.add_dimensions();
259     padding_config_dim->set_edge_padding_low(0);
260     padding_config_dim->set_interior_padding(0);
261     padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
262                                               hlo->shape().dimensions(i));
263   }
264   auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
265     if (b == nullptr) {
266       return computation->AddInstruction(std::move(to_add));
267     }
268     return b->AddInstruction(std::move(to_add));
269   };
270   auto zero = add_hlo(HloInstruction::CreateConstant(
271       LiteralUtil::Zero(hlo->shape().element_type())));
272   return add_hlo(
273       HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
274 }
275 
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)276 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
277                                           const HloSharding& sharding) {
278   if (sharding.IsTileMaximal()) {
279     return base_shape;
280   }
281   if (EvenlyPartitions(base_shape, sharding)) {
282     return base_shape;
283   }
284   auto shard_shape = MakePartitionedShape(base_shape, sharding);
285   Shape padded_base_shape = base_shape;
286   for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
287     padded_base_shape.set_dimensions(
288         i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
289   }
290   return padded_base_shape;
291 }
292 
PadBaseShapeBeforeUnevenTiledSharding(HloInstruction * hlo,const HloSharding & sharding,SpmdBuilder * b)293 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
294     HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
295   auto padded_base_shape =
296       GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
297   if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
298     return hlo;
299   }
300   return PadToShape(hlo, padded_base_shape, b);
301 }
302 
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)303 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
304     const HloSharding& partial_sharding, const HloSharding& target_sharding) {
305   if (!partial_sharding.ReplicateOnLastTileDim()) {
306     return absl::nullopt;
307   }
308   int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1;
309   int64 target_rank = target_sharding.tile_assignment().num_dimensions() -
310                       (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
311   if (target_rank != rank) {
312     return absl::nullopt;
313   }
314 
315   absl::flat_hash_map<int64, int64> device_to_replication_group;
316   partial_sharding.tile_assignment().Each(
317       [&](absl::Span<const int64> indices, int64 device) {
318         int64 gid = 0;
319         for (int64 i = 0; i < rank; ++i) {
320           gid *= partial_sharding.tile_assignment().dim(i);
321           gid += indices[i];
322         }
323         device_to_replication_group[device] = gid;
324       });
325 
326   // A dimension is expanded when target_tile_size > partial_tile_size and
327   // target_tile_size % partial_tile_size == 0.
328   // expand_tile_dims_positions is the index of the expand_dim.
329   std::vector<int64> expand_tile_dims_indices(rank, -1);
330   // expand_tile_size = target_tile_size / partial_tile_size.
331   std::vector<int64> expand_tile_sizes;
332   int num_expand_dims = 0;
333   for (int64 dim = 0; dim < rank; dim++) {
334     int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim);
335     int64 target_tile_size = target_sharding.tile_assignment().dim(dim);
336     if (target_tile_size % partial_tile_size != 0 ||
337         target_tile_size < partial_tile_size) {
338       return absl::nullopt;
339     }
340 
341     if (target_tile_size > partial_tile_size) {
342       expand_tile_dims_indices[dim] = num_expand_dims++;
343       expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
344     }
345   }
346 
347   // Reshape the partial replicate tile_dimensions.
348   int64 num_target_replication = 1;
349   if (target_sharding.ReplicateOnLastTileDim()) {
350     num_target_replication =
351         target_sharding.tile_assignment().dimensions().back();
352   }
353   auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
354   int64 num_replication = reshape_dimensions.back();
355   if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
356       num_replication % num_target_replication != 0) {
357     return absl::nullopt;
358   }
359 
360   reshape_dimensions.pop_back();
361   reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
362                             expand_tile_sizes.end());
363 
364   if (target_sharding.ReplicateOnLastTileDim()) {
365     reshape_dimensions.push_back(num_target_replication);
366   }
367 
368   auto reshape_tile_assignment = partial_sharding.tile_assignment();
369   reshape_tile_assignment.Reshape(reshape_dimensions);
370 
371   // Transpose.
372   std::vector<int64> perm;
373   perm.reserve(rank + expand_tile_sizes.size());
374   for (int64 dim = 0; dim < rank; dim++) {
375     perm.emplace_back(dim);
376     if (expand_tile_dims_indices[dim] > -1) {
377       perm.emplace_back(expand_tile_dims_indices[dim] + rank);
378     }
379   }
380   auto transpose_sharding = hlo_sharding_util::TransposeSharding(
381       target_sharding.ReplicateOnLastTileDim()
382           ? HloSharding::PartialTile(reshape_tile_assignment)
383           : HloSharding::Tile(reshape_tile_assignment),
384       perm);
385 
386   // Reshape to target shape
387   auto transpose_tile_assignment = transpose_sharding.tile_assignment();
388   transpose_tile_assignment.Reshape(
389       target_sharding.tile_assignment().dimensions());
390 
391   bool groups_matching = true;
392   target_sharding.tile_assignment().Each(
393       [&](absl::Span<const int64> indices, int64 device) {
394         if (device_to_replication_group[device] !=
395             device_to_replication_group[transpose_tile_assignment(indices)]) {
396           groups_matching = false;
397         }
398       });
399 
400   if (groups_matching) {
401     return target_sharding;
402   }
403   return target_sharding.ReplicateOnLastTileDim()
404              ? HloSharding::PartialTile(transpose_tile_assignment)
405              : HloSharding::Tile(transpose_tile_assignment);
406 }
407 
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)408 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
409     HloInstruction* hlo, const Shape& base_shape,
410     const HloSharding& src_sharding, const HloSharding& dst_sharding,
411     const std::vector<int64>& replicate_dims,
412     const SPMDCollectiveOpsCreator& collective_ops_creator,
413     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
414   // Source is tile sharding.
415   auto padded_src_shape =
416       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
417   // Target is partial replicate.
418   auto padded_dst_shape =
419       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
420   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
421     return hlo;
422   }
423 
424   auto partition_ordinals =
425       MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
426 
427   auto result = hlo;
428   auto hlo_shape = hlo->shape();
429   for (auto dim : replicate_dims) {
430     int64 dst_shard_count = dst_sharding.tile_assignment().dim(dim);
431     int64 src_per_shard_size =
432         padded_src_shape.dimensions(dim) / dst_shard_count;
433     // Calculate per shard size using the sharding to compare if dst_sharding
434     // needs more padding at the end.
435     int64 dst_per_shard_size =
436         padded_dst_shape.dimensions(dim) / dst_shard_count;
437 
438     // If src per shard doesn't have redudant data.
439     if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
440       continue;
441     }
442 
443     // If src_per_shard * replicate_factor > dst_per_shard , need to
444     // re-distribute the data between each shard using collective permute. For
445     // example, if dimension size is 6 and shard 4 ways in the src but needs to
446     // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
447     // while 2 way sharding has 3 elements, the last element in the first shard
448     // will be sliced out. re-distribution is needed.
449     //
450     // 1. Calculate left_halo size.
451     // left-halo size is
452     //   (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
453     int64 replicate_factor = src_sharding.tile_assignment().dim(dim) /
454                              dst_sharding.tile_assignment().dim(dim);
455     OffsetCalculation left_halo_size_function =
456         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
457             src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
458 
459     // 2. Calculate right_halo size.
460     // right-halo size is 0
461     OffsetCalculation right_halo_size_function =
462         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
463 
464     auto concat = result;
465     // 3. Halo exchange.
466     auto halo_exchange_result = ExchangeHalo(
467         result, left_halo_size_function, right_halo_size_function, dim,
468         src_sharding, collective_ops_creator, next_channel_id, b);
469 
470     if (halo_exchange_result.has_value()) {
471       concat = halo_exchange_result.value();
472     } else {
473       return absl::nullopt;
474     }
475 
476     // 4. Slice the valid result.
477     // Slice offset is
478     // (dst_shard_count - i - 1) *
479     // (src_per_shard_size - dst_per_shard_size)
480     // i is the index in dst_sharindg.
481     auto zero_s32 = b->AddInstruction(
482         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
483     OffsetCalculation start_offset_on_padded_concat_calculation =
484         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
485             dst_per_shard_size - src_per_shard_size,
486             (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
487             1));
488     auto slice_shape = concat->shape();
489     slice_shape.set_dimensions(dim,
490                                padded_src_shape.dimensions(dim) /
491                                    src_sharding.tile_assignment().dim(dim));
492     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
493                                                zero_s32);
494     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
495         partition_ordinals[dim], b);
496     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
497         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
498   }
499   return result;
500 }
501 
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)502 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
503     HloInstruction* hlo, const Shape& base_shape,
504     const HloSharding& src_sharding, const HloSharding& dst_sharding,
505     const std::vector<int64>& expand_tile_dims,
506     const SPMDCollectiveOpsCreator& collective_ops_creator,
507     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
508   auto padded_src_shape =
509       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
510   auto padded_dst_shape =
511       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
512   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
513     return hlo;
514   }
515 
516   auto partition_ordinals =
517       MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
518 
519   HloInstruction* result = hlo;
520   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
521       LiteralUtil::Zero(hlo->shape().element_type())));
522   std::vector<int64> expand_dims_without_halo_exchange;
523   // Pad the dimensions needs halo exchange and record the padded dims that
524   // won't need halo exchange.
525   for (auto dim : expand_tile_dims) {
526     int64 src_shard_count = src_sharding.tile_assignment().dim(dim);
527     int64 src_per_shard_size =
528         padded_src_shape.dimensions(dim) / src_shard_count;
529     // Calculate per shard size using the sharding to compare if dst_sharding
530     // needs more padding at the end.
531     int64 dst_per_shard_size =
532         padded_dst_shape.dimensions(dim) / src_shard_count;
533 
534     // If dst_sharding doesn't need more padding at the end.
535     if (src_per_shard_size >= dst_per_shard_size) {
536       continue;
537     }
538     // If src sharding at this dimension is not partitoned, simply pad to
539     // the desired shape.
540     if (src_shard_count == 1) {
541       expand_dims_without_halo_exchange.emplace_back(dim);
542       continue;
543     }
544 
545     // If dst_padding needs more padding at the end, need to re-distribute the
546     // data between each shard using collective permute.
547     // For example, if dimension size is 6 and shard 2 ways in the src but
548     // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
549     // and has 2 elements at each shard, while 2 way sharding has 3 elements
550     // in each shard, re-distribution is needed.
551     //
552     // 1. Calculate left_halo size.
553     // left-halo size is 0
554     OffsetCalculation left_halo_size_function =
555         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
556 
557     // 2. Calculate right_halo size.
558     // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
559     OffsetCalculation right_halo_size_function =
560         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
561             dst_per_shard_size - src_per_shard_size,
562             dst_per_shard_size - src_per_shard_size, 1));
563 
564     auto concat = result;
565     // 3. Halo exchange.
566     auto halo_exchange_result = ExchangeHalo(
567         result, left_halo_size_function, right_halo_size_function, dim,
568         src_sharding, collective_ops_creator, next_channel_id, b);
569 
570     if (halo_exchange_result.has_value()) {
571       concat = halo_exchange_result.value();
572     } else {
573       return absl::nullopt;
574     }
575 
576     // 4. Pad.
577     std::vector<int64> zero_padding(concat->shape().rank());
578     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
579     pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
580     int64 max_right_halo_size =
581         right_halo_size_function.MaxInRange(0, src_shard_count - 1);
582     pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max(
583         int64{0}, padded_dst_shape.dimensions(dim) -
584                       padded_src_shape.dimensions(dim) - max_right_halo_size));
585     auto padded_concat_shape = ShapeInference::InferPadShape(
586                                    concat->shape(), zero->shape(), pad_config)
587                                    .ValueOrDie();
588     concat = b->AddInstruction(HloInstruction::CreatePad(
589         padded_concat_shape, concat, zero, pad_config));
590 
591     // 5. Slice the valid result.
592     // Slice offset is (D-S) * i
593     auto zero_s32 = b->AddInstruction(
594         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
595     OffsetCalculation start_offset_on_padded_concat_calculation =
596         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
597             dst_per_shard_size - src_per_shard_size, 0, 1));
598     auto slice_shape = concat->shape();
599     slice_shape.set_dimensions(dim, dst_per_shard_size);
600     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
601                                                zero_s32);
602     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
603         partition_ordinals[dim], b);
604     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
605         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
606   }
607 
608   // Pad other dimensions that won't need halo exchange with a single pad.
609   if (!expand_dims_without_halo_exchange.empty()) {
610     std::vector<int64> zero_padding(result->shape().rank());
611     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
612 
613     auto padded_shape = result->shape();
614     for (auto dim : expand_dims_without_halo_exchange) {
615       pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
616       pad_config.mutable_dimensions(dim)->set_edge_padding_high(
617           padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
618       padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
619                                            padded_dst_shape.dimensions(dim) -
620                                            padded_src_shape.dimensions(dim));
621     }
622     result = b->AddInstruction(
623         HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
624   }
625 
626   return result;
627 }
628 
UniqueTiledDim(const HloSharding & sharding)629 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
630   if (sharding.IsTileMaximal()) {
631     return absl::nullopt;
632   }
633   int64 dim = -1;
634   int64 rank = sharding.ReplicateOnLastTileDim()
635                    ? sharding.tile_assignment().num_dimensions() - 1
636                    : sharding.tile_assignment().num_dimensions();
637   for (int64 i = 0; i < rank; ++i) {
638     if (sharding.tile_assignment().dim(i) > 1) {
639       if (dim != -1) {
640         return absl::nullopt;
641       }
642       dim = i;
643     }
644   }
645   CHECK_NE(dim, -1);
646   return dim;
647 }
648 
MultiplyAddDivideOffsetCalculation(int64 multiplier,int64 offset,int64 divisor)649 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
650     int64 multiplier, int64 offset, int64 divisor)
651     : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
652   CHECK_GT(divisor_, 0);
653   Simplify();
654 }
655 
operator -(const MultiplyAddDivideOffsetCalculation & other) const656 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
657     const MultiplyAddDivideOffsetCalculation& other) const {
658   if (divisor_ == 1 && other.divisor_ == 1) {
659     return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
660         multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
661   }
662   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
663 }
664 
Simplify()665 void MultiplyAddDivideOffsetCalculation::Simplify() {
666   // We could simplify the calculation when multiplier is a multiple of
667   // divisor_. However, when offset_ is not a multiple of divisor_, we must
668   // make sure that offset_ and multiplier_ are both non-negative or both
669   // non-positive. E.g., (3 * i  - 1) / 3 is not equivalent to i or i - 1.
670   if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
671       (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
672     multiplier_ /= divisor_;
673     offset_ /= divisor_;
674     divisor_ = 1;
675   }
676 }
677 
Calculate(int64 shard_ordinal) const678 int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const {
679   return (shard_ordinal * multiplier_ + offset_) / divisor_;
680 }
681 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const682 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
683     HloInstruction* shard_ordinal, SpmdBuilder* b) const {
684   auto scalar_shape = ShapeUtil::MakeShape(S32, {});
685   if (multiplier_ == 0) {
686     return b->AddInstruction(HloInstruction::CreateConstant(
687         LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
688   }
689   HloInstruction* result = shard_ordinal;
690   if (multiplier_ != 1) {
691     result = b->AddInstruction(HloInstruction::CreateBinary(
692         scalar_shape, HloOpcode::kMultiply, shard_ordinal,
693         b->AddInstruction(HloInstruction::CreateConstant(
694             LiteralUtil::CreateR0<int32>(multiplier_)))));
695   }
696   if (offset_ != 0) {
697     auto offset = b->AddInstruction(
698         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
699     result = b->AddInstruction(HloInstruction::CreateBinary(
700         scalar_shape, HloOpcode::kAdd, result, offset));
701   }
702   if (divisor_ != 1) {
703     auto divisor = b->AddInstruction(
704         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
705     result = b->AddInstruction(HloInstruction::CreateBinary(
706         scalar_shape, HloOpcode::kDivide, result, divisor));
707   }
708   return result;
709 }
710 
MaxInRange(int64 start_ordinal,int64 limit_ordinal) const711 int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
712     int64 start_ordinal, int64 limit_ordinal) const {
713   int64 max = Calculate(start_ordinal);
714   for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
715     max = std::max(max, Calculate(i));
716   }
717   return max;
718 }
719 
operator =(const OffsetCalculation & other)720 OffsetCalculation& OffsetCalculation::operator=(
721     const OffsetCalculation& other) {
722   opcode_ = other.opcode_;
723   copy_from_ = other.copy_from_;
724   if (opcode_ != HloOpcode::kCopy) {
725     lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
726     rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
727   }
728   return *this;
729 }
730 
IsConstant() const731 bool OffsetCalculation::IsConstant() const {
732   if (opcode_ == HloOpcode::kCopy) {
733     return copy_from_.IsConstant();
734   }
735   if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
736     return true;
737   }
738   return lhs_->IsConstant() && rhs_->IsConstant();
739 }
740 
operator -(const OffsetCalculation & other) const741 OffsetCalculation OffsetCalculation::operator-(
742     const OffsetCalculation& other) const {
743   if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
744     return copy_from_ - other.copy_from_;
745   }
746   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
747 }
748 
operator ==(const OffsetCalculation & other) const749 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
750   if (opcode_ != other.opcode_) {
751     return false;
752   }
753   if (opcode_ == HloOpcode::kCopy) {
754     return copy_from_ == other.copy_from_;
755   }
756   return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
757 }
758 
Calculate(int64 shard_ordinal) const759 int64 OffsetCalculation::Calculate(int64 shard_ordinal) const {
760   switch (opcode_) {
761     case HloOpcode::kCopy:
762       return copy_from_.Calculate(shard_ordinal);
763     case HloOpcode::kSubtract:
764       return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
765     case HloOpcode::kMultiply:
766       return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
767     default:
768       LOG(FATAL) << "Should not happen";
769   }
770 }
771 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const772 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
773                                              SpmdBuilder* b) const {
774   if (opcode_ == HloOpcode::kCopy) {
775     return copy_from_.Calculate(shard_ordinal, b);
776   }
777   auto lhs = lhs_->Calculate(shard_ordinal, b);
778   auto rhs = rhs_->Calculate(shard_ordinal, b);
779   return b->AddInstruction(
780       HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
781 }
782 
MaxInRange(int64 start_ordinal,int64 limit_ordinal) const783 int64 OffsetCalculation::MaxInRange(int64 start_ordinal,
784                                     int64 limit_ordinal) const {
785   if (IsConstant()) {
786     return Calculate(start_ordinal);
787   }
788   if (opcode_ == HloOpcode::kCopy) {
789     return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
790   }
791   int64 max = Calculate(start_ordinal);
792   for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
793     max = std::max(max, Calculate(i));
794   }
795   return max;
796 }
797 
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64 dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)798 absl::optional<HloInstruction*> ExchangeHalo(
799     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
800     const OffsetCalculation& right_halo_size_function, int64 dim,
801     const HloSharding& target,
802     const SPMDCollectiveOpsCreator& collective_ops_creator,
803     int64* next_channel_id, SpmdBuilder* b) {
804   int64 input_shard_size = hlo->shape().dimensions(dim);
805   int64 shard_count = target.tile_assignment().dim(dim);
806 
807   std::vector<HloInstruction*> concat_pieces;
808 
809   int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
810   int64 max_right_halo_size =
811       right_halo_size_function.MaxInRange(0, shard_count - 1);
812   if (max_left_halo_size + max_right_halo_size + input_shard_size >=
813           input_shard_size * shard_count &&
814       (max_left_halo_size > input_shard_size ||
815        max_right_halo_size > input_shard_size)) {
816     return absl::nullopt;
817   }
818   // Left halo.
819   for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
820        --i) {
821     std::vector<std::pair<int64, int64>> source_target_pairs;
822     target.tile_assignment().Each(
823         [&](absl::Span<const int64> indices, int64 device) {
824           if (indices[dim] > i) {
825             std::vector<int64> source_indices(indices.begin(), indices.end());
826             source_indices[dim] -= i + 1;
827             source_target_pairs.emplace_back(
828                 target.tile_assignment()(source_indices), device);
829           }
830         });
831     int64 halo_size =
832         std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
833     auto halo_shape = hlo->shape();
834     auto source_halo_slice = hlo;
835     if (halo_size != hlo->shape().dimensions(dim)) {
836       halo_shape.set_dimensions(dim, halo_size);
837       std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
838       halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
839       std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
840       source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
841           halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
842           halo_slice_strides));
843     }
844     auto left_halo =
845         collective_ops_creator.create_cross_partition_collective_permute(
846             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
847     concat_pieces.push_back(left_halo);
848   }
849 
850   concat_pieces.push_back(hlo);
851 
852   // Right halo.
853   for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
854        ++i) {
855     std::vector<std::pair<int64, int64>> source_target_pairs;
856     target.tile_assignment().Each(
857         [&](absl::Span<const int64> indices, int64 device) {
858           if (indices[dim] > i) {
859             std::vector<int64> target_indices(indices.begin(), indices.end());
860             target_indices[dim] -= i + 1;
861             source_target_pairs.emplace_back(
862                 device, target.tile_assignment()(target_indices));
863           }
864         });
865     int64 halo_size =
866         std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
867     auto halo_shape = hlo->shape();
868     HloInstruction* source_halo_slice = hlo;
869     if (halo_size != halo_shape.dimensions(dim)) {
870       halo_shape.set_dimensions(dim, halo_size);
871       std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
872       std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
873       source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
874           halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
875           halo_slice_strides));
876     }
877     auto right_halo =
878         collective_ops_creator.create_cross_partition_collective_permute(
879             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
880     concat_pieces.push_back(right_halo);
881   }
882 
883   auto concat = hlo;
884   // Concat with halos/padding.
885   if (concat_pieces.size() > 1) {
886     auto concat_shape = hlo->shape();
887     int64 concat_dim_size = 0;
888     for (auto piece : concat_pieces) {
889       concat_dim_size += piece->shape().dimensions(dim);
890     }
891     concat_shape.set_dimensions(dim, concat_dim_size);
892     concat = b->AddInstruction(
893         HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
894   }
895 
896   return concat;
897 }
898 
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)899 absl::optional<HloInstruction*> ExchangeHalo(
900     HloInstruction* hlo,
901     std::vector<OffsetCalculation> left_halo_size_functions,
902     std::vector<OffsetCalculation> right_halo_size_functions,
903     const HloSharding& target,
904     const SPMDCollectiveOpsCreator& collective_ops_creator,
905     int64* next_channel_id, SpmdBuilder* b) {
906   CHECK(left_halo_size_functions.size() == hlo->shape().rank());
907   CHECK(right_halo_size_functions.size() == hlo->shape().rank());
908 
909   HloInstruction* visiting_hlo = hlo;
910   for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
911     auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
912                                right_halo_size_functions[dim], dim, target,
913                                collective_ops_creator, next_channel_id, b);
914     if (!concat) {
915       return absl::nullopt;
916     }
917     visiting_hlo = *concat;
918   }
919   return visiting_hlo;
920 }
921 
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64 explicit_left_padding_on_full_shape,int64 padded_full_shape_size,int64 shard_size_with_halo,int64 dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)922 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
923     HloInstruction* hlo, const Shape& base_shape,
924     const OffsetCalculation& left_halo_size_function,
925     const OffsetCalculation& right_halo_size_function,
926     int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
927     int64 shard_size_with_halo, int64 dim, const HloSharding& target,
928     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
929     HloInstruction* partition_ordinal,
930     const SPMDCollectiveOpsCreator& collective_ops_creator,
931     int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
932   auto halo_exchange_result =
933       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
934                    target, collective_ops_creator, next_channel_id, b);
935   if (!halo_exchange_result) {
936     return absl::nullopt;
937   }
938   auto concat = *halo_exchange_result;
939   int64 shard_count = target.tile_assignment().dim(dim);
940   int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
941 
942   // Now we determine if we need extra padding after the concat.
943   //
944   // The max of halo size or the first shard's explicit left padding.
945   int64 max_left_halo_or_padding_size =
946       std::max(std::max(int64{0}, max_left_halo_size),
947                explicit_left_padding_on_full_shape);
948   // The calculation that returns the dynamic slice index for a shard on the
949   // padded concat, which is the difference between
950   // max_left_halo_or_padding_size and its left halo size.
951   auto start_offset_on_padded_concat_calculation =
952       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
953           0, max_left_halo_or_padding_size, 1)) -
954       left_halo_size_function;
955 
956   // See if we need to pad the concat before dynamic slice.
957   int64 extra_left_padding =
958       std::max(int64{0}, max_left_halo_or_padding_size -
959                              std::max(int64{0}, max_left_halo_size));
960   int64 extra_right_padding =
961       start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
962       shard_size_with_halo - concat->shape().dimensions(dim) -
963       extra_left_padding;
964   extra_right_padding = std::max(int64{0}, extra_right_padding);
965   if (extra_left_padding > 0 || extra_right_padding > 0) {
966     PaddingConfig padding_config;
967     auto padded_concat_shape = concat->shape();
968     for (int64 i = 0; i < base_shape.rank(); ++i) {
969       auto padding_config_dim = padding_config.add_dimensions();
970       padding_config_dim->set_interior_padding(0);
971       padding_config_dim->set_edge_padding_low(0);
972       padding_config_dim->set_edge_padding_high(0);
973       if (i != dim) {
974         continue;
975       }
976       padding_config_dim->set_edge_padding_low(extra_left_padding);
977       padding_config_dim->set_edge_padding_high(extra_right_padding);
978       padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
979                                                   extra_left_padding +
980                                                   extra_right_padding);
981     }
982     concat = b->AddInstruction(HloInstruction::CreatePad(
983         padded_concat_shape, concat, pad_value, padding_config));
984   }
985 
986   auto valid_slice = concat;
987   if (shard_size_with_halo != concat->shape().dimensions(dim)) {
988     // Concat is bigger than the shard shape, so we need a dynamic slice.
989     CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
990     auto slice_shape = concat->shape();
991     slice_shape.set_dimensions(dim, shard_size_with_halo);
992 
993     if (left_halo_size_function.IsConstant() &&
994         left_halo_size_function.Calculate(0) ==
995             explicit_left_padding_on_full_shape) {
996       std::vector<int64> start_indices(slice_shape.rank(), 0);
997       std::vector<int64> strides(slice_shape.rank(), 1);
998       valid_slice = b->AddInstruction(
999           HloInstruction::CreateSlice(slice_shape, concat, start_indices,
1000                                       slice_shape.dimensions(), strides));
1001     } else {
1002       auto zero = b->AddInstruction(
1003           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
1004       std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
1005       slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
1006           partition_ordinal, b);
1007       valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
1008           slice_shape, concat, slice_offsets, slice_shape.dimensions()));
1009     }
1010   }
1011 
1012   if (!mask_invalid_region) {
1013     return valid_slice;
1014   }
1015 
1016   int64 total_right_padding = padded_full_shape_size -
1017                               base_shape.dimensions(dim) -
1018                               explicit_left_padding_on_full_shape;
1019   // Mask off garbage data due to uneven partition or low/high padding.
1020   if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
1021     auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
1022     auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
1023     auto broadcast_start_index_in_padded_shape =
1024         b->AddInstruction(HloInstruction::CreateBroadcast(
1025             index_shape, offset_on_padded_shape, {}));
1026     auto index_in_padded_shape = b->AddInstruction(
1027         HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
1028                                      broadcast_start_index_in_padded_shape));
1029     auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
1030     std::vector<HloInstruction*> predicates;
1031     if (explicit_left_padding_on_full_shape > 0) {
1032       auto valid_index_start =
1033           b->AddInstruction(HloInstruction::CreateBroadcast(
1034               index_shape,
1035               b->AddInstruction(
1036                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1037                       explicit_left_padding_on_full_shape))),
1038               {}));
1039       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1040           mask_shape, index_in_padded_shape, valid_index_start,
1041           ComparisonDirection::kGe)));
1042     }
1043     if (total_right_padding > 0) {
1044       auto valid_index_limit =
1045           b->AddInstruction(HloInstruction::CreateBroadcast(
1046               index_shape,
1047               b->AddInstruction(
1048                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1049                       base_shape.dimensions(dim) +
1050                       explicit_left_padding_on_full_shape))),
1051               {}));
1052       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1053           mask_shape, index_in_padded_shape, valid_index_limit,
1054           ComparisonDirection::kLt)));
1055     }
1056     CHECK(!predicates.empty());
1057     auto is_valid =
1058         predicates.size() == 2
1059             ? b->AddInstruction(HloInstruction::CreateBinary(
1060                   mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1061             : predicates[0];
1062     auto masking_value = b->AddInstruction(
1063         HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1064     valid_slice = b->AddInstruction(
1065         HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1066                                       is_valid, valid_slice, masking_value));
1067   }
1068   return valid_slice;
1069 }
1070 
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64> dims)1071 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1072                                         absl::Span<const int64> dims) {
1073   if (original.sharding().IsTileMaximal()) {
1074     return original.hlo();
1075   }
1076   // Create a window config to halo exchange for unevenly partitioned reverse
1077   // dimensions.
1078   Window window;
1079   for (int64 i = 0; i < original.base_shape().rank(); ++i) {
1080     WindowDimension* dim = window.add_dimensions();
1081     dim->set_size(1);
1082     dim->set_stride(1);
1083     dim->set_window_dilation(1);
1084     dim->set_window_reversal(false);
1085     int64 low_padding = 0;
1086     if (absl::c_linear_search(dims, i)) {
1087       low_padding =
1088           RoundUpToNearest(original.base_shape().dimensions(i),
1089                            original.sharding().tile_assignment().dim(i)) -
1090           original.base_shape().dimensions(i);
1091     }
1092     dim->set_padding_low(low_padding);
1093     dim->set_padding_high(0);
1094     dim->set_base_dilation(1);
1095   }
1096 
1097   auto reshard_window = original.ReshardAsWindowedInput(
1098       window, original.sharding(),
1099       CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1100                  original.state().b),
1101       /*mask_invalid_region=*/false);
1102   if (!reshard_window.has_value()) {
1103     return nullptr;
1104   }
1105   CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1106   return reshard_window->sharded_input;
1107 }
1108 
IsNanSafeGt(HloComputation * comp)1109 bool IsNanSafeGt(HloComputation* comp) {
1110   namespace m = match;
1111   auto match_bitcast_f32 = [](int64 parameter_number) {
1112     auto param = m::Parameter(parameter_number)
1113                      .WithShape(m::Shape().WithElementType(F32));
1114     auto param_s32 =
1115         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1116     auto param_u32 =
1117         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1118     return m::Select(
1119         m::Lt(param_s32, m::ConstantScalar(0)),
1120         m::BitcastConvert(
1121             m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1122                         param_u32))
1123             .WithShape(m::Shape().WithElementType(S32)),
1124         param_s32);
1125   };
1126   auto match_bitcast_bf16 = [](int64 parameter_number) {
1127     auto param = m::Convert(m::Parameter(parameter_number)
1128                                 .WithShape(m::Shape().WithElementType(BF16)))
1129                      .WithShape(m::Shape().WithElementType(F32));
1130     auto param_s32 =
1131         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1132     auto param_u32 =
1133         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1134     return m::Select(
1135         m::Lt(param_s32, m::ConstantScalar(0)),
1136         m::BitcastConvert(
1137             m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1138                         param_u32))
1139             .WithShape(m::Shape().WithElementType(S32)),
1140         param_s32);
1141   };
1142   // If root instruction is kSelect and compares indices if values are equal.
1143   if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1144     return Match(comp->root_instruction()->operand(2),
1145                  m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1146            Match(comp->root_instruction()->operand(2),
1147                  m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1148   }
1149   return Match(comp->root_instruction(),
1150                m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1151          Match(comp->root_instruction(),
1152                m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1153 }
1154 
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1155 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
1156   HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1157   if (sort == nullptr || sort->operand_count() != 2) {
1158     return absl::nullopt;
1159   }
1160   if (!IsNanSafeGt(sort->to_apply())) {
1161     return absl::nullopt;
1162   }
1163   HloInstruction* data = sort->mutable_operand(0);
1164   HloIotaInstruction* iota =
1165       DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1166   const PrimitiveType element_type = data->shape().element_type();
1167   if (iota == nullptr || iota->shape().element_type() != S32 ||
1168       iota->opcode() != HloOpcode::kIota ||
1169       iota->iota_dimension() != sort->sort_dimension()) {
1170     return absl::nullopt;
1171   }
1172 
1173   const int64 sort_dim = sort->sort_dimension();
1174 
1175   if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1176       element_type != U32) {
1177     return absl::nullopt;
1178   }
1179 
1180   bool supported = true;
1181   absl::optional<int64> k;
1182   for (HloInstruction* gte : sort->users()) {
1183     if (gte->opcode() != HloOpcode::kGetTupleElement) {
1184       supported = false;
1185       break;
1186     }
1187 
1188     const HloInstruction* slice = gte->users()[0];
1189     if (slice->opcode() != HloOpcode::kSlice) {
1190       // Non-slice user means we are not doing a TopK
1191       supported = false;
1192       break;
1193     }
1194     if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1195         absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1196       // Strided slice or slicing at the beginning isn't supported.
1197       supported = false;
1198       break;
1199     }
1200     for (int64 dim = 0; dim < data->shape().dimensions_size(); dim++) {
1201       if (dim == sort_dim) {
1202         continue;
1203       }
1204       if (slice->slice_limits(dim) !=
1205           slice->operand(0)->shape().dimensions(dim)) {
1206         // Slicing along the other dimension isn't supported.
1207         supported = false;
1208         break;
1209       }
1210     }
1211     if (!k.has_value()) {
1212       k = slice->slice_limits(sort_dim);
1213     } else if (k != slice->slice_limits(sort_dim)) {
1214       // Different k for the different operands isn't supported.
1215       supported = false;
1216       break;
1217     }
1218   }
1219   if (k == absl::nullopt || !supported) {
1220     return absl::nullopt;
1221   }
1222 
1223   // Only support when sort dim is sharded.
1224   if (!data->has_sharding()) {
1225     return absl::nullopt;
1226   }
1227   const HloSharding& sharding = sort->operand(0)->sharding();
1228 
1229   if (sharding.IsTileMaximal()) {
1230     return absl::nullopt;
1231   }
1232 
1233   // Check if partitioned at sort dimension.
1234   for (int64 dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1235        ++dim) {
1236     if (sharding.tile_assignment().dim(dim) > 1) {
1237       if (dim != sort_dim) {
1238         return absl::nullopt;
1239       }
1240     }
1241   }
1242 
1243   // Checks if partition size is smaller than k.
1244   const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
1245 
1246   if (shard_count <= 1) {
1247     return absl::nullopt;
1248   }
1249 
1250   const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1251   const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
1252 
1253   if (k.value() >= per_partition_size) {
1254     return absl::nullopt;
1255   }
1256 
1257   return k;
1258 }
1259 
1260 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64 slice_dim,int64 k)1261 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1262                             int64 slice_dim, int64 k) {
1263   const Shape& hlo_shape = hlo->shape();
1264   auto hlo_dims = hlo_shape.dimensions();
1265   std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
1266   std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
1267   std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
1268   limit_indices[slice_dim] = k;
1269   auto output_shape = hlo_shape;
1270   output_shape.set_dimensions(slice_dim, k);
1271   return builder->AddInstruction(HloInstruction::CreateSlice(
1272       output_shape, hlo, start_indices, limit_indices, strides));
1273 }
1274 
1275 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64 dim)1276 int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
1277   if (sharding.IsTileMaximal()) {
1278     return 1;
1279   }
1280   return sharding.tile_assignment().dim(dim);
1281 }
1282 
1283 absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1284 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1285                                    const HloSharding& target) {
1286   if (source.IsTileMaximal() || target.IsTileMaximal() ||
1287       source.tile_assignment().num_dimensions() !=
1288           target.tile_assignment().num_dimensions() ||
1289       source.NumTiles() != target.NumTiles()) {
1290     return absl::nullopt;
1291   }
1292   // Record partition count to index for indices that have different partition
1293   // counts on source and target.
1294   std::map<int64, std::vector<int64>> source_size_to_dim;
1295   std::map<int64, std::vector<int64>> target_size_to_dim;
1296   for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1297     if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1298       continue;
1299     }
1300     source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1301     target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1302   }
1303   // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1304   // must have the same distribution.
1305   if (source_size_to_dim.empty() ||
1306       source_size_to_dim.size() != target_size_to_dim.size()) {
1307     return absl::nullopt;
1308   }
1309   for (const auto& entry : source_size_to_dim) {
1310     auto target_it = target_size_to_dim.find(entry.first);
1311     if (target_it == target_size_to_dim.end() ||
1312         target_it->second.size() != entry.second.size()) {
1313       return absl::nullopt;
1314     }
1315   }
1316   std::vector<std::pair<int64, int64>> result;
1317   auto remove_entry = [](int64 size, int64 dim,
1318                          std::map<int64, std::vector<int64>>& size_to_dim) {
1319     size_to_dim[size].erase(
1320         std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1321                        [dim](int64 a) { return a == dim; }),
1322         size_to_dim[size].end());
1323     if (size_to_dim[size].empty()) {
1324       size_to_dim.erase(size);
1325     }
1326   };
1327   // Find one pair of dimensions to swap at a time.
1328   while (!source_size_to_dim.empty()) {
1329     int64 source_size = source_size_to_dim.begin()->first;
1330     int64 i = source_size_to_dim.begin()->second.back();
1331     int64 target_i_size = target.tile_assignment().dim(i);
1332     if (target_i_size == source_size) {
1333       remove_entry(source_size, i, source_size_to_dim);
1334       remove_entry(source_size, i, target_size_to_dim);
1335       continue;
1336     }
1337     auto j_it = source_size_to_dim[target_i_size].begin();
1338     int64 j = *j_it;
1339     if (source_size == 1) {
1340       // If possible, find a j where the target partition count is not one, so
1341       // that when we swap, the resulting size-1 dimension will still be useful
1342       // to other dimensions.
1343       while (target.tile_assignment().dim(j) == 1) {
1344         if (++j_it == source_size_to_dim[target_i_size].end()) {
1345           break;
1346         }
1347         j = *j_it;
1348       }
1349     } else if (target_i_size % source_size == 0) {
1350       // If possible, find a j where the target partition count is source_size,
1351       // so that we can do a single swap.
1352       while (target.tile_assignment().dim(j) != source_size) {
1353         if (++j_it == source_size_to_dim[target_i_size].end()) {
1354           break;
1355         }
1356         j = *j_it;
1357       }
1358     } else {
1359       return absl::nullopt;
1360     }
1361     result.emplace_back(j, i);
1362     remove_entry(target_i_size, i, target_size_to_dim);
1363     source_size_to_dim.begin()->second.back() = j;
1364     remove_entry(target_i_size, j, source_size_to_dim);
1365   }
1366   return result;
1367 }
1368 
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1369 bool CanReshardWithCollectivePermute(const HloSharding& source,
1370                                      const HloSharding& target) {
1371   return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1372          source.tile_assignment().dimensions() ==
1373              target.tile_assignment().dimensions() &&
1374          source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1375          source.tile_assignment() != target.tile_assignment();
1376 }
1377 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims)1378 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1379                                     absl::Span<const int64> group_dims) {
1380   std::vector<int64> group_dim_shards(group_dims.size(), 1);
1381   return GroupShardingOnDims(sharding, group_dims, group_dim_shards);
1382 }
1383 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_shards)1384 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1385                                     absl::Span<const int64> group_dims,
1386                                     absl::Span<const int64> group_dim_shards) {
1387   CHECK(!sharding.IsTileMaximal());
1388   std::vector<int64> grouped_tiling_dims =
1389       sharding.tile_assignment().dimensions();
1390   std::vector<int64> group_dim_sizes(group_dims.size());
1391   for (int64 i = 0; i < group_dims.size(); ++i) {
1392     CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1393     group_dim_sizes[i] =
1394         grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1395     grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1396   }
1397 
1398   std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
1399   sharding.tile_assignment().Each(
1400       [&](absl::Span<const int64> indices, int64 device) {
1401         int64 group_id = 0;
1402         for (int64 i = 0; i < group_dims.size(); ++i) {
1403           group_id *= sharding.tile_assignment().dim(group_dims[i]) /
1404                       group_dim_shards[i];
1405           group_id += indices[group_dims[i]] / group_dim_shards[i];
1406         }
1407         device_groups[group_id].push_back(device);
1408       });
1409   auto grouped = GroupedSharding(
1410       std::move(device_groups),
1411       std::vector<int64>(group_dims.begin(), group_dims.end()),
1412       std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1413       HloSharding::Replicate());
1414   if (sharding.ReplicateOnLastTileDim()) {
1415     grouped.data_rank--;
1416   }
1417   if (Product(grouped_tiling_dims) == 1 ||
1418       (sharding.ReplicateOnLastTileDim() &&
1419        Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1420     return grouped;
1421   }
1422   if (sharding.ReplicateOnLastTileDim() && grouped_tiling_dims.back() == 1) {
1423     grouped_tiling_dims.pop_back();
1424   }
1425   Array<int64> grouped_tiling(grouped_tiling_dims);
1426   grouped_tiling.FillIota(0);
1427   grouped.sharding = sharding.ReplicateOnLastTileDim() &&
1428                              grouped_tiling_dims.size() ==
1429                                  sharding.tile_assignment().num_dimensions()
1430                          ? HloSharding::PartialTile(grouped_tiling)
1431                          : HloSharding::Tile(grouped_tiling);
1432   return grouped;
1433 }
1434 
UngroupSharding(const GroupedSharding & grouped_sharding)1435 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1436   std::vector<int64> tiling_dims;
1437   bool partial_sharding = false;
1438   auto grouped_tiling = grouped_sharding.sharding.tile_assignment();
1439   if (grouped_sharding.sharding.IsTileMaximal()) {
1440     tiling_dims = std::vector<int64>(grouped_sharding.data_rank, 1);
1441     if (grouped_sharding.device_groups[0].size() != 1) {
1442       // This is partial sharding.
1443       tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1444       partial_sharding = true;
1445     }
1446     grouped_tiling = Array<int64>(tiling_dims);
1447     grouped_tiling.FillIota(0);
1448   } else {
1449     partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1450     tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1451     if (absl::c_linear_search(grouped_sharding.group_dims,
1452                               tiling_dims.size())) {
1453       tiling_dims.push_back(1);
1454       grouped_tiling.Reshape(tiling_dims);
1455       partial_sharding = true;
1456     }
1457   }
1458   for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1459     int64 dim = grouped_sharding.group_dims[i];
1460     tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1461   }
1462   Array<int64> tiling(tiling_dims);
1463   grouped_tiling.Each([&](absl::Span<const int64> indices, int64 device) {
1464     std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
1465     for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1466       int64 remaining_group_index = g;
1467       for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1468         int64 dim = grouped_sharding.group_dims[i];
1469         int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1470         ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1471                                   grouped_tiling.dim(dim) +
1472                               indices[dim];
1473         remaining_group_index /= groups_in_this_dim;
1474       }
1475       tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1476     }
1477   });
1478   return partial_sharding ? HloSharding::PartialTile(tiling)
1479                           : HloSharding::Tile(tiling);
1480 }
1481 
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1482 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1483                                 const GroupedSharding& reference,
1484                                 bool ignore_group_order) {
1485   // Returns src -> dst index mapping.
1486   auto get_permutation = [](absl::Span<const int64> src,
1487                             absl::Span<const int64> dst) {
1488     CHECK_EQ(src.size(), dst.size());
1489     absl::flat_hash_map<int64, int64> dst_reverse_map;
1490     for (int64 i = 0; i < dst.size(); ++i) {
1491       dst_reverse_map[dst[i]] = i;
1492     }
1493     std::vector<int64> permutation(src.size());
1494     for (int64 i = 0; i < src.size(); ++i) {
1495       auto it = dst_reverse_map.find(src[i]);
1496       CHECK(it != dst_reverse_map.end());
1497       permutation[i] = it->second;
1498     }
1499     return permutation;
1500   };
1501   CHECK_EQ(grouped_sharding.device_groups.size(),
1502            reference.device_groups.size());
1503   absl::flat_hash_map<int64, int64> device_to_ref_group;
1504   for (int64 g = 0; g < reference.device_groups.size(); ++g) {
1505     for (int64 device : reference.device_groups[g]) {
1506       device_to_ref_group[device] = g;
1507     }
1508   }
1509   auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
1510     int64 ref_g = -1;
1511     for (int64 device : devices) {
1512       if (ref_g == -1) {
1513         ref_g = device_to_ref_group[device];
1514       } else if (ref_g != device_to_ref_group[device]) {
1515         return -1;
1516       }
1517     }
1518     return ref_g;
1519   };
1520   bool matching_groups = true;
1521   std::vector<int64> original_src_to_ref_permutation;
1522   for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1523     int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1524     if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1525       matching_groups = false;
1526       break;
1527     }
1528     if (g == 0) {
1529       original_src_to_ref_permutation = get_permutation(
1530           grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1531     }
1532   }
1533   if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1534     auto tiles = grouped_sharding.sharding.tile_assignment();
1535     tiles.Each([&](absl::Span<const int64> indices, int64* device) {
1536       *device = original_src_to_ref_permutation[*device];
1537     });
1538     grouped_sharding.sharding =
1539         grouped_sharding.sharding.ReplicateOnLastTileDim()
1540             ? HloSharding::PartialTile(tiles)
1541             : HloSharding::Tile(tiles);
1542   }
1543   grouped_sharding.device_groups = std::move(reference.device_groups);
1544   return grouped_sharding;
1545 }
1546 
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64> sharding_dims,const HloSharding & reference,absl::Span<const int64> reference_dims)1547 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1548                                 absl::Span<const int64> sharding_dims,
1549                                 const HloSharding& reference,
1550                                 absl::Span<const int64> reference_dims) {
1551   auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims);
1552   auto reference_grouped = GroupShardingOnDims(reference, reference_dims);
1553   return UngroupSharding(AlignGroupsWith(sharding_grouped, reference_grouped));
1554 }
1555 
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1556 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1557                            const Shape& original_base_shape) {
1558   auto result = original_base_shape;
1559   for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1560     int64 dim = grouped_sharding.group_dims[i];
1561     if (dim >= original_base_shape.rank()) {
1562       continue;
1563     }
1564     int64 groups = grouped_sharding.group_dim_sizes[i];
1565     result.set_dimensions(dim, result.dimensions(dim) / groups);
1566   }
1567   return result;
1568 }
1569 
1570 namespace {
1571 
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1572 HloInstruction* GetInGroupPartitionId(
1573     HloInstruction* partition_id,
1574     const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1575   int64 total_devices = device_groups.size() * device_groups[0].size();
1576   std::vector<uint32> in_group_ids(total_devices);
1577   for (uint32 i = 0; i < device_groups.size(); ++i) {
1578     for (uint32 j = 0; j < device_groups[i].size(); ++j) {
1579       in_group_ids[device_groups[i][j]] = j;
1580     }
1581   }
1582   auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1583       LiteralUtil::CreateR1<uint32>(in_group_ids)));
1584   return b->AddInstruction(HloInstruction::CreateReshape(
1585       ShapeUtil::MakeScalarShape(U32),
1586       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1587           ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1588 }
1589 
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64>> & device_groups)1590 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1591     const SPMDCollectiveOpsCreator& creator,
1592     const std::vector<std::vector<int64>>& device_groups) {
1593   SPMDCollectiveOpsCreator result;
1594   result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1595     return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1596                                  b);
1597   };
1598   auto expand_partition_groups =
1599       [device_groups](
1600           const std::vector<std::vector<int64>>& partition_subgroups) {
1601         if (partition_subgroups.empty()) {
1602           return device_groups;
1603         }
1604         std::vector<std::vector<int64>> result(partition_subgroups.size() *
1605                                                device_groups.size());
1606         for (int64 g = 0; g < device_groups.size(); ++g) {
1607           for (int64 i = 0; i < partition_subgroups.size(); ++i) {
1608             result[g * partition_subgroups.size() + i].resize(
1609                 partition_subgroups[i].size());
1610             for (int64 j = 0; j < partition_subgroups[i].size(); ++j) {
1611               result[g * partition_subgroups.size() + i][j] =
1612                   device_groups[g][partition_subgroups[i][j]];
1613             }
1614           }
1615         }
1616         return result;
1617       };
1618   result.create_cross_partition_all_reduce =
1619       [creator, expand_partition_groups](
1620           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1621           const std::vector<std::vector<int64>>& partition_subgroups,
1622           int64 channel_id) {
1623         return creator.create_cross_partition_all_reduce(
1624             b, operand, reduction, expand_partition_groups(partition_subgroups),
1625             channel_id);
1626       };
1627   result.create_cross_partition_collective_permute =
1628       [creator, device_groups](
1629           SpmdBuilder* b, HloInstruction* operand,
1630           std::vector<std::pair<int64, int64>>& src_dst_pairs,
1631           int64 next_channel_id) {
1632         std::vector<std::pair<int64, int64>> expanded_pairs(
1633             src_dst_pairs.size() * device_groups.size());
1634         for (int64 g = 0; g < device_groups.size(); ++g) {
1635           for (int64 i = 0; i < src_dst_pairs.size(); ++i) {
1636             expanded_pairs[g * src_dst_pairs.size() + i] =
1637                 std::pair<int64, int64>{
1638                     device_groups[g][src_dst_pairs[i].first],
1639                     device_groups[g][src_dst_pairs[i].second]};
1640           }
1641         }
1642         return creator.create_cross_partition_collective_permute(
1643             b, operand, expanded_pairs, next_channel_id);
1644       };
1645   result.create_cross_partition_all_to_all =
1646       [creator, expand_partition_groups](
1647           SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1648           const std::vector<std::vector<int64>>& partition_subgroups,
1649           int64 channel_id, absl::optional<int64> split_dimension) {
1650         return creator.create_cross_partition_all_to_all(
1651             b, operands, expand_partition_groups(partition_subgroups),
1652             channel_id, split_dimension);
1653       };
1654   if (creator.create_cross_partition_all_gather) {
1655     result.create_cross_partition_all_gather =
1656         [creator, expand_partition_groups](
1657             SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1658             const std::vector<std::vector<int64>>& partition_subgroups,
1659             int64 channel_id, int64 all_gather_dimension) {
1660           return creator.create_cross_partition_all_gather(
1661               b, operand, ag_shape,
1662               expand_partition_groups(partition_subgroups), channel_id,
1663               all_gather_dimension);
1664         };
1665   }
1666   return result;
1667 }
1668 
1669 }  // namespace
1670 
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1671 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1672     const PartitionedHlo::PartitioningState& state,
1673     const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1674   auto result = state;
1675   result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1676       state.collective_ops_creator, device_groups);
1677   result.partition_id =
1678       GetInGroupPartitionId(state.partition_id, device_groups, b);
1679   // Create a string key for the groups.
1680   std::vector<std::string> per_group_strings(device_groups.size());
1681   for (int64 i = 0; i < per_group_strings.size(); ++i) {
1682     per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1683   }
1684   auto& grouped_cache =
1685       state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1686   if (!grouped_cache) {
1687     grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
1688   }
1689   result.reshard_cache = grouped_cache.get();
1690   return result;
1691 }
1692 
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_sizes,SpmdBuilder * b)1693 HloInstruction* PerGroupSliceFromReplicated(
1694     HloInstruction* replicated, HloInstruction* partition_id,
1695     const std::vector<std::vector<int64>>& device_groups,
1696     absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
1697     SpmdBuilder* b) {
1698   std::vector<uint32> group_ids(device_groups.size() * device_groups[0].size());
1699   for (int64 g = 0; g < device_groups.size(); ++g) {
1700     for (int64 device : device_groups[g]) {
1701       group_ids[device] = g;
1702     }
1703   }
1704   auto group_id_table = b->AddInstruction(
1705       HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint32>(group_ids)));
1706   auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1707       ShapeUtil::MakeScalarShape(U32),
1708       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1709           ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1710           {1}))));
1711   std::vector<int64> group_level_tile_dims(replicated->shape().rank(), 1);
1712   for (int64 i = 0; i < group_dims.size(); ++i) {
1713     group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1714   }
1715   Array<int64> group_level_tile(group_level_tile_dims);
1716   group_level_tile.Each([&](absl::Span<const int64> indices, int64* group) {
1717     *group = 0;
1718     for (int64 dim : group_dims) {
1719       *group *= group_level_tile.dim(dim);
1720       *group += indices[dim];
1721     }
1722   });
1723   auto group_level_sharding = HloSharding::Tile(group_level_tile);
1724   auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1725       replicated, group_level_sharding, b);
1726   auto shard_shape =
1727       MakePartitionedShape(replicated->shape(), group_level_sharding);
1728   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1729       shard_shape, padded_hlo,
1730       MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1731                            b),
1732       shard_shape.dimensions()));
1733 }
1734 
ParseReductionComputation(const HloComputation * reduction_comp)1735 absl::optional<HloOpcode> ParseReductionComputation(
1736     const HloComputation* reduction_comp) {
1737   if (reduction_comp->num_parameters() != 2) {
1738     return absl::nullopt;
1739   }
1740   auto root = reduction_comp->root_instruction();
1741   if (!root->IsElementwiseBinary()) {
1742     return absl::nullopt;
1743   }
1744   if (!absl::c_linear_search(root->operands(),
1745                              reduction_comp->parameter_instruction(0)) ||
1746       !absl::c_linear_search(root->operands(),
1747                              reduction_comp->parameter_instruction(1))) {
1748     return absl::nullopt;
1749   }
1750   return root->opcode();
1751 }
1752 
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64>> & device_groups)1753 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
1754     const HloSharding& sharding,
1755     const std::vector<std::vector<int64>>& device_groups) {
1756   if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 ||
1757       device_groups[0].size() < 2) {
1758     return absl::nullopt;
1759   }
1760   int64 rank = sharding.tile_assignment().num_dimensions();
1761   if (sharding.ReplicateOnLastTileDim()) {
1762     rank--;
1763   }
1764   absl::flat_hash_map<int64, std::vector<int64>> device_to_index;
1765   sharding.tile_assignment().Each(
1766       [&](absl::Span<const int64> index, int64 device) {
1767         device_to_index[device] =
1768             std::vector<int64>(index.begin(), index.begin() + rank);
1769       });
1770   std::vector<int64> dims;
1771   int64 group_count = 1;
1772   for (int64 i = 0; i < rank; ++i) {
1773     if (device_to_index[device_groups[0][0]][i] ==
1774         device_to_index[device_groups[0][1]][i]) {
1775       dims.push_back(i);
1776       group_count *= sharding.tile_assignment().dim(i);
1777     }
1778   }
1779   if (group_count != device_groups.size()) {
1780     return absl::nullopt;
1781   }
1782   for (const auto& group : device_groups) {
1783     for (int64 i = 1; i < group.size(); ++i) {
1784       if (absl::c_any_of(dims, [&](const int64 dim) {
1785             return device_to_index[group[i]][dim] !=
1786                    device_to_index[group[0]][dim];
1787           })) {
1788         return absl::nullopt;
1789       }
1790     }
1791   }
1792   return dims;
1793 }
1794 
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64> target_dims,absl::Span<const int64> source_dims)1795 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
1796                                          const HloSharding& source_sharding,
1797                                          absl::Span<const int64> target_dims,
1798                                          absl::Span<const int64> source_dims) {
1799   CHECK(target_dims.size() == source_dims.size())
1800       << "Expected 1:1 match between parallel dimensions";
1801   if (source_sharding.IsReplicated()) {
1802     return HloSharding::Replicate();
1803   }
1804   absl::InlinedVector<int64, 4> tile_dims(target_shape.dimensions_size(), 1);
1805   int num_tiles = 1;
1806   for (int i = 0, end = target_dims.size(); i < end; ++i) {
1807     num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1808     tile_dims[target_dims[i]] =
1809         source_sharding.tile_assignment().dim(source_dims[i]);
1810   }
1811   // If there is some partition across non-parallel dimensions in the
1812   // other operand then partially replicate for the new
1813   bool to_be_partially_replicated = false;
1814   if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1815     CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1816     to_be_partially_replicated = true;
1817     tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1818                         num_tiles);
1819   }
1820   auto tgt_tile_assignment = source_sharding.tile_assignment();
1821   tgt_tile_assignment.Reshape(tile_dims);
1822   if (to_be_partially_replicated) {
1823     return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1824                                target_dims, source_sharding, source_dims);
1825   } else {
1826     return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1827                                target_dims, source_sharding, source_dims);
1828   }
1829 }
1830 
1831 absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1832 GatherOperandsShardedAcrossParallelDims(
1833     const HloInstruction& operand, const HloInstruction& indices,
1834     const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1835   auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1836   auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1837   if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1838     return absl::nullopt;
1839   }
1840   auto new_index_shard = indices.sharding();
1841   auto new_operand_shard = operand.sharding();
1842   int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1843   int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1844   if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1845     return absl::nullopt;
1846   }
1847   absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
1848   for (int idx : parallel_dims.index_parallel_in_dim) {
1849     if (idx != -1) {
1850       indices_parallel_dims_ordered_as_operand.push_back(idx);
1851     }
1852   }
1853   if (new_index_shard.IsReplicated()) {
1854     return GatherParallelDimSharding{
1855         CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1856                                      indices_parallel_dims_ordered_as_operand,
1857                                      operand_parallel_dims),
1858         new_operand_shard};
1859   }
1860   if (new_operand_shard.IsReplicated()) {
1861     return GatherParallelDimSharding{
1862         new_index_shard,
1863         CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1864                                      operand_parallel_dims,
1865                                      indices_parallel_dims_ordered_as_operand)};
1866   }
1867 
1868   // Parallel dimension distribution needs to be the same, so try to steal
1869   // sharding from partial replication to compensate.
1870   if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1871     auto to_adjust_dims = operand_parallel_dims;
1872     auto target_dims = indices_parallel_dims_ordered_as_operand;
1873     HloSharding* target = &new_index_shard;
1874     HloSharding* to_adjust = &new_operand_shard;
1875     if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1876       std::swap(to_adjust_dims, target_dims);
1877       std::swap(to_adjust, target);
1878     }
1879     if (!to_adjust->ReplicateOnLastTileDim()) {
1880       return absl::nullopt;
1881     }
1882     auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1883     for (int i = 0; i < to_adjust_dims.size(); ++i) {
1884       int64 target_dim = target->tile_assignment().dim(target_dims[i]);
1885       int64 to_adjust_dim = to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1886       if (target_dim < to_adjust_dim) {
1887         return absl::nullopt;
1888       }
1889       if (target_dim == to_adjust_dim) {
1890         continue;
1891       }
1892       int64 ratio = target_dim / to_adjust_dim;
1893       if (target_dim % to_adjust_dim != 0 ||
1894           new_tile_assignment_dims.back() % ratio != 0) {
1895         return absl::nullopt;
1896       }
1897       new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1898       new_tile_assignment_dims.back() /= ratio;
1899     }
1900     CHECK_GE(new_tile_assignment_dims.back(), 1);
1901     bool to_partially_replicate = true;
1902     if (new_tile_assignment_dims.back() == 1) {
1903       new_tile_assignment_dims.pop_back();
1904       to_partially_replicate = false;
1905     }
1906     auto new_tile_assignment = to_adjust->tile_assignment();
1907     new_tile_assignment.Reshape(new_tile_assignment_dims);
1908     if (to_partially_replicate) {
1909       *to_adjust =
1910           AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1911                               to_adjust_dims, *target, target_dims);
1912     } else {
1913       *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1914                                        to_adjust_dims, *target, target_dims);
1915     }
1916   }
1917   // Make sure that the parallel dimensions are aligned.
1918   auto operand_shard_tile_dims =
1919       new_operand_shard.tile_assignment().dimensions();
1920   for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1921     operand_shard_tile_dims[operand_parallel_dims[i]] =
1922         new_index_shard.tile_assignment().dim(
1923             indices_parallel_dims_ordered_as_operand[i]);
1924   }
1925   auto operand_shard_tiles = new_operand_shard.tile_assignment();
1926   operand_shard_tiles.Reshape(operand_shard_tile_dims);
1927   new_operand_shard =
1928       AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1929                               ? HloSharding::PartialTile(operand_shard_tiles)
1930                               : HloSharding::Tile(operand_shard_tiles),
1931                           operand_parallel_dims, new_index_shard,
1932                           indices_parallel_dims_ordered_as_operand);
1933   return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1934 }
1935 
1936 }  // namespace spmd
1937 }  // namespace xla
1938