1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17 
18 #include <float.h>
19 
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/client/lib/comparators.h"
33 #include "tensorflow/compiler/xla/comparison_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/protobuf_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cse.h"
41 #include "tensorflow/compiler/xla/service/hlo_dce.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
46 #include "tensorflow/compiler/xla/service/hlo_query.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
49 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
50 #include "tensorflow/compiler/xla/service/shape_inference.h"
51 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
52 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/window_util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/numbers.h"
58 
59 namespace xla {
60 namespace spmd {
61 
MakeReport()62 string SpmdLogger::MakeReport() {
63   string report;
64   absl::StrAppend(&report,
65                   "\n\n***** SPMD memory during transformation *****\n");
66 
67   std::sort(entries_.begin(), entries_.end(),
68             [](auto const& entry0, auto const& entry1) {
69               return entry0.first > entry1.first;
70             });
71   for (int64 i = 0;
72        i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
73     absl::StrAppend(
74         &report, "\n  ",
75         tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
76         entries_[i].second, "\n");
77   }
78 
79   return report;
80 }
81 
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)82 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
83                                   const std::vector<HloInstruction*>& group) {
84   string report = hlo->ToString();
85   int64 max_value = -1;
86   for (HloInstruction* inst : group) {
87     if (!inst->shape().IsArray()) {
88       continue;
89     }
90     max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
91     absl::StrAppend(&report, "     * ", inst->ToString(), "\n");
92   }
93   entries_.push_back(std::make_pair(max_value, report));
94 }
95 
ReportBeforePartition(const HloModule & module,int64 report_instruction_count)96 /* static */ string SpmdLogger::ReportBeforePartition(
97     const HloModule& module, int64 report_instruction_count) {
98   string report;
99   absl::StrAppend(&report,
100                   "\n\n***** SPMD memory usage before partition *****\n");
101   absl::StrAppend(&report, "\n  ** Replicated instructions\n");
102   absl::StrAppend(&report, ReportMemoryUsage(
103                                module,
104                                [](const HloInstruction* hlo) {
105                                  return !hlo->has_sharding() ||
106                                         hlo->sharding().IsReplicated();
107                                },
108                                report_instruction_count));
109   absl::StrAppend(&report, "\n  ** All instructions\n");
110   absl::StrAppend(&report,
111                   ReportMemoryUsage(
112                       module, [](const HloInstruction* hlo) { return true; },
113                       report_instruction_count));
114   return report;
115 }
116 
ReportAfterPartition(const HloModule & module,int64 report_instruction_count)117 /* static */ string SpmdLogger::ReportAfterPartition(
118     const HloModule& module, int64 report_instruction_count) {
119   string report;
120   absl::StrAppend(&report,
121                   "\n\n***** SPMD memory usage after partition *****\n");
122   absl::StrAppend(&report,
123                   ReportMemoryUsage(
124                       module, [](const HloInstruction* hlo) { return true; },
125                       report_instruction_count));
126   return report;
127 }
128 
129 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64 report_instruction_count)130 /* static */ string SpmdLogger::ReportMemoryUsage(
131     const HloModule& module, const F& filter, int64 report_instruction_count) {
132   string report;
133   std::vector<HloInstruction*> instructions;
134   instructions.reserve(module.instruction_count());
135 
136   for (auto computation : module.computations()) {
137     if (computation->IsFusionComputation()) {
138       continue;
139     }
140     for (auto hlo : computation->instructions()) {
141       if (hlo->shape().IsTuple() ||
142           ShapeUtil::IsEffectiveScalar(hlo->shape())) {
143         continue;
144       }
145       if (filter(hlo)) {
146         instructions.push_back(hlo);
147       }
148     }
149   }
150 
151   const auto add_report = [&](std::vector<HloInstruction*>* insts) {
152     std::sort(insts->begin(), insts->end(),
153               [](const HloInstruction* inst0, const HloInstruction* inst1) {
154                 return ShapeSizeInBytes(inst0->shape()) >
155                        ShapeSizeInBytes(inst1->shape());
156               });
157     for (int64 i = 0;
158          i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
159       absl::StrAppend(&report, "  ",
160                       tensorflow::strings::HumanReadableNumBytes(
161                           ShapeSizeInBytes((*insts)[i]->shape())),
162                       " : ", (*insts)[i]->ToString(), "\n");
163     }
164   };
165 
166   add_report(&instructions);
167   return report;
168 }
169 
170 namespace {
171 
172 // Clears all sharding attributes from instructions in the module. This must be
173 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module)174 Status ClearShardingAttributes(HloModule* module) {
175   for (HloComputation* computation : module->computations()) {
176     for (HloInstruction* hlo : computation->instructions()) {
177       // Keep sharding annotation on Infeed and entry parameters since they're
178       // used by HloReplicationAnalysis later (for ArCrsCombiner).
179       if (hlo->opcode() == HloOpcode::kInfeed) {
180         continue;
181       }
182       if (hlo->opcode() == HloOpcode::kParameter &&
183           computation == module->entry_computation()) {
184         continue;
185       }
186       hlo->clear_sharding();
187     }
188   }
189   return Status::OK();
190 }
191 
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64> replication_dims)192 std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
193     const HloSharding& sharding, absl::Span<const int64> replication_dims) {
194   int64 group_size = 1;
195   for (int64 i : replication_dims) {
196     group_size *= sharding.tile_assignment().dim(i);
197   }
198   std::vector<std::vector<int64>> partition_groups(
199       sharding.tile_assignment().num_elements() / group_size);
200   sharding.tile_assignment().Each(
201       [&](absl::Span<const int64> indices, int64 partition) {
202         int64 group_id = 0;
203         for (int64 i = 0; i < indices.size(); ++i) {
204           if (!absl::c_linear_search(replication_dims, i)) {
205             group_id *= sharding.tile_assignment().dim(i);
206             group_id += indices[i];
207           }
208         }
209         partition_groups[group_id].push_back(partition);
210       });
211   return partition_groups;
212 }
213 
214 }  // namespace
215 
AddInstruction(std::unique_ptr<HloInstruction> instruction)216 HloInstruction* SpmdBuilder::AddInstruction(
217     std::unique_ptr<HloInstruction> instruction) {
218   HloInstruction* hlo =
219       HloComputation::Builder::AddInstruction(std::move(instruction));
220   if (visiting_hlo_) {
221     instructions_[visiting_hlo_].push_back(hlo);
222   }
223   if (hlo->opcode() == HloOpcode::kBroadcast) {
224     for (int64 i = 0; i < hlo->shape().rank(); ++i) {
225       if (!absl::c_linear_search(hlo->dimensions(), i)) {
226         broadcast_dims_[hlo].insert(i);
227       }
228     }
229   }
230   if (hlo->IsElementwise() && hlo->operand_count() > 0) {
231     absl::flat_hash_set<int64> broadcast_dims;
232     for (int64 i = 0; i < hlo->shape().rank(); ++i) {
233       broadcast_dims.insert(i);
234     }
235     for (int64 i = 0; i < hlo->operand_count(); ++i) {
236       auto it = broadcast_dims_.find(hlo->operand(i));
237       if (it == broadcast_dims_.end()) {
238         broadcast_dims.clear();
239         break;
240       }
241       for (int64 i = 0; i < hlo->shape().rank(); ++i) {
242         if (!it->second.contains(i)) {
243           broadcast_dims.erase(i);
244         }
245       }
246     }
247     if (!broadcast_dims.empty()) {
248       broadcast_dims_[hlo] = std::move(broadcast_dims);
249     }
250   }
251   if (hlo->opcode() == HloOpcode::kTranspose) {
252     auto it = broadcast_dims_.find(hlo->operand(0));
253     if (it != broadcast_dims_.end()) {
254       absl::flat_hash_set<int64> xpose_broadcast_dims;
255       std::vector<int64> reverse_map(hlo->shape().rank());
256       for (int64 i = 0; i < reverse_map.size(); ++i) {
257         reverse_map[hlo->dimensions(i)] = i;
258       }
259       for (int64 dim : it->second) {
260         xpose_broadcast_dims.insert(reverse_map[dim]);
261       }
262       broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
263     }
264   }
265   if (hlo->opcode() == HloOpcode::kReshape &&
266       Product(hlo->shape().dimensions()) > 0) {
267     auto it = broadcast_dims_.find(hlo->operand(0));
268     if (it != broadcast_dims_.end()) {
269       absl::flat_hash_set<int64> reshape_broadcast_dims;
270       for (int64 i = 0; i < hlo->shape().rank(); ++i) {
271         reshape_broadcast_dims.insert(i);
272       }
273       std::vector<int64> before_dim_size_stack;
274       std::vector<int64> after_dim_size_stack;
275       for (int64 i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) {
276         before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
277       }
278       for (int64 i = hlo->shape().rank() - 1; i >= 0; --i) {
279         after_dim_size_stack.push_back(hlo->shape().dimensions(i));
280       }
281       while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
282         int64 before_size = before_dim_size_stack.back();
283         int64 after_size = after_dim_size_stack.back();
284         int64 current_before_dim =
285             hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
286         int64 current_after_dim =
287             hlo->shape().rank() - after_dim_size_stack.size();
288         before_dim_size_stack.pop_back();
289         after_dim_size_stack.pop_back();
290         if (!it->second.contains(current_before_dim)) {
291           reshape_broadcast_dims.erase(current_after_dim);
292         }
293         if (before_size == after_size) {
294           continue;
295         }
296         if (before_size % after_size == 0) {
297           // Split dim.
298           before_dim_size_stack.push_back(before_size / after_size);
299         } else if (after_size % before_size == 0) {
300           // Merge dim.
301           after_dim_size_stack.push_back(after_size / before_size);
302         } else {
303           // Other cases, mark all remaining dims as non-broadcast.
304           for (int64 i = current_after_dim; i < hlo->shape().rank(); ++i) {
305             reshape_broadcast_dims.erase(i);
306           }
307           break;
308         }
309       }
310       if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
311         reshape_broadcast_dims.clear();
312       }
313       if (!reshape_broadcast_dims.empty()) {
314         broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
315       }
316     }
317   }
318   if (hlo->opcode() == HloOpcode::kSlice ||
319       hlo->opcode() == HloOpcode::kDynamicSlice) {
320     auto it = broadcast_dims_.find(hlo->operand(0));
321     if (it != broadcast_dims_.end()) {
322       auto dims = it->second;
323       broadcast_dims_[hlo] = std::move(dims);
324     }
325   }
326   if (hlo->opcode() == HloOpcode::kPad) {
327     auto it = broadcast_dims_.find(hlo->operand(0));
328     if (it != broadcast_dims_.end()) {
329       absl::flat_hash_set<int64> pad_broadcast_dims;
330       for (int64 i = 0; i < hlo->shape().rank(); ++i) {
331         const auto& dim = hlo->padding_config().dimensions(i);
332         if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
333             dim.interior_padding() == 0 && it->second.contains(i)) {
334           pad_broadcast_dims.insert(i);
335         }
336       }
337       if (!pad_broadcast_dims.empty()) {
338         broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
339       }
340     }
341   }
342   return hlo;
343 }
344 
Reshard(const HloSharding & target)345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
346   if (sharding() == target) {
347     return *this;
348   }
349   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
350   const bool is_to_replicate =
351       hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
352   if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
353     for (auto& entry : cache) {
354       if (entry.first == target) {
355         return entry.second;
356       }
357     }
358   }
359   auto resharded = ReshardNoCache(target);
360   state_.reshard_cache->per_hlo_cache[resharded.hlo()]
361       .reshard_cache.emplace_back(sharding(), *this);
362   if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
363     cache.emplace_back(target, std::move(resharded));
364     return cache.back().second;
365   }
366   return resharded;
367 }
368 
ReshardNoCache(const HloSharding & target)369 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
370   VLOG(2) << "Resharding " << hlo_->ToString() << " from "
371           << hlo_->sharding().ToString() << " to " << target.ToString();
372   const Shape& shape = hlo_->shape();
373   if (shape.element_type() == TOKEN) {
374     return *this;
375   }
376   CHECK(shape.IsTuple() || !target.IsTuple());
377 
378   // Tuple shape instructions may have non-tuple sharding, which means that the
379   // same sharding applies to all the leaves.
380   if (shape.IsTuple() && !target.IsTuple()) {
381     return Reshard(target.GetTupleSharding(shape).ValueOrDie());
382   }
383 
384   // For a tuple shape, recursively apply Reshard to all the leaves and return
385   // a tuple instruction.
386   if (shape.IsTuple()) {
387     std::vector<HloInstruction*> elements;
388     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
389       auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
390       auto element = state_.b->AddInstruction(
391           HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
392       element->set_sharding(sharding().GetSubSharding(shape, {i}));
393       elements.push_back(
394           PartitionedHlo(
395               element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
396               .Reshard(target.GetSubSharding(shape, {i}))
397               .hlo());
398     }
399     auto tuple =
400         state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
401     tuple->set_sharding(target);
402     return PartitionedHlo(tuple, base_shape_, state_);
403   }
404 
405   if (sharding() == target) {
406     return *this;
407   }
408 
409   if (CanReshardWithCollectivePermute(sharding(), target)) {
410     return ReshardWithCollectivePermute(target);
411   }
412 
413   if (auto src_tgt_dims =
414           GetReshardAllToAllSourceTargetDims(sharding(), target)) {
415     return ReshardWithAllToAll(target, *src_tgt_dims);
416   }
417 
418   if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
419     auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
420     if (try_reshard.has_value()) {
421       return try_reshard.value();
422     }
423     try_reshard = ReshardPartialReplicateWithAllToAll(target);
424     if (try_reshard.has_value()) {
425       return try_reshard.value();
426     }
427   }
428 
429   if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
430     auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
431     if (try_reshard.has_value()) {
432       return try_reshard.value();
433     }
434     try_reshard = ReshardPartialReplicateWithAllToAll(target);
435     if (try_reshard.has_value()) {
436       return try_reshard.value();
437     }
438   }
439 
440   // If not replicated yet, first replicate and then reshard to use one of the
441   // two implementations below.
442   if (!sharding().IsReplicated()) {
443     return Replicate().Reshard(target);
444   }
445 
446   // 'Replicated' to 'SingleDevice'.
447   if (target.IsTileMaximal()) {
448     auto copy = state_.b->AddInstruction(
449         HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
450     copy->set_sharding(target);
451     return PartitionedHlo(copy, base_shape_, state_);
452   }
453 
454   // 'Replicated' to partial replicated.
455   if (target.ReplicateOnLastTileDim()) {
456     std::vector<int64> group_dims(target.tile_assignment().num_dimensions() -
457                                   1);
458     std::iota(group_dims.begin(), group_dims.end(), 0);
459     auto target_grouped = GroupShardingOnDims(target, group_dims);
460     auto partially_sharded = PerGroupSliceFromReplicated(
461         hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
462         target_grouped.group_dim_sizes, state_.b);
463     partially_sharded->set_sharding(target);
464     return PartitionedHlo(partially_sharded, base_shape(), state_);
465   }
466 
467   // 'Replicated' to 'Tiled'.
468   auto padded_hlo =
469       PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
470   auto shard_shape = MakePartitionedShape(shape, target);
471   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
472       shard_shape, padded_hlo,
473       MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
474       shard_shape.dimensions()));
475   slice->set_sharding(target);
476   return PartitionedHlo(slice, base_shape_, state_);
477 }
478 
PadWithValue(HloInstruction * pad_value,absl::Span<const int64> left_padded_dims,absl::Span<const int64> skipped_dims) const479 PartitionedHlo PartitionedHlo::PadWithValue(
480     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
481     absl::Span<const int64> skipped_dims) const {
482   const HloSharding& sharding = hlo_->sharding();
483   const Shape& shape = hlo_->shape();
484   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
485   if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
486     return *this;
487   }
488   CHECK(!sharding.IsTileMaximal());
489   auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
490   auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
491   auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) {
492     // Comparison: iota + start_index < valid_size
493     auto iota =
494         state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
495     auto broadcast_start_index = state_.b->AddInstruction(
496         HloInstruction::CreateBroadcast(index_shape, start_index, {}));
497     auto index_in_full_shape =
498         state_.b->AddInstruction(HloInstruction::CreateBinary(
499             index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
500     ComparisonDirection direction = ComparisonDirection::kLt;
501     int64 index_limit = base_shape_.dimensions(dim);
502     if (absl::c_linear_search(left_padded_dims, dim)) {
503       direction = ComparisonDirection::kGe;
504       index_limit =
505           index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
506           index_limit;
507     }
508     auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
509         LiteralUtil::CreateR0<int32>(index_limit)));
510     auto broadcast_limit = state_.b->AddInstruction(
511         HloInstruction::CreateBroadcast(index_shape, limit, {}));
512     return state_.b->AddInstruction(HloInstruction::CreateCompare(
513         mask_shape, index_in_full_shape, broadcast_limit, direction));
514   };
515 
516   HloInstruction* mask = nullptr;
517   auto offsets = MakePartitionOffsets(base_shape_, sharding,
518                                       state_.partition_id, state_.b);
519   for (int64 i = 0; i < shape.rank(); ++i) {
520     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
521         absl::c_linear_search(skipped_dims, i)) {
522       continue;
523     }
524     if (mask == nullptr) {
525       mask = get_mask_for_dim(i, offsets[i]);
526     } else {
527       mask = state_.b->AddInstruction(
528           HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
529                                        get_mask_for_dim(i, offsets[i])));
530     }
531   }
532 
533   if (mask == nullptr) {
534     return *this;
535   }
536 
537   auto broadcast_pad_value = state_.b->AddInstruction(
538       HloInstruction::CreateBroadcast(shape, pad_value, {}));
539   auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
540       shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
541   result->set_sharding(sharding);
542   return PartitionedHlo(result, base_shape_, state_);
543 }
544 
545 absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)546 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
547                                        const HloSharding& target,
548                                        HloInstruction* pad_value,
549                                        bool mask_invalid_region) {
550   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
551   for (auto& entry : cache) {
552     if (std::get<0>(entry) == target &&
553         protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
554       return std::get<2>(entry);
555     }
556   }
557   auto update_cache = [&](WindowedInputShardReturnValue result) {
558     cache.emplace_back(target, window, std::move(result));
559     return std::get<2>(cache.back());
560   };
561   VLOG(2) << "ReshardAsWindowedInput()\n"
562           << "\twindow:" << window_util::ToString(window)
563           << "\ttarget sharding:" << target.ToString();
564 
565   CHECK(!target.IsTileMaximal());
566   auto partition_ordinals =
567       MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
568   auto shard_shape = base_shape_;
569 
570   std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
571       base_shape_.rank());
572   std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
573       base_shape_.rank());
574   std::vector<HloInstruction*> dynamic_slice_offset_on_output(
575       base_shape_.rank(), nullptr);
576 
577   Window shard_window = window;
578   auto padded_shape = base_shape_;
579   std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
580   std::vector<int64> per_shard_window_counts(base_shape_.rank());
581   std::vector<int64> explicit_left_padding(base_shape_.rank());
582   for (int64 i = 0; i < base_shape_.rank(); ++i) {
583     // Do not pad non-partitioned dimensions.
584     int64 shard_count = target.tile_assignment().dim(i);
585     if (shard_count == 1) {
586       offsets_on_padded_shape[i] = state_.b->AddInstruction(
587           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
588       continue;
589     }
590     const auto& wd = window.dimensions(i);
591     const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
592     int64 full_size =
593         base_shape_.dimensions(i) +
594         (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
595         wd.padding_high() + wd.padding_low();
596     if (full_size < dilated_size) {
597       VLOG(2) << "Failed to reshard window operand because the window size is "
598                  "larger than padded base size";
599       return absl::nullopt;
600     }
601     int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
602     per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
603     if (wd.stride() != 1 &&
604         (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
605       // TODO(yuanzx): Support this case.
606       VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
607       return absl::nullopt;
608     }
609 
610     // We use explicit padding for full dilations, then use padding_low and
611     // padding_high on the sharded op for the remaining. padding_low and
612     // padding_high are now given initial values, which will be later updated if
613     // dilation is not 1.
614     auto swd = shard_window.mutable_dimensions(i);
615     explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
616     swd->set_padding_low(wd.padding_low() % wd.base_dilation());
617     swd->set_padding_high(0);
618 
619     // Calculation for the first element needed on the 'padded-but-not-dilated'
620     // shape. The start on the dilated shape could be a hole, so we add
621     // wd.base_dilation() - 1 to the constant term to skip the leading holes.
622     start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
623         wd.stride() * per_shard_window_counts[i],
624         wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
625     int64 dilated_shard_size =
626         wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
627     limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
628         wd.stride() * per_shard_window_counts[i],
629         dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
630         wd.base_dilation());
631 
632     offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
633         partition_ordinals[i], state_.b);
634 
635     auto shard_size_function =
636         limit_on_padded_calculations[i] - start_on_padded_calculations[i];
637     int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count);
638     shard_shape.set_dimensions(i, max_shard_size);
639     padded_shape.set_dimensions(
640         i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
641 
642     // For base dilation, calculate the needed padding_low and padding_high, as
643     // well as the offset for the output if a dynamic slice is needed after the
644     // sharded op.
645     if (wd.base_dilation() != 1) {
646       // Returns the offset of a shard's first valid element in the dilated
647       // shard.
648       auto get_first_valid_element_offset_on_dilated_shard =
649           [&](int64 shard_ordinal) {
650             return start_on_padded_calculations[i].Calculate(shard_ordinal) *
651                        wd.base_dilation() +
652                    swd->padding_low() -
653                    wd.stride() * per_shard_window_counts[i] * shard_ordinal;
654           };
655       CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
656                swd->padding_low());
657 
658       // Determine swd->padding_high.
659       for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
660            ++shard_ordinal) {
661         int64 wanted_limit_on_dilated_shard =
662             wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
663         int64 actual_limit_on_dilated_shard_without_pad_high =
664             get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
665             (max_shard_size - 1) * wd.base_dilation() + 1;
666         swd->set_padding_high(std::max<int64>(
667             swd->padding_high(),
668             wanted_limit_on_dilated_shard -
669                 actual_limit_on_dilated_shard_without_pad_high));
670       }
671 
672       // Determine swd->padding_low and output dynamic slice index.
673       if (wd.stride() == 1) {
674         int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0);
675         bool all_same = true;
676         for (int64 shard_ordinal = 1; shard_ordinal < shard_count;
677              ++shard_ordinal) {
678           int64 start =
679               get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
680           if (start != swd->padding_low()) {
681             all_same = false;
682           }
683           max_pad_low = std::max(max_pad_low, start);
684         }
685         if (!all_same) {
686           auto start_on_padded_input =
687               start_on_padded_calculations[i].Calculate(partition_ordinals[i],
688                                                         state_.b);
689           // We will calculate
690           //   max_pad_low - (first_window - required_first_window)
691           // which equals
692           //   required_first_window - (first_window - max_pad_low)
693           auto first_window_minus_max_pad_low =
694               MultiplyAddDivideOffsetCalculation(
695                   wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
696                   .Calculate(start_on_padded_input, state_.b);
697           auto required_first_window =
698               MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
699                                                  1)
700                   .Calculate(partition_ordinals[i], state_.b);
701           dynamic_slice_offset_on_output[i] =
702               state_.b->AddInstruction(HloInstruction::CreateBinary(
703                   required_first_window->shape(), HloOpcode::kSubtract,
704                   required_first_window, first_window_minus_max_pad_low));
705         }
706         swd->set_padding_low(max_pad_low);
707       } else {
708         if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
709             0) {
710           // General base dilation not yet implemented.
711           return absl::nullopt;
712         }
713         // padding_low on all shards should equal the initially assigned
714         // swd->padding_low(), i.e., the padding_low() on the original window.
715       }
716     }
717   }
718 
719   // Returns the output dynamic slice offset when needed, and absl::nullopt
720   // otherwise.
721   auto get_dynamic_slice_offset_on_output_if_needed =
722       [&]() -> absl::optional<std::vector<HloInstruction*>> {
723     if (absl::c_all_of(
724             dynamic_slice_offset_on_output,
725             [](HloInstruction* offset) { return offset == nullptr; })) {
726       return absl::nullopt;
727     }
728     auto zero = state_.b->AddInstruction(
729         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
730     for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
731       if (dynamic_slice_offset_on_output[i] == nullptr) {
732         dynamic_slice_offset_on_output[i] = zero;
733       }
734     }
735     return dynamic_slice_offset_on_output;
736   };
737 
738   // If the currrent HLO is replicated, pad then slice.
739   if (sharding().IsReplicated()) {
740     PaddingConfig padding_config;
741     for (int64 i = 0; i < base_shape_.rank(); ++i) {
742       auto padding_config_dim = padding_config.add_dimensions();
743       padding_config_dim->set_interior_padding(0);
744       // Do not pad non-partitioned dimensions.
745       if (target.tile_assignment().dim(i) == 1) {
746         padding_config_dim->set_edge_padding_low(0);
747         padding_config_dim->set_edge_padding_high(0);
748         continue;
749       }
750       padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
751       padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
752                                                 explicit_left_padding[i] -
753                                                 base_shape_.dimensions(i));
754     }
755     auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
756                           ? hlo_
757                           : state_.b->AddInstruction(HloInstruction::CreatePad(
758                                 padded_shape, hlo_, pad_value, padding_config));
759     auto sharded_input =
760         state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
761             shard_shape, padded_hlo, offsets_on_padded_shape,
762             shard_shape.dimensions()));
763     return update_cache(WindowedInputShardReturnValue{
764         sharded_input, shard_window,
765         get_dynamic_slice_offset_on_output_if_needed()});
766   }
767 
768   if (target != sharding()) {
769     return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
770   }
771 
772   // Halo exchange.
773   HloInstruction* visiting_hlo = hlo_;
774   auto original_shard_shape = MakePartitionedShape(base_shape_, target);
775 
776   std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
777   std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
778   // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
779   // and in the second dimension (and beyond) we create halos by slicing the
780   // concat in the previous dimension, which is not optimal. We should generate
781   // halos only concating slices, instead of slicing concats.
782   for (int dim = 0; dim < base_shape_.rank(); ++dim) {
783     int64 shard_count = target.tile_assignment().dim(dim);
784     if (shard_count == 1) {
785       continue;
786     }
787     int64 input_shard_size =
788         CeilOfRatio(base_shape_.dimensions(dim), shard_count);
789 
790     // Left halo. The size of the halo is derived by subtracting the first read
791     // element offset of the i'th partition from the limit of the (i-1)'th
792     // partition.
793     MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
794         input_shard_size, explicit_left_padding[dim], 1);
795     left_halo_size_functions[dim] =
796         shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
797 
798     // Right halo.
799     MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
800         input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
801     right_halo_size_functions[dim] =
802         limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
803 
804     auto resharded = ExchangeHaloAndGetValidData(
805         visiting_hlo, base_shape_, left_halo_size_functions[dim],
806         right_halo_size_functions[dim], explicit_left_padding[dim],
807         padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
808         offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
809         state_.collective_ops_creator, state_.next_channel_id, state_.b,
810         mask_invalid_region);
811     if (!resharded) {
812       VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
813                  "is beyond the neighbor.";
814       return Replicate().ReshardAsWindowedInput(window, target, pad_value);
815     }
816     visiting_hlo = *resharded;
817   }
818   return update_cache(WindowedInputShardReturnValue{
819       visiting_hlo, shard_window,
820       get_dynamic_slice_offset_on_output_if_needed()});
821 }
822 
Replicate()823 PartitionedHlo PartitionedHlo::Replicate() {
824   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
825   if (state_.partitioner->options().cache_all_gather) {
826     for (auto& entry : cache) {
827       if (entry.first.IsReplicated()) {
828         return entry.second;
829       }
830     }
831   }
832   const HloSharding& sharding = hlo_->sharding();
833   const Shape& shape = hlo_->shape();
834   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
835 
836   if (sharding.IsReplicated()) {
837     return *this;
838   }
839   for (auto& entry : cache) {
840     if (entry.first.IsReplicated()) {
841       return entry.second;
842     }
843   }
844   auto update_cache = [&](PartitionedHlo resharded) {
845     state_.reshard_cache->per_hlo_cache[resharded.hlo()]
846         .reshard_cache.emplace_back(sharding, *this);
847     if (state_.partitioner->options().cache_all_gather) {
848       cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
849       return cache.back().second;
850     }
851     return resharded;
852   };
853   // 'Single Device' to 'Repliated'.
854   if (sharding.IsTileMaximal()) {
855     return update_cache(Broadcast());
856   }
857 
858   // 'Tiled' to 'Replicated'.
859   std::vector<int64> all_dims(shape.rank());
860   std::iota(all_dims.begin(), all_dims.end(), 0);
861   HloInstruction* result = ReplicatePartial(all_dims);
862   result->set_sharding(HloSharding::Replicate());
863   return update_cache(PartitionedHlo(result, base_shape_, state_));
864 }
865 
ReplicatePartial(absl::Span<const int64> dims)866 HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
867   CHECK(!sharding().IsTileMaximal());
868   const Shape& shard_shape = hlo()->shape();
869   Shape target_shape = shard_shape;
870   Shape padded_target_shape = shard_shape;
871   for (int64 i : dims) {
872     padded_target_shape.set_dimensions(
873         i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
874     target_shape.set_dimensions(i, base_shape().dimensions(i));
875   }
876 
877   HloInstruction* result = nullptr;
878   if (state_.collective_ops_creator.create_cross_partition_all_gather) {
879     result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
880                                                  state_.next_channel_id, dims,
881                                                  state_.collective_ops_creator);
882   }
883   if (result == nullptr) {
884     auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
885         LiteralUtil::Zero(shard_shape.element_type())));
886     auto zero_bcast = state_.b->AddInstruction(
887         HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
888     auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
889                                         state_.partition_id, state_.b, dims);
890     auto dus =
891         state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
892             padded_target_shape, zero_bcast, hlo_, offsets));
893     HloComputation* reduction =
894         MakeBinaryAdd(shard_shape.element_type(), state_.module);
895     result = state_.partitioner->AllReduceAlongShardingDims(
896         state_.b, dus, sharding(), state_.next_channel_id, dims,
897         state_.collective_ops_creator, reduction);
898   }
899   if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
900     std::vector<int64> start_indices(target_shape.rank(), 0);
901     std::vector<int64> strides(target_shape.rank(), 1);
902     result = state_.b->AddInstruction(
903         HloInstruction::CreateSlice(target_shape, result, start_indices,
904                                     target_shape.dimensions(), strides));
905   }
906   return result;
907 }
908 
909 absl::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)910 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
911     const HloSharding& target) {
912   if (!target.ReplicateOnLastTileDim()) {
913     return absl::nullopt;
914   }
915   // Tiled/partial replicate to partial replicate
916   // Get the comptible sharding to target with resharding by all reduce.
917   auto compatible_sharding =
918       PartialReplicateReshardCompatibleSharding(target, sharding());
919   if (!compatible_sharding.has_value()) {
920     return absl::nullopt;
921   }
922 
923   const auto& temp_sharding = compatible_sharding.value();
924   auto partitioned_hlo = *this;
925   // Use collective permute to adjust device assignment if needed.
926   if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
927     partitioned_hlo =
928         partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
929   }
930 
931   // Get replicate dims and replicate factor of each dimensions.
932   int64 rank = hlo_->shape().rank();
933   std::vector<int64> replicate_dims;
934   std::vector<int64> replicate_factors;
935   for (int64 dim = 0; dim < rank; dim++) {
936     int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) /
937                              target.tile_assignment().dim(dim);
938     if (replicate_factor > 1) {
939       replicate_dims.emplace_back(dim);
940       replicate_factors.emplace_back(replicate_factor);
941     }
942   }
943 
944   // Do left halo exchange if all-reduce directly will remove useful data
945   // from the source.
946   auto halo_exchange = TileToPartialReplicateHaloExchange(
947       partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
948       partitioned_hlo.state().collective_ops_creator,
949       partitioned_hlo.state().next_channel_id,
950       partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
951   if (!halo_exchange.has_value()) {
952     return absl::nullopt;
953   }
954   auto halo_exchange_hlo = halo_exchange.value();
955   // Grouped on replicate dimensions.
956   auto sharding_grouped =
957       GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
958   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
959       partitioned_hlo.state(), sharding_grouped.device_groups,
960       partitioned_hlo.state().b);
961   auto base_shape = MakePartitionedShape(base_shape_, target);
962   // It's possible that halo_exchange_hlo == hlo.hlo().
963   // Record the sharding of hlo here, and reset it before return.
964   auto original_sharding = partitioned_hlo.sharding();
965   halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
966   auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
967                                               per_group_partitioner_state);
968   HloInstruction* result =
969       partial_replicate_hlo.ReplicatePartial(replicate_dims);
970   partitioned_hlo.hlo()->set_sharding(original_sharding);
971   result->set_sharding(target);
972   return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
973 }
974 
975 absl::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)976 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
977     const HloSharding& target) {
978   if (!sharding().ReplicateOnLastTileDim()) {
979     return absl::nullopt;
980   }
981 
982   // Get the temp sharding target from partial replicate to target tile dims.
983   // target_compatible_sharding has the same tile_assignment dimensions
984   // as the target and can reshard to target by collective permute.
985   // target_compatible_sharding could have different device assignment as
986   // targe. sharding() can reshard to target_compatible_sharding by
987   // dynamic slice.
988   auto target_compatible_sharding =
989       PartialReplicateReshardCompatibleSharding(sharding(), target);
990   // Reshard to target_compatible_sharding by dynamic slice.
991   if (!target_compatible_sharding.has_value()) {
992     return absl::nullopt;
993   }
994   std::vector<int64> expand_tile_dims;
995   std::vector<int64> tiling_dim_factors;
996   int64 rank = hlo_->shape().rank();
997   tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
998   const auto& temp_target_sharding = target_compatible_sharding.value();
999   for (int64 dim = 0; dim < rank; dim++) {
1000     if (temp_target_sharding.tile_assignment().dim(dim) >
1001         sharding().tile_assignment().dim(dim)) {
1002       expand_tile_dims.push_back(dim);
1003     }
1004     tiling_dim_factors.emplace_back(
1005         temp_target_sharding.tile_assignment().dim(dim) /
1006         sharding().tile_assignment().dim(dim));
1007   }
1008 
1009   // Add another dimension in tiling_dim_factors if target is partial replicate.
1010   if (target.ReplicateOnLastTileDim()) {
1011     tiling_dim_factors.emplace_back(
1012         target.tile_assignment().dimensions().back());
1013   }
1014 
1015   // 2. Get the padded_hlo, do right halo exchange if needed.
1016   auto padded_hlo = PadFromPartialReplicateShape(
1017       hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1018       state_.collective_ops_creator, state_.next_channel_id,
1019       state_.partition_id, state_.b);
1020   if (!padded_hlo.has_value()) {
1021     return absl::nullopt;
1022   }
1023   // 3. Slice out the tile from replicate ones.
1024   auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1025   // Since we are just slicing, we can just use the differences between the new
1026   // and old offsets in the full shape as the dynamic-slice offsets.
1027   auto padded_base_shape = shard_shape;
1028   for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
1029     padded_base_shape.set_dimensions(
1030         i, padded_base_shape.dimensions(i) *
1031                temp_target_sharding.tile_assignment().dim(i));
1032   }
1033   auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1034                                       state_.partition_id, state_.b);
1035   auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1036                                           state_.partition_id, state_.b);
1037   for (int64 i = 0; i < offsets.size(); ++i) {
1038     offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1039         offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1040   }
1041   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1042       shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1043   slice->set_sharding(temp_target_sharding);
1044   auto result = PartitionedHlo(slice, base_shape_, state_);
1045   // If temp_target_sharding's device assignment is different from target,
1046   // use collective permute to reshard.
1047   if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1048     return result.ReshardWithCollectivePermute(target);
1049   }
1050   // If device assignment in temp_target_sharding and target are the same,
1051   // return result directly.
1052   return result;
1053 }
1054 
Broadcast() const1055 PartitionedHlo PartitionedHlo::Broadcast() const {
1056   const Shape& shape = hlo_->shape();
1057   const HloSharding& sharding = hlo_->sharding();
1058   CHECK(sharding.HasUniqueDevice());
1059   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1060 
1061   auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1062       LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
1063   Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1064   auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1065       bcast_shape,
1066       state_.b->AddInstruction(HloInstruction::CreateCompare(
1067           ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1068           ComparisonDirection::kEq)),
1069       {}));
1070 
1071   auto zero = state_.b->AddInstruction(
1072       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1073   auto zero_bcast = state_.b->AddInstruction(
1074       HloInstruction::CreateBroadcast(shape, zero, {}));
1075   auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1076       shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1077   HloComputation* reduction =
1078       MakeBinaryAdd(shape.element_type(), state_.module);
1079 
1080   auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1081       state_.b, operand, reduction, {}, NewChannel());
1082   result->set_sharding(HloSharding::Replicate());
1083   return PartitionedHlo(result, base_shape_, state_);
1084 }
1085 
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64,int64>> source_target_dims) const1086 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1087     const HloSharding& target,
1088     absl::Span<const std::pair<int64, int64>> source_target_dims) const {
1089   if (source_target_dims.empty()) {
1090     if (target == sharding()) {
1091       return *this;
1092     }
1093     // If the device order is different in the target, fix the order with
1094     // ReshardWithCollectivePermute.
1095     return ReshardWithCollectivePermute(target);
1096   }
1097 
1098   // Swap one pair of dimensions.
1099   int64 source_dim = source_target_dims[0].first;
1100   int64 target_dim = source_target_dims[0].second;
1101   const int64 group_size = sharding().tile_assignment().dim(source_dim) /
1102                            sharding().tile_assignment().dim(target_dim);
1103 
1104   auto temp_target_tile = sharding().tile_assignment();
1105   {
1106     std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
1107     int64 i = 0;
1108     int64 added_source_dim = -1;
1109     int64 added_target_dim = -1;
1110     for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1111       if (source_dim == j) {
1112         reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1113         reshape_tile_dims[++i] = group_size;
1114         added_source_dim = i;
1115       } else if (target_dim == j) {
1116         reshape_tile_dims[i] = temp_target_tile.dim(j);
1117         reshape_tile_dims[++i] = 1;
1118         added_target_dim = i;
1119       } else {
1120         reshape_tile_dims[i] = temp_target_tile.dim(j);
1121       }
1122       ++i;
1123     }
1124     temp_target_tile.Reshape(reshape_tile_dims);
1125     std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
1126     std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1127     xpose_dims[added_source_dim] = added_target_dim;
1128     xpose_dims[added_target_dim] = added_source_dim;
1129     temp_target_tile = hlo_sharding_util::TransposeSharding(
1130                            HloSharding::Tile(temp_target_tile), xpose_dims)
1131                            .tile_assignment();
1132     auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1133     temp_target_tile_dims[source_dim] =
1134         sharding().tile_assignment().dim(target_dim);
1135     temp_target_tile_dims[target_dim] =
1136         sharding().tile_assignment().dim(source_dim);
1137     temp_target_tile.Reshape(temp_target_tile_dims);
1138   }
1139   auto temp_target = target.ReplicateOnLastTileDim()
1140                          ? HloSharding::PartialTile(temp_target_tile)
1141                          : HloSharding::Tile(temp_target_tile);
1142   auto padded_shape = hlo_->shape();
1143   padded_shape.set_dimensions(
1144       target_dim,
1145       RoundUpToNearest(padded_shape.dimensions(target_dim),
1146                        temp_target.tile_assignment().dim(target_dim)));
1147   auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1148 
1149   // The order of ids in the group must follow the temp_target sharding.
1150   std::vector<std::vector<int64>> groups(
1151       temp_target.tile_assignment().num_elements() / group_size);
1152   temp_target.tile_assignment().Each(
1153       [&](absl::Span<const int64> indices, int64 device) {
1154         int64 group_id = 0;
1155         for (int64 dim = 0; dim < indices.size(); ++dim) {
1156           if (dim == target_dim) {
1157             group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1158             group_id += indices[dim] / group_size;
1159           } else {
1160             group_id *= temp_target.tile_assignment().dim(dim);
1161             group_id += indices[dim];
1162           }
1163         }
1164         groups[group_id].push_back(device);
1165       });
1166 
1167   HloInstruction* result = nullptr;
1168 
1169   // Split along the split dimension (target_dim) of the all-to-all
1170   // output.
1171   std::vector<int64> dimensions;
1172   for (int64 i = 0; i < base_shape_.rank(); ++i) {
1173     if (i == target_dim) {
1174       dimensions.push_back(group_size);
1175       dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1176     } else {
1177       dimensions.push_back(padded_hlo->shape().dimensions(i));
1178     }
1179   }
1180   auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1181       ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1182       padded_hlo));
1183   // After the reshape, it is guaranteed to have at least 3 dimensions.
1184   auto all_to_all =
1185       state_.collective_ops_creator.create_cross_partition_all_to_all(
1186           state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1187 
1188   // Reorder the split dimension of the reshape to be located in front of the
1189   // input partition dimension, so the two dimensions can be combined.
1190   int64 new_source_dim =
1191       (target_dim < source_dim) ? source_dim + 1 : source_dim;
1192   std::vector<int64> permutation;
1193   for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
1194     if (i == target_dim) {
1195       continue;
1196     }
1197     if (i == new_source_dim) {
1198       permutation.push_back(target_dim);
1199     }
1200     permutation.push_back(i);
1201   }
1202   auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1203       ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1204           .ValueOrDie(),
1205       all_to_all, permutation));
1206 
1207   // Combine the split dimension and the input partition dimension.
1208   auto new_shape = ShapeInference::InferAllToAllShape(
1209                        padded_hlo->shape(), target_dim, source_dim, group_size)
1210                        .ValueOrDie();
1211   result = state_.b->AddInstruction(
1212       HloInstruction::CreateReshape(new_shape, transpose));
1213 
1214   const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1215   if (result_shape != result->shape()) {
1216     result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1217         result_shape, result, std::vector<int64>(result_shape.rank(), 0),
1218         result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
1219   }
1220   result->set_sharding(temp_target);
1221   auto remaining_source_target_dims = source_target_dims;
1222   remaining_source_target_dims.remove_prefix(1);
1223   return PartitionedHlo(result, base_shape_, state_)
1224       .ReshardWithAllToAll(target, remaining_source_target_dims);
1225 }
1226 
1227 absl::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1228 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1229   bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1230   const auto& partial_replicate_sharding =
1231       source_is_partial_replicate ? sharding() : target;
1232   // If neither the source nor the target is partial replicate, return null.
1233   if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1234     return absl::nullopt;
1235   }
1236   const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1237   // If both source and target are partial replicate, should be supported in
1238   // Reshard with AllToAll already.
1239   if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1240     return absl::nullopt;
1241   }
1242 
1243   // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1244   // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1245   // the last tile dim will be replicate first before all-to-all.
1246   // Or resharding from
1247   // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1248   // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1249   // the last tile dim will be sharded after all-to-all.
1250   const int num_replicas =
1251       partial_replicate_sharding.tile_assignment().dimensions().back();
1252   if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1253        partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1254       (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1255     return absl::nullopt;
1256   }
1257   int to_replicate_dim = -1;
1258   for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1259        --i) {
1260     if (tile_sharding.tile_assignment().dim(i) > 1 &&
1261         (to_replicate_dim == -1)) {
1262       if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1263         return absl::nullopt;
1264       }
1265       to_replicate_dim = i;
1266     }
1267 
1268     if (tile_sharding.tile_assignment().dim(i) !=
1269         partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1270       return absl::nullopt;
1271     }
1272   }
1273 
1274   if (to_replicate_dim == -1) {
1275     return absl::nullopt;
1276   }
1277 
1278   // Check if core assignments for source and the target are the same.
1279   auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1280   reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1281   if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1282     return absl::nullopt;
1283   }
1284 
1285   auto tmp_tile_assignment = tile_sharding.tile_assignment();
1286   auto tmp_tile_assignment_dimensions =
1287       tile_sharding.tile_assignment().dimensions();
1288   tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1289   tmp_tile_assignment_dimensions.push_back(num_replicas);
1290   tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1291   auto tmp_partial_replicate_sharding =
1292       HloSharding::PartialTile(tmp_tile_assignment);
1293 
1294   if (source_is_partial_replicate) {
1295     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1296             sharding(), tmp_partial_replicate_sharding)) {
1297       auto partitioned_hlo =
1298           ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1299       return partitioned_hlo.Reshard(target);
1300     }
1301   } else {
1302     auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1303 
1304     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1305             partitioned_hlo.sharding(), target)) {
1306       return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1307     }
1308   }
1309 
1310   return absl::nullopt;
1311 }
1312 
ReshardWithCollectivePermute(const HloSharding & target) const1313 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1314     const HloSharding& target) const {
1315   CHECK(CanReshardWithCollectivePermute(sharding(), target))
1316       << sharding().ToString() << " to " << target.ToString();
1317   if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1318     if (!(*broadcast_dims)->empty()) {
1319       // If hlo() has broadcast dims, check if data is already the same between
1320       // source/destination pairs.
1321       std::vector<int64> broadcast_dims_vector;
1322       for (int64 i = 0; i < hlo()->shape().rank(); ++i) {
1323         if ((*broadcast_dims)->contains(i)) {
1324           broadcast_dims_vector.push_back(i);
1325         }
1326       }
1327       if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1328               sharding(), broadcast_dims_vector) ==
1329           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1330               target, broadcast_dims_vector)) {
1331         auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1332             hlo()->shape(), HloOpcode::kCopy, hlo()));
1333         copy->set_sharding(target);
1334         return PartitionedHlo(copy, base_shape_, state_);
1335       }
1336     }
1337   }
1338   std::vector<std::pair<int64, int64>> src_dst_pairs;
1339   sharding().tile_assignment().Each(
1340       [&](absl::Span<const int64> indices, int64 src_device) {
1341         int64 dst_device = target.tile_assignment()(indices);
1342         src_dst_pairs.emplace_back(src_device, dst_device);
1343       });
1344   auto cp =
1345       state_.collective_ops_creator.create_cross_partition_collective_permute(
1346           state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1347   cp->set_sharding(target);
1348   return PartitionedHlo(cp, base_shape_, state_);
1349 }
1350 
SpmdPartitioningVisitor(HloComputation * computation,int64 num_partitions,int64 num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1351 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1352     HloComputation* computation, int64 num_partitions, int64 num_replicas,
1353     const SPMDCollectiveOpsCreator& collective_ops_creator,
1354     int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
1355     SpmdPartitioner* partitioner)
1356     : changed_(false),
1357       module_(computation->parent()),
1358       num_partitions_(num_partitions),
1359       num_replicas_(num_replicas),
1360       collective_ops_creator_(collective_ops_creator),
1361       next_channel_id_(next_channel_id),
1362       b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
1363       partition_id_(collective_ops_creator_.create_partition_id(&b_)),
1364       logger_(logger),
1365       options_(std::move(options)),
1366       partitioner_(partitioner) {}
1367 
DefaultAction(HloInstruction * hlo)1368 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
1369   if (hlo->HasSideEffect()) {
1370     return Unimplemented("Side-effect ops cannot be replicated: %s",
1371                          hlo->ToString());
1372   }
1373 
1374   if (hlo->IsElementwise() && hlo->operand_count() > 0) {
1375     return HandleElementwise(hlo);
1376   }
1377 
1378   if (!hlo->sharding().IsTileMaximal()) {
1379     VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
1380             << hlo->ToString();
1381     for (int64 i = 0; i < hlo->operand_count(); ++i) {
1382       VLOG(1) << "  operand " << i
1383               << " sharding:" << hlo->operand(i)->sharding().ToString();
1384     }
1385   }
1386 
1387   HloSharding sharding = hlo->sharding().HasUniqueDevice()
1388                              ? hlo->sharding()
1389                              : HloSharding::Replicate();
1390 
1391   // If the instruction cannot be partitioned, replicate the instruction unless
1392   // the instruction has side-effect.
1393   std::vector<HloInstruction*> new_operands;
1394   for (HloInstruction* operand : hlo->operands()) {
1395     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1396   }
1397   auto clone =
1398       b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
1399   clone->set_sharding(sharding);
1400   clone->set_metadata(hlo->metadata());
1401   SetPartitionedHlo(hlo,
1402                     PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
1403                         .Reshard(hlo->sharding()));
1404   return Status::OK();
1405 }
1406 
Preprocess(HloInstruction * hlo)1407 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
1408   visiting_hlo_ = hlo;
1409   b_.set_visiting_hlo(hlo);
1410   // Temporarily replace manual sharding to one-device sharding so that the
1411   // partitioner will not change the HLOs.
1412   auto manual_to_onedevice = [&](const Shape& shape,
1413                                  const HloSharding& sharding) {
1414     // If a tuple's elements are all manual, then sharding.IsManual() == True,
1415     // so we test whether it is tuple first.
1416     if (sharding.IsTuple()) {
1417       std::vector<HloSharding> subshardings = sharding.tuple_elements();
1418       for (HloSharding& subsharding : subshardings) {
1419         if (subsharding.IsManual()) {
1420           subsharding = HloSharding::AssignDevice(0);
1421         }
1422       }
1423       return HloSharding::Tuple(shape, subshardings);
1424     }
1425     if (sharding.IsManual()) {
1426       return HloSharding::AssignDevice(0);
1427     }
1428     return sharding;
1429   };
1430   const bool has_manual_sharding =
1431       hlo->sharding().IsManual() ||
1432       (hlo->sharding().IsTuple() &&
1433        absl::c_any_of(
1434            hlo->sharding().tuple_elements(),
1435            [](const HloSharding& sharding) { return sharding.IsManual(); }));
1436   if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
1437     visiting_hlo_sharding_ = hlo->sharding();
1438     hlo->set_sharding(
1439         manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
1440 
1441     visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1442     for (auto operand : hlo->operands()) {
1443       visiting_hlo_operand_shardings_.push_back(operand->sharding());
1444       operand->set_sharding(
1445           manual_to_onedevice(operand->shape(), operand->sharding()));
1446       GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1447     }
1448   }
1449   return Status::OK();
1450 }
1451 
Postprocess(HloInstruction * hlo)1452 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
1453   logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
1454                             b_.derived_instructions(hlo));
1455   visiting_hlo_ = nullptr;
1456   b_.set_visiting_hlo(nullptr);
1457   // Revert fake one-device shardings for manually partitioned ops.
1458   if (visiting_hlo_sharding_) {
1459     hlo->set_sharding(*visiting_hlo_sharding_);
1460     GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
1461     for (int64 i = 0; i < hlo->operand_count(); ++i) {
1462       auto operand = hlo->mutable_operand(i);
1463       operand->set_sharding(visiting_hlo_operand_shardings_[i]);
1464       GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1465     }
1466     visiting_hlo_sharding_.reset();
1467     visiting_hlo_operand_shardings_.clear();
1468   }
1469   return Status::OK();
1470 }
1471 
HandleElementwise(HloInstruction * hlo)1472 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
1473   std::vector<HloInstruction*> new_operands;
1474   for (HloInstruction* operand : hlo->operands()) {
1475     new_operands.push_back(
1476         GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
1477   }
1478   SetPartitionedHlo(hlo, [&] {
1479     return b_.AddInstruction(hlo->CloneWithNewOperands(
1480         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1481   });
1482   return Status::OK();
1483 }
1484 
HandleConcatenate(HloInstruction * hlo)1485 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
1486   const HloSharding& sharding = hlo->sharding();
1487   if (sharding.IsTileMaximal()) {
1488     return DefaultAction(hlo);
1489   }
1490 
1491   const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
1492   const int64 dimension = hlo->concatenate_dimension();
1493   if (sharding.tile_assignment().dim(dimension) == 1) {
1494     std::vector<HloInstruction*> new_operands;
1495     for (HloInstruction* operand : hlo->operands()) {
1496       new_operands.push_back(
1497           GetPartitionedHlo(operand).Reshard(sharding).hlo());
1498     }
1499     SetPartitionedHlo(hlo, [&] {
1500       return b_.AddInstruction(
1501           hlo->CloneWithNewOperands(shard_shape, new_operands));
1502     });
1503     return Status::OK();
1504   }
1505 
1506   // If the concatenate dimension is along one of the partitioned dimensions,
1507   // allocate the full output shape, each partition updates its owned region,
1508   // all-reduce across partitions, and then slice its output region.
1509 
1510   // temp_output_shape is the output shape where the concatenate dimension
1511   // is changed to the full (and padded to shard count) dimension size.
1512   auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
1513   auto last_operand_padded_shape =
1514       MakePartitionedShape(hlo->operands().back()->shape(), sharding);
1515   // If the last operand has more padding than the temp_output padding, needs to
1516   // add extra padding to avoid dynamic update slice out of bound.
1517   int last_operand_padding =
1518       last_operand_padded_shape.dimensions(dimension) *
1519           sharding.tile_assignment().dim(dimension) -
1520       hlo->operands().back()->shape().dimensions(dimension);
1521   int temp_output_padding = temp_output_shape.dimensions(dimension) *
1522                                 sharding.tile_assignment().dim(dimension) -
1523                             hlo->shape().dimensions(dimension);
1524   int padding_for_last_operand =
1525       last_operand_padding < temp_output_padding
1526           ? 0
1527           : last_operand_padding - temp_output_padding;
1528   temp_output_shape.set_dimensions(
1529       dimension, temp_output_shape.dimensions(dimension) *
1530                          sharding.tile_assignment().dim(dimension) +
1531                      padding_for_last_operand);
1532   auto temp_output = CreateZero(temp_output_shape, &b_);
1533 
1534   // Offset of each operand along the concatenate dimension.
1535   int64 offset = 0;
1536   for (HloInstruction* operand : hlo->operands()) {
1537     auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
1538     std::vector<HloInstruction*> start_indices(
1539         hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
1540                                  LiteralUtil::Zero(S32))));
1541     start_indices[dimension] =
1542         MultiplyAddDivideOffsetCalculation(
1543             spmd_operand->shape().dimensions(dimension), offset, 1)
1544             .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_,
1545                                                   &b_)[dimension],
1546                        &b_);
1547     temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1548         temp_output_shape, temp_output, spmd_operand, start_indices));
1549     offset += operand->shape().dimensions(dimension);
1550   }
1551   std::vector<int64> non_concat_dims;
1552   non_concat_dims.reserve(hlo->shape().rank() - 1);
1553   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1554     if (i != dimension) {
1555       non_concat_dims.push_back(i);
1556     }
1557   }
1558   auto grouped = GroupShardingOnDims(sharding, non_concat_dims);
1559   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1560       MakePartitioningState(), grouped.device_groups, &b_);
1561   auto all_reduce = per_group_partitioner_state.collective_ops_creator
1562                         .create_cross_partition_all_reduce(
1563                             &b_, temp_output,
1564                             MakeBinaryAdd(hlo->shape().element_type(), module_),
1565                             {}, NewChannel());
1566   SetPartitionedHlo(hlo, [&] {
1567     auto start_indices = MakeTiledPartitionOrdinals(
1568         grouped.sharding, per_group_partitioner_state.partition_id, &b_);
1569     start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
1570                                    shard_shape.dimensions(dimension), 0, 1)
1571                                    .Calculate(start_indices[dimension], &b_);
1572     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
1573         shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
1574   });
1575 
1576   return Status::OK();
1577 }
1578 
HandleSlice(HloInstruction * hlo)1579 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
1580   const HloSharding& sharding = hlo->sharding();
1581   if (sharding.IsTileMaximal()) {
1582     return DefaultAction(hlo);
1583   }
1584 
1585   auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
1586 
1587   // Create a window config to represent the slice.
1588   Window window;
1589   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1590     WindowDimension* dim = window.add_dimensions();
1591     dim->set_size(1);
1592     dim->set_stride(hlo->slice_strides(i));
1593     dim->set_window_dilation(1);
1594     dim->set_window_reversal(false);
1595     dim->set_padding_low(-hlo->slice_starts(i));
1596     dim->set_padding_high(hlo->slice_limits(i) -
1597                           operand.base_shape().dimensions(i));
1598     dim->set_base_dilation(1);
1599   }
1600 
1601   auto reshard_operand = operand.ReshardAsWindowedInput(
1602       window, sharding,
1603       CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
1604       /*mask_invalid_region=*/false);
1605   if (!reshard_operand.has_value()) {
1606     return DefaultAction(hlo);
1607   }
1608   TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
1609   const Shape& operand_shape = reshard_operand->sharded_input->shape();
1610 
1611   std::vector<int64> start_indices = hlo->slice_starts();
1612   std::vector<int64> limit_indices = hlo->slice_limits();
1613   std::vector<int64> strides = hlo->slice_strides();
1614   bool need_slice = false;
1615   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1616     auto dim = reshard_operand->shard_window.dimensions(i);
1617     start_indices[i] = -dim.padding_low();
1618     limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
1619     if (start_indices[i] != 0 || strides[i] != 1 ||
1620         limit_indices[i] != operand_shape.dimensions(i)) {
1621       need_slice = true;
1622     }
1623   }
1624 
1625   SetPartitionedHlo(hlo, [&] {
1626     if (need_slice) {
1627       auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
1628       return b_.AddInstruction(HloInstruction::CreateSlice(
1629           shard_shape, reshard_operand->sharded_input, start_indices,
1630           limit_indices, strides));
1631     }
1632     auto data = reshard_operand->sharded_input;
1633     // Create a copy so that it will not share the resharding cache.
1634     return b_.AddInstruction(
1635         HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
1636   });
1637 
1638   return Status::OK();
1639 }
1640 
HandleSort(HloInstruction * hlo)1641 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
1642   HloSharding sharding = hlo->sharding();
1643   // Special handling for sort in TopK when first operand partitioined at
1644   // sort dimension.
1645   auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
1646   if (k.has_value()) {
1647     // When the first operand partitioned at sort dimension:
1648     //   1. Partition sort computation to different partitions;
1649     //   2. Slice TopK value and index from different partitions;
1650     //   3. Gather and replicate value and index from different partitions,
1651     //      the shape of replicated value and index will be
1652     //      [batch_size, ..., partition_count * k, ...];
1653     //   4. Final sort uses replicated value and index from different partitions
1654     //      as input.
1655     // GetTupleElement and Slice after the non-partitoned sort won't change
1656     // at this point, as HandleGetTupleElement and HandleSlice will update them.
1657     HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1658     const int64 sort_dim = sort->sort_dimension();
1659     auto input = hlo->operand(0);
1660     auto index = hlo->operand(1);
1661     const HloSharding& input_sharding = input->sharding();
1662     const int64 partition_count =
1663         input_sharding.tile_assignment().dim(sort_dim);
1664     const int64 input_size = input->shape().dimensions(sort_dim);
1665     const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
1666     const auto element_type = input->shape().element_type();
1667     const auto index_type = index->shape().element_type();
1668 
1669     // Partition and pad input and index.
1670     // Pad input with minimal value.
1671     auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1672         CreateFirstWithType(element_type, &b_));
1673     // Pad index with max value.
1674     auto partitioned_index =
1675         GetPartitionedHlo(index)
1676             .Reshard(input_sharding)
1677             .PadWithValue(CreateLastWithType(index_type, &b_));
1678 
1679     // Each partition needs to do TopK separately, thus the base shape
1680     // becomes the padded shape.
1681     std::vector<int64> replicated_dimensions(
1682         input->shape().dimensions().begin(), input->shape().dimensions().end());
1683     replicated_dimensions[sort_dim] = per_partition_size * partition_count;
1684     const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1685         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1686          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1687 
1688     // Partition original topk to different shards.
1689     auto topk_sharding =
1690         input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1691     auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
1692     auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
1693         shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
1694 
1695     // Get value from first sort.
1696     HloInstruction* value_gte =
1697         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1698             topk->shape().tuple_shapes(0), topk, 0));
1699     HloInstruction* index_gte =
1700         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1701             topk->shape().tuple_shapes(1), topk, 1));
1702 
1703     // Slice top K value from the first partitioned sort.
1704     replicated_dimensions[sort_dim] = k.value() * partition_count;
1705     auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
1706     slice_input->set_sharding(input_sharding);
1707     PartitionedHlo partitioned_slice_input(
1708         slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
1709         MakePartitioningState());
1710     // Reshard value to be replicated.
1711     auto replicated_slice_input =
1712         partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
1713 
1714     // Slice top K index from the first parttioned sort.
1715     auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
1716     slice_index->set_sharding(input_sharding);
1717     PartitionedHlo partitioned_slice_index(
1718         slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
1719         MakePartitioningState());
1720     // Reshard value to be replicated.
1721     auto replicated_slice_index =
1722         partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
1723 
1724     // Creates replicated sort to do TopK, the input is value and index pairs
1725     // from all the partitions.
1726     const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
1727         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1728          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1729     HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
1730         final_topk_shape, sort_dim,
1731         {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
1732         sort->is_stable()));
1733     final_sort->set_sharding(HloSharding::Replicate()
1734                                  .GetTupleSharding(final_sort->shape())
1735                                  .ValueOrDie());
1736     PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
1737                                    MakePartitioningState());
1738     SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
1739 
1740     return Status::OK();
1741   }
1742 
1743   if (hlo->shape().IsTuple()) {
1744     // Check that all elements are sharded in the same way.
1745     if (hlo->shape().tuple_shapes_size() == 0) {
1746       return DefaultAction(hlo);
1747     }
1748     sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
1749     for (int64 i = 1; i < hlo->operand_count(); ++i) {
1750       if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
1751         return DefaultAction(hlo);
1752       }
1753     }
1754   }
1755   if (sharding.IsTileMaximal()) {
1756     return DefaultAction(hlo);
1757   }
1758   for (int64 dim : hlo->dimensions()) {
1759     if (sharding.tile_assignment().dim(dim) > 1) {
1760       return DefaultAction(hlo);
1761     }
1762   }
1763   // Reshard operands to the same as the output.
1764   std::vector<HloInstruction*> new_operands;
1765   for (HloInstruction* operand : hlo->operands()) {
1766     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1767   }
1768   SetPartitionedHlo(hlo, [&] {
1769     return b_.AddInstruction(hlo->CloneWithNewOperands(
1770         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1771   });
1772   return Status::OK();
1773 }
1774 
HandleCustomCall(HloInstruction * hlo)1775 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
1776   if (hlo->custom_call_target() == "SPMDFullToShardShape") {
1777     // This op switches from auto partitioning to manual partitioning.
1778     auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
1779     if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
1780       input_partitioned = input_partitioned.PadWithValue(
1781           CreateR0WithType(hlo->shape().element_type(), 0, &b_));
1782     }
1783     auto input = input_partitioned.hlo();
1784     CHECK(hlo->sharding().IsManual());
1785     CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
1786     auto copy = b_.AddInstruction(
1787         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
1788     SetPartitionedHlo(hlo, [&] { return copy; });
1789     return Status::OK();
1790   }
1791   if (hlo->custom_call_target() == "SPMDShardToFullShape") {
1792     // This op switches from manual partitioning to auto partitioning.
1793     auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
1794     CHECK(input->sharding().IsManual());
1795     auto copy = b_.AddInstruction(
1796         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
1797     CHECK(ShapeUtil::Compatible(
1798         copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
1799     SetPartitionedHlo(hlo, [&] { return copy; });
1800     return Status::OK();
1801   }
1802   if (hlo->custom_call_target() != "TopK") {
1803     return DefaultAction(hlo);
1804   }
1805 
1806   if (!hlo->operand(0)->has_sharding()) {
1807     return DefaultAction(hlo);
1808   }
1809 
1810   const HloSharding& sharding = hlo->operand(0)->sharding();
1811   if (sharding.IsTileMaximal() || sharding.IsReplicated()) {
1812     return DefaultAction(hlo);
1813   }
1814 
1815   const int64 sort_dim = 1;
1816   const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
1817 
1818   if (shard_count <= 1) {
1819     return DefaultAction(hlo);
1820   }
1821 
1822   const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1823   const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0);
1824   const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
1825   const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
1826 
1827   if (k >= per_partition_size) {
1828     return DefaultAction(hlo);
1829   }
1830 
1831   auto input = hlo->operand(0);
1832   const auto element_type = input->shape().element_type();
1833 
1834   auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1835       CreateFirstWithType(element_type, &b_));
1836 
1837   // Each partition needs to do TopK separately, thus the base shape
1838   // becomes [batch_size, k * shard_count].
1839   const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1840       {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
1841                             {batch_size, k * shard_count}),
1842        ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
1843   auto custom_call_sharding =
1844       sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1845   auto shard_shape =
1846       MakePartitionedShape(replicated_shape, custom_call_sharding);
1847   auto topk = b_.AddInstruction(
1848       hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
1849   topk->set_sharding(custom_call_sharding);
1850   // Partition customcall.
1851   PartitionedHlo partitioned_topk(topk, replicated_shape,
1852                                   MakePartitioningState());
1853   topk = partitioned_topk.hlo();
1854 
1855   // Get value from TopK.
1856   HloInstruction* value_gte =
1857       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1858           topk->shape().tuple_shapes(0), topk, 0));
1859   value_gte->set_sharding(sharding);
1860   // Partition GetTupleElement of value.
1861   PartitionedHlo value_partitioned_gte(
1862       value_gte, partitioned_topk.base_shape().tuple_shapes(0),
1863       MakePartitioningState());
1864   // Reshard value to be replicated.
1865   auto replicated_value_gte =
1866       value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
1867 
1868   // Get index from TopK.
1869   HloInstruction* index_gte =
1870       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1871           topk->shape().tuple_shapes(1), topk, 1));
1872   auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
1873       ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
1874       partition_id_));
1875   // Add per partition offset to index, index returned from CustomCall always
1876   // starts from 0.
1877   auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
1878       index_gte->shape(),
1879       b_.AddInstruction(HloInstruction::CreateBinary(
1880           partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
1881           b_.AddInstruction(HloInstruction::CreateConstant(
1882               LiteralUtil::CreateR0<int32>(per_partition_size))))),
1883       {}));
1884   index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
1885       index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
1886   index_gte->set_sharding(sharding);
1887   // Parttion GetTupleElement of index.
1888   PartitionedHlo index_partitioned_gte(
1889       index_gte, partitioned_topk.base_shape().tuple_shapes(1),
1890       MakePartitioningState());
1891   // Reshard index to be replicated.
1892   auto replicated_index_gte =
1893       index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
1894 
1895   // Creates replicated sort to do TopK, the input is value and index pairs
1896   // from all the partitions. The reason to use Sort instead of CustomCall TopK
1897   // is CustomCall only takes value as input. There will be an extra Gather
1898   // to get the correct index if CustomCall is used here.
1899 
1900   // Create comparator for the sort.
1901   XlaBuilder b("Sort.Compare");
1902   XlaComputation comparator = CreateScalarComparisonComputation(
1903       "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
1904       &b);
1905   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
1906   HloModuleConfig config(program_shape);
1907   TF_ASSIGN_OR_RETURN(auto new_module,
1908                       HloModule::CreateFromProto(comparator.proto(), config));
1909   HloCloneContext context(module_);
1910   auto compare_computation =
1911       module_->DeepCloneComputation(new_module->entry_computation(), &context);
1912   auto sort = b_.AddInstruction(HloInstruction::CreateSort(
1913       replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
1914       compare_computation, true));
1915   sort->set_sharding(
1916       HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie());
1917   PartitionedHlo replicated_sort(sort, replicated_shape,
1918                                  MakePartitioningState());
1919 
1920   // Slice value and index from top-k for output.
1921   HloInstruction* sort_value_gte =
1922       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1923           replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
1924           0));
1925   HloInstruction* sort_index_gte =
1926       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1927           replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
1928           1));
1929   // Slice value from final sort.
1930   HloInstruction* slice_sort_value =
1931       SliceFirstK(sort_value_gte, &b_, sort_dim, k);
1932   // Slice index from final sort.
1933   HloInstruction* slice_index_value =
1934       SliceFirstK(sort_index_gte, &b_, sort_dim, k);
1935   auto create_tuple = b_.AddInstruction(
1936       HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
1937   create_tuple->set_sharding(HloSharding::Replicate());
1938 
1939   SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(),
1940                                         MakePartitioningState())
1941                              .Reshard(hlo->sharding()));
1942 
1943   return Status::OK();
1944 }
1945 
HandleTranspose(HloInstruction * hlo)1946 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
1947   const HloSharding& sharding = hlo->sharding();
1948   if (sharding.IsTileMaximal()) {
1949     return DefaultAction(hlo);
1950   }
1951 
1952   std::vector<int64> inverse_dimensions(hlo->shape().rank());
1953   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1954     inverse_dimensions[hlo->dimensions(i)] = i;
1955   }
1956   auto desired_operand_sharding =
1957       hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
1958 
1959   auto operand = GetPartitionedHlo(hlo->operand(0))
1960                      .Reshard(desired_operand_sharding)
1961                      .hlo();
1962   SetPartitionedHlo(hlo, [&] {
1963     return b_.AddInstruction(hlo->CloneWithNewOperands(
1964         MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
1965   });
1966   return Status::OK();
1967 }
1968 
HandleReshape(HloInstruction * hlo)1969 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
1970   const HloSharding& sharding = hlo->sharding();
1971   if (sharding.IsTileMaximal()) {
1972     return DefaultAction(hlo);
1973   }
1974 
1975   auto operand = GetPartitionedHlo(hlo->operand(0));
1976   // The output shape is the source and the operand shape is the target to get
1977   // the aligned sharding for the operand.
1978   auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
1979       hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
1980   if (desired_operand_sharding.has_value()) {
1981     auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
1982     SetPartitionedHlo(hlo, [&] {
1983       return b_.AddInstruction(hlo->CloneWithNewOperands(
1984           MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
1985     });
1986     return Status::OK();
1987   }
1988 
1989   // Check if operand sharding and sharding are both tiled or partial replicate.
1990   // If both of them are partial replicate, check num_replications are the same.
1991   if (operand.sharding().ReplicateOnLastTileDim() !=
1992           sharding.ReplicateOnLastTileDim() ||
1993       (sharding.ReplicateOnLastTileDim() &&
1994        (operand.sharding().tile_assignment().dimensions().back() !=
1995         sharding.tile_assignment().dimensions().back()))) {
1996     return DefaultAction(hlo);
1997   }
1998 
1999   // Try use halo exchange for certain split-dim/merge-dims cases.
2000   // ReshapeSharding failed in these cases probably due to uneven partitioning,
2001   // where halo exchange could help. Specifically we check the following
2002   // conditions to detect supported cases:
2003   // 1) Both input and output are partitioned on one dimension.
2004   // 2) The combined size of dimensions before the partitioned dimension are the
2005   // same on input and output. This means we don't need to consider the major
2006   // dimensions.
2007   // 3) Let A = the input size on the partitioned dimension, and
2008   //        B = the output size on the partitioned dimension; then
2009   //    either A % B == 0 (split dim) or B % A == 0 (merge dims).
2010   auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
2011   auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
2012   if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
2013     return DefaultAction(hlo);
2014   }
2015   int64 input_sharded_dim = *maybe_input_sharded_dim;
2016   int64 output_sharded_dim = *maybe_output_sharded_dim;
2017   // Check that the major dims before the sharded dim have the same total size
2018   // for input and output.
2019   int64 input_major_dims_size = 1;
2020   for (int64 i = 0; i < input_sharded_dim; ++i) {
2021     input_major_dims_size *= operand.base_shape().dimensions(i);
2022   }
2023   int64 output_major_dims_size = 1;
2024   for (int64 i = 0; i < output_sharded_dim; ++i) {
2025     output_major_dims_size *= hlo->shape().dimensions(i);
2026   }
2027   if (input_major_dims_size != output_major_dims_size) {
2028     return DefaultAction(hlo);
2029   }
2030   // Fix potential device ordering mismatch in tile assignment.
2031   Array<int64> new_input_tile_assignment = sharding.tile_assignment();
2032   new_input_tile_assignment.Reshape(
2033       operand.sharding().tile_assignment().dimensions());
2034   auto aligned_sharding =
2035       sharding.ReplicateOnLastTileDim()
2036           ? HloSharding::PartialTile(new_input_tile_assignment)
2037           : HloSharding::Tile(new_input_tile_assignment);
2038   operand = operand.Reshard(aligned_sharding);
2039   auto replication_count = sharding.ReplicateOnLastTileDim()
2040                                ? sharding.tile_assignment().dimensions().back()
2041                                : 1;
2042 
2043   int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2044   int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2045   auto input_shard_shape =
2046       MakePartitionedShape(operand.base_shape(), operand.sharding());
2047   auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2048   if (input_dim_size % output_dim_size == 0) {
2049     // Split dim.
2050     int64 split_factor = input_dim_size / output_dim_size;
2051     int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim);
2052     // Use halo exchange to fix misaligned data.
2053     Window window;
2054     for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2055       WindowDimension* dim = window.add_dimensions();
2056       dim->set_size(1);
2057       dim->set_stride(1);
2058       dim->set_window_dilation(1);
2059       dim->set_window_reversal(false);
2060       dim->set_base_dilation(1);
2061       dim->set_padding_low(0);
2062       if (i == input_sharded_dim) {
2063         dim->set_padding_high(output_shard_size * split_factor *
2064                                   num_partitions_ / replication_count -
2065                               input_dim_size);
2066       } else {
2067         dim->set_padding_high(0);
2068       }
2069     }
2070 
2071     auto reshard_operand = operand.ReshardAsWindowedInput(
2072         window, operand.sharding(),
2073         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2074         /*mask_invalid_region=*/false);
2075     if (!reshard_operand.has_value()) {
2076       return DefaultAction(hlo);
2077     }
2078     TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2079     CHECK_EQ(
2080         reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2081         output_shard_size * split_factor);
2082     SetPartitionedHlo(hlo, [&] {
2083       // Do a local reshape.
2084       return b_.AddInstruction(HloInstruction::CreateReshape(
2085           output_shard_shape, reshard_operand->sharded_input));
2086     });
2087     return Status::OK();
2088   } else if (output_dim_size % input_dim_size == 0) {
2089     // Merge dims.
2090     int64 merge_factor = output_dim_size / input_dim_size;
2091     // First reshape locally. (The sharded dimension could include padded data.)
2092     auto tmp_shard_shape = output_shard_shape;
2093     tmp_shard_shape.set_dimensions(
2094         output_sharded_dim,
2095         input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2096     auto tmp_reshape = b_.AddInstruction(
2097         HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2098     tmp_reshape->set_metadata(hlo->metadata());
2099     tmp_reshape->set_sharding(hlo->sharding());
2100     auto tmp_full_shape = tmp_shard_shape;
2101     tmp_full_shape.set_dimensions(
2102         output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2103                                 num_partitions_ / replication_count);
2104     auto tmp_output =
2105         PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2106 
2107     // Use halo exchange to fix misaligned data.
2108     Window window;
2109     for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) {
2110       WindowDimension* dim = window.add_dimensions();
2111       dim->set_size(1);
2112       dim->set_stride(1);
2113       dim->set_window_dilation(1);
2114       dim->set_window_reversal(false);
2115       dim->set_base_dilation(1);
2116       dim->set_padding_low(0);
2117       if (i == output_sharded_dim) {
2118         dim->set_padding_high(output_dim_size -
2119                               tmp_shard_shape.dimensions(output_sharded_dim) *
2120                                   num_partitions_ / replication_count);
2121       } else {
2122         dim->set_padding_high(0);
2123       }
2124     }
2125 
2126     auto reshard_output = tmp_output.ReshardAsWindowedInput(
2127         window, sharding,
2128         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2129         /*mask_invalid_region=*/false);
2130     if (!reshard_output.has_value()) {
2131       return DefaultAction(hlo);
2132     }
2133     TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2134     CHECK_EQ(
2135         reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2136         output_shard_shape.dimensions(output_sharded_dim));
2137     SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2138     return Status::OK();
2139   }
2140   return DefaultAction(hlo);
2141 }
2142 
HandleIota(HloInstruction * hlo)2143 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2144   const HloSharding& sharding = hlo->sharding();
2145   if (sharding.IsTileMaximal()) {
2146     return DefaultAction(hlo);
2147   }
2148 
2149   SetPartitionedHlo(hlo, [&] {
2150     int64 dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2151     auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2152         MakePartitionedShape(hlo->shape(), sharding), dimension));
2153 
2154     if (sharding.tile_assignment().dim(dimension) > 1) {
2155       auto partition_ordinals =
2156           MakeTiledPartitionOrdinals(sharding, partition_id_, &b_);
2157       auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2158           LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
2159       auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2160           ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2161           partition_ordinals[dimension], multiplier));
2162       if (iota->shape().element_type() != S32) {
2163         offset = b_.AddInstruction(HloInstruction::CreateConvert(
2164             ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2165       }
2166       auto broadcast = b_.AddInstruction(
2167           HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2168       return b_.AddInstruction(HloInstruction::CreateBinary(
2169           iota->shape(), HloOpcode::kAdd, iota, broadcast));
2170     }
2171 
2172     return iota;
2173   });
2174 
2175   return Status::OK();
2176 }
2177 
HandleSingleDevice(const HloInstruction * hlo)2178 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2179   TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2180   int64 device = hlo->sharding().GetUniqueDevice();
2181   const HloSharding sharding = HloSharding::AssignDevice(device);
2182 
2183   std::vector<HloInstruction*> operands;
2184   std::vector<Shape> operand_shapes;
2185   for (const HloInstruction* operand : hlo->operands()) {
2186     operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2187     operand_shapes.push_back(operand->shape());
2188   }
2189   auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2190   auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
2191 
2192   auto on_device = b_.AddInstruction(
2193       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
2194   auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2195       ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device,
2196       ComparisonDirection::kEq));
2197 
2198   SpmdBuilder true_b("true_computation", visiting_hlo_);
2199   HloComputation* true_computation;
2200   {
2201     auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2202         /*parameter_number=*/0, operand_shape, "true_branch_param"));
2203     std::vector<HloInstruction*> new_operands;
2204     for (int64 i = 0; i < operands.size(); ++i) {
2205       new_operands.push_back(true_b.AddInstruction(
2206           HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
2207     }
2208     auto root = true_b.AddInstruction(
2209         hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2210     true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2211   }
2212 
2213   SpmdBuilder false_b("false_computation", visiting_hlo_);
2214   HloComputation* false_computation;
2215   {
2216     false_b.AddInstruction(HloInstruction::CreateParameter(
2217         /*parameter_number=*/0, operand_shape, "false_branch_param"));
2218     auto root = CreateZero(hlo->shape(), &false_b);
2219     false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2220   }
2221 
2222   SetPartitionedHlo(hlo, [&]() {
2223     return b_.AddInstruction(HloInstruction::CreateConditional(
2224         hlo->shape(), pred, operand, true_computation, operand,
2225         false_computation));
2226   });
2227   return Status::OK();
2228 }
2229 
HandleAllReduce(HloInstruction * hlo)2230 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2231   if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2232     return HandleElementwise(hlo);
2233   }
2234   return DefaultAction(hlo);
2235 }
2236 
HandleBroadcast(HloInstruction * hlo)2237 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2238   if (hlo->sharding().IsTileMaximal()) {
2239     return DefaultAction(hlo);
2240   }
2241 
2242   auto& operand = GetPartitionedHlo(hlo->operand(0));
2243 
2244   // Tiled output.
2245   std::vector<int64> new_dims;
2246   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2247     if (!absl::c_linear_search(hlo->dimensions(), i)) {
2248       new_dims.push_back(i);
2249     }
2250   }
2251   auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2252       hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2253                                                                new_dims),
2254       new_dims);
2255   auto input = operand.Reshard(desired_input_sharding).hlo();
2256   auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2257   SetPartitionedHlo(hlo, [&] {
2258     return b_.AddInstruction(
2259         hlo->CloneWithNewOperands(output_shard_shape, {input}));
2260   });
2261   return Status::OK();
2262 }
2263 
HandleConstant(HloInstruction * hlo)2264 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2265   const Literal& literal = hlo->literal();
2266   if (literal.shape().IsTuple() ||
2267       (!hlo->sharding().IsTileMaximal() &&
2268        (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2269         !literal.IsAllFirst()))) {
2270     return DefaultAction(hlo);
2271   }
2272 
2273   SetPartitionedHlo(hlo, [&]() {
2274     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2275     std::vector<int64> start_indices(hlo->shape().rank(), 0);
2276     auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2277         literal.Slice(start_indices, shard_shape.dimensions())));
2278     *constant->mutable_shape() = shard_shape;
2279     return constant;
2280   });
2281   return Status::OK();
2282 }
2283 
HandleDynamicSlice(HloInstruction * hlo)2284 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2285   if (hlo->sharding().IsTileMaximal()) {
2286     return DefaultAction(hlo);
2287   }
2288   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2289     if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2290         (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
2291          !hlo->operand(i + 1)->IsConstant() ||
2292          !hlo->operand(i + 1)->literal().IsZero({}))) {
2293       // We currently do not partition the sliced dimensions.
2294       return DefaultAction(hlo);
2295     }
2296   }
2297   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2298   auto new_input =
2299       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2300   for (int64 i = 0; i < new_indices.size(); ++i) {
2301     // Replicate the indices.
2302     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
2303                          .Reshard(HloSharding::Replicate())
2304                          .hlo();
2305   }
2306   SetPartitionedHlo(hlo, [&]() {
2307     auto partitioned_shape =
2308         MakePartitionedShape(hlo->shape(), hlo->sharding());
2309     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2310         partitioned_shape, new_input, new_indices,
2311         partitioned_shape.dimensions()));
2312   });
2313   return Status::OK();
2314 }
2315 
HandleDynamicUpdateSlice(HloInstruction * hlo)2316 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
2317   if (hlo->sharding().IsTileMaximal()) {
2318     return DefaultAction(hlo);
2319   }
2320 
2321   std::vector<int64> partitioned_slice_dims;
2322   std::vector<int64> slice_dims;
2323   std::vector<int64> partitioned_non_slice_dims;
2324   std::vector<int64> partitioned_slice_offsets;
2325   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2326     if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
2327       slice_dims.push_back(i);
2328       if (hlo->sharding().tile_assignment().dim(i) != 1) {
2329         if (!hlo->operand(i + 2)->IsConstant()) {
2330           return DefaultAction(hlo);
2331         }
2332         partitioned_slice_dims.push_back(i);
2333         partitioned_slice_offsets.push_back(
2334             hlo->operand(i + 2)->literal().Get<int>({}));
2335       }
2336     } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
2337       if (!hlo->operand(i + 2)->IsConstant() ||
2338           !hlo->operand(i + 2)->literal().IsZero({})) {
2339         return DefaultAction(hlo);
2340       }
2341       partitioned_non_slice_dims.push_back(i);
2342     }
2343   }
2344 
2345   // Handle when there is slice dim partitioned.
2346   if (!partitioned_slice_dims.empty()) {
2347     auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
2348       return b_.AddInstruction(std::move(to_add));
2349     };
2350     std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2351     for (int64 i = 0; i < new_indices.size(); ++i) {
2352       // Replicate the indices.
2353       new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2354                            .Reshard(HloSharding::Replicate())
2355                            .hlo();
2356     }
2357 
2358     // Get partitioned input.
2359     const auto& dus_sharding = hlo->sharding();
2360     const auto& partitioned_input =
2361         GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
2362 
2363     // Get replicate update.
2364     auto update_sharding = HloSharding::Replicate();
2365     if (!partitioned_non_slice_dims.empty()) {
2366       // Do partial replicate for update if non slice dims are partitioned.
2367       update_sharding =
2368           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
2369                                                                    slice_dims);
2370     }
2371     HloInstruction* replicate_update =
2372         GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
2373 
2374     const auto& update_shape = replicate_update->shape();
2375     const auto& partitioned_shape = partitioned_input->shape();
2376     auto partition_ordinals =
2377         MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
2378     HloInstruction* all_dims_within_partition = add_hlo(
2379         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2380 
2381     for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
2382       int dim = partitioned_slice_dims[i];
2383       // Calculate per partition size.
2384       const int64 per_partition_size = partitioned_shape.dimensions(dim);
2385 
2386       // Only update within a single partition is supported.
2387       if ((partitioned_slice_offsets[i] / per_partition_size) !=
2388           ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) - 1) /
2389            per_partition_size)) {
2390         return DefaultAction(hlo);
2391       }
2392 
2393       // within_partition = (offset >= partition_id * per_partition_size) &&
2394       //                    (offset < (partition_id + 1) * per_partition_size)
2395       const Shape& compare_shape =
2396           ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
2397       auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
2398           LiteralUtil::CreateR0<int>(per_partition_size)));
2399       const Shape& offset_shape = per_partition_size_hlo->shape();
2400       auto partition_offset = add_hlo(HloInstruction::CreateBinary(
2401           offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
2402           per_partition_size_hlo));
2403       // offset >= partition_id * per_partition_size
2404       auto offset_ge = add_hlo(HloInstruction::CreateCompare(
2405           compare_shape, new_indices[dim], partition_offset,
2406           ComparisonDirection::kGe));
2407       // offset < (partition_id + 1) * per_partition_size
2408       auto offset_lt = add_hlo(HloInstruction::CreateCompare(
2409           compare_shape, new_indices[dim],
2410           add_hlo(HloInstruction::CreateBinary(
2411               offset_shape, HloOpcode::kMultiply,
2412               add_hlo(HloInstruction::CreateBinary(
2413                   offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
2414                   add_hlo(HloInstruction::CreateConstant(
2415                       LiteralUtil::CreateR0<int>(1))))),
2416               per_partition_size_hlo)),
2417           ComparisonDirection::kLt));
2418       auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
2419           compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
2420 
2421       all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
2422           compare_shape, HloOpcode::kAnd, all_dims_within_partition,
2423           update_within_partition));
2424 
2425       // Calculate offset.
2426       // slice dim offset =
2427       //  within_partition ?
2428       //  offset - partition_id * per_partition_size : 0
2429       new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
2430           new_indices[dim]->shape(), HloOpcode::kSelect,
2431           update_within_partition,
2432           add_hlo(HloInstruction::CreateBinary(
2433               new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
2434               partition_offset)),
2435           add_hlo(
2436               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
2437     }
2438 
2439     // Create dynamic update slice.
2440     auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
2441         partitioned_shape, partitioned_input, replicate_update, new_indices));
2442     SetPartitionedHlo(hlo, [&]() {
2443       // Select if update is needed.
2444       return add_hlo(HloInstruction::CreateTernary(
2445           dus->shape(), HloOpcode::kSelect,
2446           add_hlo(HloInstruction::CreateBroadcast(
2447               ShapeUtil::ChangeElementType(dus->shape(), PRED),
2448               all_dims_within_partition, {})),
2449           dus, partitioned_input));
2450     });
2451     return Status::OK();
2452   }
2453 
2454   // Partition non slice dims only.
2455   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2456   auto new_input =
2457       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2458   auto new_update =
2459       GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
2460   for (int64 i = 0; i < new_indices.size(); ++i) {
2461     // Replicate the indices.
2462     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2463                          .Reshard(HloSharding::Replicate())
2464                          .hlo();
2465   }
2466   SetPartitionedHlo(hlo, [&]() {
2467     auto partitioned_shape =
2468         MakePartitionedShape(hlo->shape(), hlo->sharding());
2469     return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2470         partitioned_shape, new_input, new_update, new_indices));
2471   });
2472   return Status::OK();
2473 }
2474 
HandleGetTupleElement(HloInstruction * hlo)2475 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2476   const auto& tuple = GetPartitionedHlo(hlo->operand(0));
2477   auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2478       ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
2479       tuple.hlo(), hlo->tuple_index()));
2480   const auto source_sharding =
2481       tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
2482   gte->set_sharding(source_sharding);
2483   PartitionedHlo source_partitioned_gte(
2484       gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
2485       MakePartitioningState());
2486   source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
2487   SetPartitionedHlo(hlo, source_partitioned_gte);
2488   return Status::OK();
2489 }
2490 
HandleInfeed(HloInstruction * hlo)2491 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
2492   const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
2493   auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
2494   if (ShapeUtil::GetLeafCount(shape) == 0) {
2495     // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
2496     // requires one element for an empty tuple, but leaf-count number of
2497     // elements for non-empty tuple. So if it has a nested empty tuple, we
2498     // cannot invoke GetSubSharding() since it expects a sharding for the empty
2499     // tuple. This is a workaround for that case.
2500     SetPartitionedHlo(hlo, [&]() {
2501       return b_.AddInstruction(
2502           HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
2503     });
2504     return Status::OK();
2505   }
2506   auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2507   auto shard_shape = MakePartitionedShape(shape, sharding);
2508   if (EvenlyPartitions(shape, sharding)) {
2509     SetPartitionedHlo(hlo, [&]() {
2510       return b_.AddInstruction(HloInstruction::CreateInfeed(
2511           shard_shape, token, hlo->infeed_config()));
2512     });
2513     return Status::OK();
2514   }
2515 
2516   if (hlo->sharding().HasUniqueDevice()) {
2517     return HandleSingleDevice(hlo);
2518   }
2519 
2520   // Create a branch for each unique partitioned shape.
2521   std::vector<Shape> per_branch_partitioned_shapes;
2522   std::vector<int32> conditional_branch_indices(num_partitions_);
2523   for (int64 i = 0; i < num_partitions_; ++i) {
2524     auto partitioned_shape =
2525         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2526     int64 matching_existing_index = 0;
2527     for (; matching_existing_index < per_branch_partitioned_shapes.size();
2528          ++matching_existing_index) {
2529       if (ShapeUtil::Compatible(
2530               partitioned_shape,
2531               per_branch_partitioned_shapes[matching_existing_index])) {
2532         break;
2533       }
2534     }
2535     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2536       conditional_branch_indices[i] = matching_existing_index;
2537     } else {
2538       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2539       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2540     }
2541   }
2542 
2543   HloInstruction* branch_index;
2544   if (per_branch_partitioned_shapes.size() == num_partitions_) {
2545     // Use partition ID as the branch index if each partition has its own
2546     // branch.
2547     branch_index = partition_id_;
2548     // PartitionId's output is U32 but conditional requires S32.
2549     if (branch_index->shape().element_type() != S32) {
2550       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2551           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2552           branch_index));
2553     }
2554   } else {
2555     // Otherwise, use a constant table to look up the branch index.
2556     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2557         LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2558     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2559         ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2560         {1}));
2561     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2562         ShapeUtil::MakeShape(S32, {}), branch_index));
2563   }
2564 
2565   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2566   for (int64 i = 0; i < branches.size(); ++i) {
2567     SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
2568     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2569         /*parameter_number=*/0, token->shape(), "infeed_token_param"));
2570     auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
2571         per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
2572     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2573       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2574           pad_infeed = [&](const ShapeIndex& index,
2575                            HloInstruction* infeed_element) -> HloInstruction* {
2576         if (index == ShapeIndex({1})) {
2577           // Token.
2578           return infeed_element;
2579         }
2580         const Shape& element_shape =
2581             ShapeUtil::GetSubshape(infeed->shape(), index);
2582         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2583           std::vector<HloInstruction*> padded_elements(
2584               element_shape.tuple_shapes_size());
2585           for (int64 i = 0; i < padded_elements.size(); ++i) {
2586             auto sub_index = index;
2587             sub_index.push_back(i);
2588             padded_elements[i] = pad_infeed(
2589                 sub_index,
2590                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2591                     ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
2592                     i)));
2593           }
2594           return branch_b.AddInstruction(
2595               HloInstruction::CreateTuple(padded_elements));
2596         }
2597         const Shape& pad_shape =
2598             ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1));
2599         if (ShapeUtil::Compatible(element_shape, pad_shape)) {
2600           return infeed_element;
2601         }
2602         if (element_shape.IsArray()) {
2603           CHECK(pad_shape.IsArray());
2604           return PadToShape(infeed_element, pad_shape, &branch_b);
2605         }
2606         CHECK(element_shape.IsTuple());
2607         CHECK(element_shape.tuple_shapes().empty());
2608         return CreateZero(pad_shape, &branch_b);
2609       };
2610       pad_infeed({}, infeed);
2611     }
2612     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
2613   }
2614   SetPartitionedHlo(hlo, [&]() {
2615     return b_.AddInstruction(HloInstruction::CreateConditional(
2616         ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
2617         branches, std::vector<HloInstruction*>(branches.size(), token)));
2618   });
2619   return Status::OK();
2620 }
2621 
HandlePad(HloInstruction * hlo)2622 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
2623   if (hlo->sharding().IsTileMaximal()) {
2624     return DefaultAction(hlo);
2625   }
2626   auto lhs = GetPartitionedHlo(hlo->operand(0));
2627   // Create a window config to represent the pad.
2628   Window window;
2629   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2630     const auto& pd = hlo->padding_config().dimensions(i);
2631     WindowDimension* dim = window.add_dimensions();
2632     dim->set_size(1);
2633     dim->set_stride(1);
2634     dim->set_window_dilation(1);
2635     dim->set_window_reversal(false);
2636     dim->set_padding_low(pd.edge_padding_low());
2637     dim->set_padding_high(pd.edge_padding_high());
2638     dim->set_base_dilation(pd.interior_padding() + 1);
2639   }
2640 
2641   auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
2642                             .Reshard(HloSharding::Replicate())
2643                             .hlo();
2644   auto reshard_operand =
2645       lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
2646                                  /*mask_invalid_region=*/false);
2647   if (!reshard_operand.has_value()) {
2648     return DefaultAction(hlo);
2649   }
2650   PaddingConfig sharded_padding_config;
2651   bool need_pad = false;
2652   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2653     auto dim = sharded_padding_config.add_dimensions();
2654     const auto& wd = reshard_operand->shard_window.dimensions(i);
2655     dim->set_edge_padding_low(wd.padding_low());
2656     dim->set_edge_padding_high(wd.padding_high());
2657     dim->set_interior_padding(wd.base_dilation() - 1);
2658     if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
2659         wd.base_dilation() != 1) {
2660       need_pad = true;
2661     }
2662   }
2663   auto sharded_pad = reshard_operand->sharded_input;
2664   if (need_pad) {
2665     TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
2666                         ShapeInference::InferPadShape(sharded_pad->shape(),
2667                                                       replicated_rhs->shape(),
2668                                                       sharded_padding_config));
2669     sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
2670                                                    sharded_pad, replicated_rhs,
2671                                                    sharded_padding_config));
2672   }
2673 
2674   SetPartitionedHlo(hlo, [&]() {
2675     if (!reshard_operand->dynamic_slice_index_on_output) {
2676       return sharded_pad;
2677     }
2678     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2679     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2680         shard_shape, sharded_pad,
2681         *reshard_operand->dynamic_slice_index_on_output,
2682         shard_shape.dimensions()));
2683   });
2684   return Status::OK();
2685 }
2686 
HandleParameter(HloInstruction * hlo)2687 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
2688   SetPartitionedHlo(hlo, [&]() {
2689     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2690     auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
2691         hlo->parameter_number(), shard_shape, "param"));
2692     if (hlo->parameter_replicated_at_leaf_buffers()) {
2693       new_param->set_parameter_replicated_at_leaf_buffers(
2694           *hlo->parameter_replicated_at_leaf_buffers());
2695     }
2696     return new_param;
2697   });
2698   return Status::OK();
2699 }
2700 
HandleReduce(HloInstruction * hlo)2701 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
2702   int64 input_count = 1;
2703   auto per_input_sharding = hlo->sharding();
2704   if (hlo->shape().IsTuple()) {
2705     input_count = hlo->shape().tuple_shapes_size();
2706     CHECK_GT(input_count, 0);
2707     per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2708   }
2709 
2710   std::vector<PartitionedHlo> inputs;
2711   std::vector<HloInstruction*> inits;
2712   std::vector<int64> preserved_dims;
2713   for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
2714     if (!absl::c_linear_search(hlo->dimensions(), i)) {
2715       preserved_dims.push_back(i);
2716     }
2717   }
2718 
2719   for (int64 operand_id = 0; operand_id < input_count; ++operand_id) {
2720     inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
2721                         .Reshard(HloSharding::Replicate())
2722                         .hlo());
2723     inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
2724     if (operand_id > 0) {
2725       // Make sure all operands are sharded in the same way.
2726       inputs.back() = inputs.back().Reshard(inputs[0].sharding());
2727     }
2728     if (!inputs[0].sharding().IsTileMaximal()) {
2729       inputs.back() =
2730           inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
2731                                      /*skipped_dims=*/preserved_dims);
2732     }
2733   }
2734 
2735   std::vector<Shape*> new_operand_shapes(input_count * 2);
2736   for (int64 i = 0; i < input_count; ++i) {
2737     new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
2738     new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
2739   }
2740   // Create the shard shape of the reduce result.
2741   TF_ASSIGN_OR_RETURN(
2742       auto reduce_shape,
2743       ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
2744                                        hlo->to_apply()->ComputeProgramShape()));
2745 
2746   std::vector<HloInstruction*> input_hlos(input_count);
2747   for (int64 i = 0; i < input_count; ++i) {
2748     input_hlos[i] = inputs[i].hlo();
2749   }
2750   auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2751       reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
2752   local_reduce->set_metadata(hlo->metadata());
2753 
2754   SetPartitionedHlo(hlo, [&]() {
2755     HloInstruction* reduce = local_reduce;
2756     const bool reduce_sharded_dimension =
2757         !inputs[0].sharding().IsTileMaximal() &&
2758         absl::c_any_of(hlo->dimensions(), [&](int64 i) {
2759           return inputs[0].sharding().tile_assignment().dim(i) > 1;
2760         });
2761     if (reduce_sharded_dimension) {
2762       if (inputs[0].sharding().ReplicateOnLastTileDim()) {
2763         preserved_dims.push_back(inputs[0].base_shape().rank());
2764       }
2765       if (local_reduce->shape().IsArray()) {
2766         reduce = partitioner_->AllReduceAlongShardingDims(
2767             &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
2768             hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
2769       } else {
2770         auto grouped =
2771             GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
2772         auto grouped_state = CreatePerGroupPartitioningState(
2773             inputs[0].state(), grouped.device_groups, &b_);
2774         std::vector<HloInstruction*> all_gathered_partial_results(input_count);
2775         for (int64 i = 0; i < input_count; ++i) {
2776           auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2777               ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
2778               i));
2779           auto expanded_shape = input_hlos[i]->shape();
2780           auto all_gather_shape = input_hlos[i]->shape();
2781           for (int64 dim : hlo->dimensions()) {
2782             expanded_shape.set_dimensions(dim, 1);
2783             all_gather_shape.set_dimensions(
2784                 dim, inputs[0].sharding().tile_assignment().dim(dim));
2785           }
2786           auto reshape = b_.AddInstruction(
2787               HloInstruction::CreateReshape(expanded_shape, gte));
2788           // Replicate per group.
2789           reshape->set_sharding(grouped.sharding);
2790           all_gathered_partial_results[i] =
2791               PartitionedHlo(reshape, all_gather_shape, grouped_state)
2792                   .Replicate()
2793                   .hlo();
2794         }
2795         reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2796             reduce_shape, all_gathered_partial_results, inits,
2797             hlo->dimensions(), hlo->to_apply()));
2798       }
2799     }
2800     auto sharding = hlo_sharding_util::RemoveShapeDimensions(
2801         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2802             inputs[0].sharding(), hlo->dimensions()),
2803         hlo->dimensions());
2804     if (local_reduce->shape().IsArray()) {
2805       reduce->set_sharding(sharding);
2806     } else {
2807       reduce->set_sharding(HloSharding::Tuple(
2808           reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
2809     }
2810     return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
2811         .Reshard(hlo->sharding())
2812         .hlo();
2813   });
2814   return Status::OK();
2815 }
2816 
HandleReverse(HloInstruction * hlo)2817 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
2818   auto reverse = Cast<HloReverseInstruction>(hlo);
2819   if (reverse->sharding().IsTileMaximal()) {
2820     return DefaultAction(hlo);
2821   }
2822   auto operand = GetPartitionedHlo(reverse->operand(0))
2823                      .Reshard(hlo_sharding_util::ReverseSharding(
2824                          reverse->sharding(), reverse->dimensions()));
2825   auto left_padded_operand =
2826       HaloExchangeToPadOnLeft(operand, reverse->dimensions());
2827   if (!left_padded_operand) {
2828     return DefaultAction(hlo);
2829   }
2830   SetPartitionedHlo(hlo, [&] {
2831     return b_.AddInstruction(hlo->CloneWithNewOperands(
2832         left_padded_operand->shape(), {left_padded_operand}));
2833   });
2834   return Status::OK();
2835 }
2836 
HandleWhile(HloInstruction * hlo)2837 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
2838   const HloSharding& sharding = hlo->sharding();
2839 
2840   // Shardings for the body parameter, body root, and cond parameter must be
2841   // the same, and the condition root must be replicated so that all partitions
2842   // follow the same control flow.
2843   hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
2844   hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
2845   TF_RETURN_IF_ERROR(partitioner_
2846                          ->PartitionComputation(hlo->while_condition(),
2847                                                 HloSharding::Replicate(),
2848                                                 next_channel_id_, logger_)
2849                          .status());
2850   TF_RETURN_IF_ERROR(partitioner_
2851                          ->PartitionComputation(hlo->while_body(), sharding,
2852                                                 next_channel_id_, logger_)
2853                          .status());
2854   SetPartitionedHlo(hlo, [&] {
2855     return b_.AddInstruction(HloInstruction::CreateWhile(
2856         MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
2857         hlo->while_body(),
2858         GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
2859   });
2860   return Status::OK();
2861 }
2862 
HandleConditional(HloInstruction * hlo)2863 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
2864   std::vector<HloInstruction*> branch_args;
2865   for (int64 i = 0; i < hlo->branch_count(); ++i) {
2866     HloComputation* computation = hlo->branch_computation(i);
2867 
2868     // Shardings of the branch computation parameter and its argument must be
2869     // the same.
2870     computation->parameter_instruction(0)->set_sharding(
2871         hlo->operand(i + 1)->sharding());
2872     branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
2873   }
2874 
2875   // The root of the branch computations must follow the sharding of the
2876   // conditional instruction.
2877   for (int64 i = 0; i < hlo->branch_count(); ++i) {
2878     HloComputation* computation = hlo->branch_computation(i);
2879     TF_RETURN_IF_ERROR(partitioner_
2880                            ->PartitionComputation(computation, hlo->sharding(),
2881                                                   next_channel_id_, logger_)
2882                            .status());
2883   }
2884 
2885   // We replicate the predicate of the conditional (the first operand) so that
2886   // all partitions follow the same control flow.
2887   SetPartitionedHlo(hlo, [&] {
2888     return b_.AddInstruction(HloInstruction::CreateConditional(
2889         MakePartitionedShape(hlo->shape(), hlo->sharding()),
2890         GetPartitionedHlo(hlo->operand(0))
2891             .Reshard(HloSharding::Replicate())
2892             .hlo(),
2893         hlo->called_computations(), branch_args));
2894   });
2895   return Status::OK();
2896 }
2897 
HandleOutfeed(HloInstruction * hlo)2898 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
2899   if (hlo->sharding().HasUniqueDevice()) {
2900     return HandleSingleDevice(hlo);
2901   }
2902 
2903   const auto& sharding = hlo->sharding();
2904   const Shape& shape = hlo->operand(0)->shape();
2905   auto partitioned_operand =
2906       GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2907   const auto& shard_shape = partitioned_operand.hlo()->shape();
2908   const auto& operand = partitioned_operand.hlo();
2909   auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
2910 
2911   if (EvenlyPartitions(shape, sharding)) {
2912     Shape outfeed_shape = operand->shape();
2913     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
2914                                                            &outfeed_shape));
2915     SetPartitionedHlo(hlo, [&]() {
2916       return b_.AddInstruction(HloInstruction::CreateOutfeed(
2917           outfeed_shape, operand, token, hlo->outfeed_config()));
2918     });
2919     return Status::OK();
2920   }
2921 
2922   // Create a branch for each unique partitioned shape.
2923   std::vector<Shape> per_branch_partitioned_shapes;
2924   std::vector<int32> conditional_branch_indices(num_partitions_);
2925   for (int64 i = 0; i < num_partitions_; ++i) {
2926     auto partitioned_shape =
2927         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2928     int64 matching_existing_index = 0;
2929     for (; matching_existing_index < per_branch_partitioned_shapes.size();
2930          ++matching_existing_index) {
2931       if (ShapeUtil::Compatible(
2932               partitioned_shape,
2933               per_branch_partitioned_shapes[matching_existing_index])) {
2934         break;
2935       }
2936     }
2937     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2938       conditional_branch_indices[i] = matching_existing_index;
2939     } else {
2940       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2941       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2942     }
2943   }
2944 
2945   // Get branch index for this partition.
2946   HloInstruction* branch_index;
2947   if (per_branch_partitioned_shapes.size() == num_partitions_) {
2948     // Use partition ID as the branch index if each partition has its own
2949     // branch.
2950     branch_index = partition_id_;
2951     // PartitionId's output is U32 but conditional requires S32.
2952     if (branch_index->shape().element_type() != S32) {
2953       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2954           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2955           branch_index));
2956     }
2957   } else {
2958     // Otherwise, use a constant table to look up the branch index.
2959     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2960         LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2961     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2962         ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2963         {1}));
2964     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2965         ShapeUtil::MakeShape(S32, {}), branch_index));
2966   }
2967 
2968   // Create conditional for the outfeed.
2969   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2970   for (int64 i = 0; i < branches.size(); ++i) {
2971     SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
2972     // Create tuple param within the branch.
2973     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2974         /*parameter_number=*/0,
2975         ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
2976         "outfeed_token_param"));
2977     auto outfeed_data = branch_b.AddInstruction(
2978         HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
2979     auto outfeed_token = branch_b.AddInstruction(
2980         HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
2981     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2982       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2983           slice_outfeed =
2984               [&](const ShapeIndex& index,
2985                   HloInstruction* outfeed_operand) -> HloInstruction* {
2986         // Get outfeed element shape.
2987         const Shape& element_shape =
2988             ShapeUtil::GetSubshape(outfeed_data->shape(), index);
2989         // Recursively call slice_outfeed for tuple shapes.
2990         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2991           std::vector<HloInstruction*> slice_elements(
2992               element_shape.tuple_shapes_size());
2993           for (int64 i = 0; i < slice_elements.size(); ++i) {
2994             auto sub_index = index;
2995             sub_index.push_back(i);
2996             slice_elements[i] = slice_outfeed(
2997                 sub_index,
2998                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2999                     ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
3000                     i)));
3001           }
3002           return branch_b.AddInstruction(
3003               HloInstruction::CreateTuple(slice_elements));
3004         }
3005         // Get the slice shape.
3006         const Shape& slice_shape = ShapeUtil::GetSubshape(
3007             per_branch_partitioned_shapes[i], ShapeIndexView(index));
3008         if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3009           return outfeed_operand;
3010         }
3011         // Slice out useful data.
3012         if (element_shape.IsArray()) {
3013           CHECK(slice_shape.IsArray());
3014           std::vector<int64> start_indices(slice_shape.rank(), 0);
3015           std::vector<int64> slice_strides(slice_shape.rank(), 1);
3016           return branch_b.AddInstruction(HloInstruction::CreateSlice(
3017               slice_shape, outfeed_operand, start_indices,
3018               slice_shape.dimensions(), slice_strides));
3019         }
3020         CHECK(element_shape.IsTuple());
3021         CHECK(element_shape.tuple_shapes().empty());
3022         return outfeed_operand;
3023       };
3024       outfeed_data = slice_outfeed({}, outfeed_data);
3025     }
3026     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3027         hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3028     branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3029         per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3030         hlo->outfeed_config()));
3031     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3032   }
3033   SetPartitionedHlo(hlo, [&]() {
3034     return b_.AddInstruction(HloInstruction::CreateConditional(
3035         token->shape(), branch_index, branches,
3036         std::vector<HloInstruction*>(
3037             branches.size(),
3038             b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3039   });
3040   return Status::OK();
3041 }
3042 
HandleRng(HloInstruction * hlo)3043 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3044   if (hlo->sharding().HasUniqueDevice()) {
3045     return HandleSingleDevice(hlo);
3046   }
3047 
3048   if (hlo->sharding().IsReplicated()) {
3049     SetPartitionedHlo(hlo, [&] {
3050       // Run on a single device (0) and distribute the data to all other cores.
3051       std::vector<HloInstruction*> new_operands;
3052       for (int64 i = 0; i < hlo->operand_count(); ++i) {
3053         new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3054                                    .Reshard(HloSharding::AssignDevice(0))
3055                                    .hlo());
3056       }
3057       auto clone = b_.AddInstruction(
3058           hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3059       clone->set_sharding(HloSharding::AssignDevice(0));
3060       return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3061           .Reshard(HloSharding::Replicate())
3062           .hlo();
3063     });
3064     return Status::OK();
3065   }
3066 
3067   TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3068   // Replicate the operands and run partitioned Rng on all devices.
3069   std::vector<HloInstruction*> new_operands;
3070   for (int64 i = 0; i < hlo->operand_count(); ++i) {
3071     new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3072                                .Reshard(HloSharding::Replicate())
3073                                .hlo());
3074   }
3075 
3076   if (!hlo->sharding().ReplicateOnLastTileDim()) {
3077     SetPartitionedHlo(hlo, [&] {
3078       return b_.AddInstruction(HloInstruction::CreateRng(
3079           MakePartitionedShape(hlo->shape(), hlo->sharding()),
3080           hlo->random_distribution(), new_operands));
3081     });
3082   } else {
3083     std::vector<int64> group_dims(
3084         hlo->sharding().tile_assignment().num_dimensions() - 1);
3085     std::iota(group_dims.begin(), group_dims.end(), 0);
3086     auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims);
3087     auto per_group_state = CreatePerGroupPartitioningState(
3088         MakePartitioningState(), sharding_grouped.device_groups, &b_);
3089     auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3090         MakePartitionedShape(hlo->shape(), hlo->sharding()),
3091         hlo->random_distribution(), new_operands));
3092     rng->set_sharding(HloSharding::AssignDevice(0));
3093     SetPartitionedHlo(hlo, [&]() {
3094       return PartitionedHlo(rng, rng->shape(), per_group_state)
3095           .Replicate()
3096           .hlo();
3097     });
3098   }
3099   return Status::OK();
3100 }
3101 
HandleReduceWindow(HloInstruction * hlo)3102 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3103   // TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
3104   if (hlo->shape().IsTuple()) {
3105     return DefaultAction(hlo);
3106   }
3107   auto& operand = GetPartitionedHlo(hlo->operand(0));
3108   if (hlo->sharding().IsTileMaximal()) {
3109     return DefaultAction(hlo);
3110   }
3111 
3112   // Replicate init
3113   auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1))
3114                              .Reshard(HloSharding::Replicate());
3115   auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3116       hlo->window(), hlo->sharding(), replicated_init.hlo());
3117   if (!resharded_operand_and_window.has_value()) {
3118     return DefaultAction(hlo);
3119   }
3120 
3121   TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3122                       ShapeInference::InferReduceWindowShape(
3123                           resharded_operand_and_window->sharded_input->shape(),
3124                           replicated_init.hlo()->shape(),
3125                           resharded_operand_and_window->shard_window,
3126                           hlo->to_apply()->ComputeProgramShape()));
3127   auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3128   *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3129   SetPartitionedHlo(hlo, [&]() {
3130     auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow(
3131         sharded_rw_shape, resharded_operand_and_window->sharded_input,
3132         replicated_init.hlo(), resharded_operand_and_window->shard_window,
3133         hlo->to_apply()));
3134     if (!resharded_operand_and_window->dynamic_slice_index_on_output
3135              .has_value()) {
3136       CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()));
3137       return sharded_rw;
3138     }
3139     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3140         shard_shape, sharded_rw,
3141         *resharded_operand_and_window->dynamic_slice_index_on_output,
3142         shard_shape.dimensions()));
3143   });
3144   return Status::OK();
3145 }
3146 
HandleSelectAndScatter(HloInstruction * hlo)3147 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3148   if (hlo->sharding().IsTileMaximal()) {
3149     return DefaultAction(hlo);
3150   }
3151   auto operand = GetPartitionedHlo(hlo->operand(0));
3152   auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3153   if (hlo->sharding() != operand.sharding()) {
3154     operand = operand.Reshard(hlo->sharding());
3155   }
3156   if (hlo->sharding() != source.sharding()) {
3157     source = source.Reshard(hlo->sharding());
3158   }
3159 
3160   // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3161   // low/high padding, since comparison will return false with NaN input.
3162   if (hlo->shape().element_type() != F32 &&
3163       hlo->shape().element_type() != BF16) {
3164     return DefaultAction(hlo);
3165   }
3166 
3167   auto select = hlo->called_computations()[0];
3168   auto select_root = select->root_instruction();
3169   if (select_root->opcode() != HloOpcode::kCompare ||
3170       select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3171       select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3172       select_root->operand(0)->parameter_number() +
3173               select_root->operand(1)->parameter_number() !=
3174           1) {
3175     return DefaultAction(hlo);
3176   }
3177 
3178   float float_pad_value;
3179   if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3180       select_root->comparison_direction() == ComparisonDirection::kGt) {
3181     if (select_root->operand(0)->parameter_number() == 0) {
3182       float_pad_value = -std::numeric_limits<float>::infinity();
3183     } else {
3184       float_pad_value = std::numeric_limits<float>::infinity();
3185     }
3186   } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3187              select_root->comparison_direction() == ComparisonDirection::kLt) {
3188     if (select_root->operand(0)->parameter_number() == 0) {
3189       float_pad_value = std::numeric_limits<float>::infinity();
3190     } else {
3191       float_pad_value = -std::numeric_limits<float>::infinity();
3192     }
3193   } else {
3194     return DefaultAction(hlo);
3195   }
3196 
3197   auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
3198       hlo->shape().element_type() == BF16
3199           ? LiteralUtil::CreateR0<bfloat16>(
3200                 static_cast<bfloat16>(float_pad_value))
3201           : LiteralUtil::CreateR0<float>(float_pad_value)));
3202 
3203   // Replicate init
3204   auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
3205                              .Reshard(HloSharding::Replicate());
3206 
3207   auto partition_ordinals =
3208       MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
3209 
3210   // The first window for each dimension that overlaps with the shard area.
3211   std::vector<MultiplyAddDivideOffsetCalculation> first_window(
3212       hlo->shape().rank());
3213   // The first window for each dimension that goes beyond with the shard area.
3214   std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
3215       hlo->shape().rank());
3216   std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
3217   std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
3218   std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
3219   std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
3220   auto unpadded_data_shard_shape =
3221       MakePartitionedShape(hlo->shape(), hlo->sharding());
3222   auto unpadded_source_shard_shape =
3223       MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
3224   auto source_shard_hlo = source.hlo();
3225   auto data_shard_hlo = operand.hlo();
3226   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
3227     int64 shard_count = hlo->sharding().tile_assignment().dim(i);
3228     if (shard_count == 1) {
3229       continue;
3230     }
3231     // If stride > window_size, there will be gaps between windows. These gaps
3232     // will also exist in the output, so we keep them during halo exchange.
3233     //
3234     // TODO(yuanzx): This could introduce overhead if partitions start at
3235     // different offsets in a gap.
3236     auto wd = hlo->window().dimensions(i);
3237     if (wd.stride() > wd.size()) {
3238       wd.set_size(wd.stride());
3239     }
3240     // shard_size * i < stride * k - pad_low + window_size  =>
3241     //   k > (shard_size * i + pad_low - window_size) / stride  =>
3242     //   first_k == (shard_size * i + pad_low - window_size + stride) / stride
3243     first_window[i] = MultiplyAddDivideOffsetCalculation(
3244         unpadded_data_shard_shape.dimensions(i),
3245         wd.padding_low() - wd.size() + wd.stride(), wd.stride());
3246     // shard_size * (i + 1) <= stride * k - pad_low  =>
3247     //   k >= (shard_size * i + shard_size + pad_low) / stride  =>
3248     //   limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
3249     //     stride
3250     limit_window[i] = MultiplyAddDivideOffsetCalculation(
3251         unpadded_data_shard_shape.dimensions(i),
3252         unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
3253             wd.stride() - 1,
3254         wd.stride());
3255     source_left_halo_sizes[i] =
3256         MultiplyAddDivideOffsetCalculation(
3257             unpadded_source_shard_shape.dimensions(i), 0, 1) -
3258         first_window[i];
3259     source_right_halo_sizes[i] =
3260         limit_window[i] - MultiplyAddDivideOffsetCalculation(
3261                               unpadded_source_shard_shape.dimensions(i),
3262                               unpadded_source_shard_shape.dimensions(i), 1);
3263     data_left_halo_sizes[i] =
3264         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3265             unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
3266         OffsetCalculation(
3267             HloOpcode::kMultiply, first_window[i],
3268             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
3269     data_right_halo_sizes[i] =
3270         OffsetCalculation(
3271             HloOpcode::kMultiply, limit_window[i],
3272             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
3273         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3274             unpadded_data_shard_shape.dimensions(i),
3275             unpadded_data_shard_shape.dimensions(i) + wd.stride() +
3276                 wd.padding_low() - wd.size(),
3277             1));
3278 
3279     int64 max_windows =
3280         (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
3281     auto first_window_hlo =
3282         first_window[i].Calculate(partition_ordinals[i], &b_);
3283     // Padding on the source is filled with the init value so they do not change
3284     // the data on overlapping windows.
3285     auto resharded_source = ExchangeHaloAndGetValidData(
3286         source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
3287         source_right_halo_sizes[i], 0,
3288         limit_window[i].Calculate(shard_count - 1), max_windows, i,
3289         hlo->sharding(), first_window_hlo, replicated_init.hlo(),
3290         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3291     if (!resharded_source) {
3292       return DefaultAction(hlo);
3293     }
3294     source_shard_hlo = *resharded_source;
3295 
3296     auto offset_start_in_data =
3297         MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
3298             .Calculate(first_window_hlo, &b_);
3299     int64 padded_data_size =
3300         (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
3301         wd.size();
3302     int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
3303     auto resharded_data = ExchangeHaloAndGetValidData(
3304         data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
3305         data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
3306         data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
3307         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3308     if (!resharded_data) {
3309       return DefaultAction(hlo);
3310     }
3311     data_shard_hlo = *resharded_data;
3312   }
3313 
3314   Window window_on_shard = hlo->window();
3315   for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
3316     int64 shard_count = hlo->sharding().tile_assignment().dim(i);
3317     if (shard_count == 1) {
3318       continue;
3319     }
3320     auto reshard_wd = window_on_shard.mutable_dimensions(i);
3321     // The shards are already explicitly padded.
3322     reshard_wd->set_padding_low(0);
3323     reshard_wd->set_padding_high(0);
3324   }
3325 
3326   auto sharded_select_and_scatter =
3327       b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
3328           data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
3329           source_shard_hlo, replicated_init.hlo(),
3330           hlo->called_computations()[1]));
3331   SetPartitionedHlo(hlo, [&]() {
3332     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3333     if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
3334                               shard_shape)) {
3335       return sharded_select_and_scatter;
3336     }
3337     auto zero = b_.AddInstruction(
3338         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
3339     std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
3340     for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
3341       if (hlo->sharding().tile_assignment().dim(i) == 1) {
3342         continue;
3343       }
3344       int64 pad_low = hlo->window().dimensions(i).padding_low();
3345       auto left_halo_size =
3346           data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
3347       if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
3348         slice_offsets[i] = left_halo_size;
3349       } else {
3350         auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
3351             ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
3352             ComparisonDirection::kEq));
3353         auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
3354             LiteralUtil::CreateR0<int32>(pad_low)));
3355         slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
3356             zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
3357             left_halo_size));
3358       }
3359     }
3360     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3361         shard_shape, sharded_select_and_scatter, slice_offsets,
3362         shard_shape.dimensions()));
3363   });
3364   return Status::OK();
3365 }
3366 
HandleTuple(HloInstruction * hlo)3367 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
3368   std::vector<HloInstruction*> new_operands;
3369   for (int64 i = 0; i < hlo->operand_count(); ++i) {
3370     new_operands.push_back(
3371         GetPartitionedHlo(hlo->operand(i))
3372             .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
3373             .hlo());
3374   }
3375   SetPartitionedHlo(hlo, [&]() {
3376     return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
3377   });
3378   return Status::OK();
3379 }
3380 
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)3381 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
3382     HloComputation* computation, const HloSharding& root_sharding,
3383     const SpmdPartitionerOptions& options) {
3384   VLOG(2) << "Partitioning computation " << computation->name() << " for "
3385           << num_replicas_ << " replicas and " << num_partitions_
3386           << " partitions";
3387   TF_RETURN_IF_ERROR(computation->Accept(this));
3388 
3389   HloModule* module = computation->parent();
3390   auto new_root =
3391       GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
3392   auto new_computation =
3393       module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
3394   TF_RETURN_IF_ERROR(
3395       DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
3396 
3397   // Replace the original computation with the new SPMD computation.
3398   std::unordered_map<HloComputation*, HloComputation*> replacement;
3399   replacement[computation] = new_computation;
3400   module->ReplaceComputations(replacement);
3401   return changed_;
3402 }
3403 
HandlePartitionId(HloInstruction * hlo)3404 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
3405   return Unimplemented(
3406       "PartitionId instruction is not supported for SPMD partitioning since "
3407       "the meaning is ambiguous -- whether the instruction is replicated or "
3408       "the data is replicated, and if the latter which data is replicated.");
3409 }
3410 
GetDefaultCollectiveOpsCreator(int64 num_partitions,int64 num_replicas)3411 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
3412                                                         int64 num_replicas) {
3413   return {
3414       [](SpmdBuilder* b) {
3415         return b->AddInstruction(HloInstruction::CreatePartitionId());
3416       },
3417       [num_replicas, num_partitions](
3418           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
3419           const std::vector<std::vector<int64>>& partition_subgroups,
3420           int64 channel_id) {
3421         if (partition_subgroups.size() <= 1) {
3422           std::vector<ReplicaGroup> groups(num_replicas);
3423           // TODO(yuanzx): Unify subgroup definition with AllToAll.
3424           for (int64 i = 0; i < num_replicas; ++i) {
3425             groups[i].add_replica_ids(i);
3426           }
3427           return b->AddInstruction(HloInstruction::CreateAllReduce(
3428               operand->shape(), {operand}, reduction, groups,
3429               /*constrain_layout=*/false, channel_id,
3430               /*use_global_device_ids=*/false));
3431         }
3432 
3433         std::vector<ReplicaGroup> device_groups;
3434         device_groups.reserve(partition_subgroups.size() * num_replicas);
3435         for (int64 i = 0; i < num_replicas; ++i) {
3436           for (const auto& pgroup : partition_subgroups) {
3437             device_groups.emplace_back();
3438             for (int64 pid : pgroup) {
3439               device_groups.back().add_replica_ids(i * num_partitions + pid);
3440             }
3441           }
3442         }
3443         return b->AddInstruction(HloInstruction::CreateAllReduce(
3444             operand->shape(), {operand}, reduction, device_groups,
3445             /*constrain_layout=*/false, channel_id,
3446             /*use_global_device_ids=*/true));
3447       },
3448       [](SpmdBuilder* b, HloInstruction* operand,
3449          std::vector<std::pair<int64, int64>>& src_dst_pairs,
3450          int64 channel_id) {
3451         return b->AddInstruction(HloInstruction::CreateCollectivePermute(
3452             operand->shape(), operand, src_dst_pairs, channel_id));
3453       },
3454       [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
3455          const std::vector<std::vector<int64>>& partition_subgroups,
3456          int64 channel_id, absl::optional<int64> split_dimension) {
3457         std::vector<Shape> shapes(operands.size(), operands[0]->shape());
3458         const Shape output_shape = (shapes.size() == 1)
3459                                        ? shapes[0]
3460                                        : ShapeUtil::MakeTupleShape(shapes);
3461         std::vector<ReplicaGroup> groups(partition_subgroups.size());
3462         for (int64 i = 0; i < groups.size(); ++i) {
3463           for (int64 id : partition_subgroups[i]) {
3464             groups[i].add_replica_ids(id);
3465           }
3466         }
3467         return b->AddInstruction(HloInstruction::CreateAllToAll(
3468             output_shape, operands, groups,
3469             /*constrain_layout=*/false, channel_id, split_dimension));
3470       },
3471       [num_replicas, num_partitions](
3472           SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
3473           const std::vector<std::vector<int64>>& partition_subgroups,
3474           int64 channel_id, int64 all_gather_dimension) {
3475         std::vector<ReplicaGroup> device_groups;
3476         device_groups.reserve(partition_subgroups.size() * num_replicas);
3477         for (int64 i = 0; i < num_replicas; ++i) {
3478           for (const auto& pgroup : partition_subgroups) {
3479             device_groups.emplace_back();
3480             for (int64 pid : pgroup) {
3481               device_groups.back().add_replica_ids(i * num_partitions + pid);
3482             }
3483           }
3484         }
3485         return b->AddInstruction(HloInstruction::CreateAllGather(
3486             ag_shape, operand, all_gather_dimension, device_groups,
3487             /*constrain_layout=*/false, channel_id,
3488             /*use_global_device_ids=*/true));
3489       },
3490   };
3491 }
3492 
SpmdPartitioner(int64 num_partitions,int64 num_replicas,SpmdPartitionerOptions options)3493 SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
3494                                  SpmdPartitionerOptions options)
3495     : SpmdPartitioner(
3496           num_partitions, num_replicas, std::move(options),
3497           GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
3498 
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)3499 HloInstruction* SpmdPartitioner::AllGatherShards(
3500     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3501     int64* next_channel_id, absl::Span<const int64> selected_dims,
3502     const SPMDCollectiveOpsCreator& collectives_creator) {
3503   return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
3504                                  selected_dims, collectives_creator,
3505                                  /*per_dim_ag=*/true);
3506 }
3507 
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)3508 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
3509     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3510     int64* next_channel_id, absl::Span<const int64> selected_dims,
3511     const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
3512   if (selected_dims.empty()) {
3513     return operand;
3514   }
3515   CHECK(!sharding.IsTileMaximal());
3516   // Add one leading dimension to gather all partitions.
3517   std::vector<int64> shape;
3518   shape.push_back(1);
3519   for (int64 dim : operand->shape().dimensions()) {
3520     shape.push_back(dim);
3521   }
3522   auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
3523       ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
3524   HloInstruction* result = reshape;
3525   if (per_dim_ag) {
3526     for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3527       if (sharding.tile_assignment().dim(*it) == 1) {
3528         continue;
3529       }
3530       auto partition_subgroups =
3531           GetPartitionGroupsForReplication(sharding, {*it});
3532       shape[0] *= partition_subgroups[0].size();
3533       result = collectives_creator.create_cross_partition_all_gather(
3534           b, result,
3535           ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3536           partition_subgroups, (*next_channel_id)++,
3537           /*all_gather_dimension=*/0);
3538     }
3539   } else {
3540     auto partition_subgroups =
3541         GetPartitionGroupsForReplication(sharding, selected_dims);
3542     shape[0] *= partition_subgroups[0].size();
3543     result = collectives_creator.create_cross_partition_all_gather(
3544         b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3545         partition_subgroups, (*next_channel_id)++,
3546         /*all_gather_dimension=*/0);
3547   }
3548   // If n > 1 dimensions are partitioned, split the leading dimension to n.
3549   std::vector<int64> tiled_dims;
3550   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
3551     if (sharding.tile_assignment().dim(i) > 1 &&
3552         absl::c_linear_search(selected_dims, i)) {
3553       tiled_dims.push_back(i);
3554     }
3555   }
3556   if (tiled_dims.size() > 1) {
3557     std::vector<int64> split_dim_shape;
3558     split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
3559     for (int64 i : tiled_dims) {
3560       split_dim_shape.push_back(sharding.tile_assignment().dim(i));
3561     }
3562     for (int64 dim : operand->shape().dimensions()) {
3563       split_dim_shape.push_back(dim);
3564     }
3565     result = b->AddInstruction(HloInstruction::CreateReshape(
3566         ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
3567         result));
3568   }
3569   // Transpose the gathered dimensions to next to their corresponding
3570   // partitioned dimensions.
3571   std::vector<int64> xpose_permutation(result->shape().rank());
3572   int64 split_dims_added = 0;
3573   for (int64 i = 0; i < xpose_permutation.size(); ++i) {
3574     if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
3575         !absl::c_linear_search(selected_dims, i - split_dims_added)) {
3576       xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
3577     } else {
3578       xpose_permutation[i] = split_dims_added;
3579       xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
3580       split_dims_added++;
3581       i++;
3582     }
3583   }
3584   result = b->AddInstruction(HloInstruction::CreateTranspose(
3585       ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
3586           .ValueOrDie(),
3587       result, xpose_permutation));
3588   // Reshape to the desired shape.
3589   auto ag_shape = operand->shape();
3590   for (int64 i : tiled_dims) {
3591     ag_shape.set_dimensions(
3592         i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
3593   }
3594   result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
3595   return result;
3596 }
3597 
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)3598 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
3599     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3600     int64* next_channel_id, absl::Span<const int64> selected_dims,
3601     const SPMDCollectiveOpsCreator& collectives_creator,
3602     HloComputation* reduction) {
3603   return AllReduceAlongShardingDimsInternal(
3604       b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
3605       reduction, /*per_dim_ar=*/true);
3606 }
3607 
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)3608 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
3609     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3610     int64* next_channel_id, absl::Span<const int64> selected_dims,
3611     const SPMDCollectiveOpsCreator& collectives_creator,
3612     HloComputation* reduction, bool per_dim_ar) {
3613   if (!per_dim_ar) {
3614     auto partition_subgroups =
3615         GetPartitionGroupsForReplication(sharding, selected_dims);
3616     return collectives_creator.create_cross_partition_all_reduce(
3617         b, operand, reduction, partition_subgroups, (*next_channel_id)++);
3618   }
3619   auto result = operand;
3620   for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3621     if (sharding.tile_assignment().dim(*it) == 1) {
3622       continue;
3623     }
3624     auto partition_subgroups =
3625         GetPartitionGroupsForReplication(sharding, {*it});
3626     result = collectives_creator.create_cross_partition_all_reduce(
3627         b, result, reduction, partition_subgroups, (*next_channel_id)++);
3628   }
3629   return result;
3630 }
3631 
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64 * next_channel_id,SpmdLogger * logger)3632 StatusOr<bool> SpmdPartitioner::PartitionComputation(
3633     HloComputation* computation, const HloSharding& root_sharding,
3634     int64* next_channel_id, SpmdLogger* logger) {
3635   auto visitor =
3636       CreateVisitor(computation, num_partitions_, num_replicas_,
3637                     collective_ops_creator_, next_channel_id, logger, options_);
3638   return visitor->DoPartition(computation, root_sharding, options_);
3639 }
3640 
CreateVisitor(HloComputation * computation,int64 num_partitions,int64 num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)3641 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
3642     HloComputation* computation, int64 num_partitions, int64 num_replicas,
3643     const SPMDCollectiveOpsCreator& collective_ops_creator,
3644     int64* next_channel_id, SpmdLogger* logger,
3645     SpmdPartitionerOptions options) {
3646   return absl::make_unique<SpmdPartitioningVisitor>(
3647       computation, num_partitions, num_replicas, collective_ops_creator,
3648       next_channel_id, logger, std::move(options), this);
3649 }
3650 
Run(HloModule * module)3651 StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
3652   TF_RETURN_IF_ERROR(PreprocessSharding(module));
3653 
3654   XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
3655                         *module, options_.report_instruction_count));
3656 
3657   // Add the parameters' and output's shardings to the module.
3658   std::vector<HloSharding> entry_params_shardings;
3659   for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) {
3660     auto param = module->entry_computation()->parameter_instruction(i);
3661     CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
3662     entry_params_shardings.push_back(param->sharding());
3663   }
3664   module->set_spmd_parameters_shardings(entry_params_shardings);
3665   auto entry_root = module->entry_computation()->root_instruction();
3666   CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
3667   module->set_spmd_output_sharding(entry_root->sharding());
3668 
3669   FlattenCallGraph flatten;
3670   TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
3671 
3672   SpmdLogger logger(options_.report_instruction_count);
3673   auto program_shape = module->entry_computation()->ComputeProgramShape();
3674   int64 next_channel_id = hlo_query::NextChannelId(*module);
3675   // Copy the root sharding since the partitioner visitor may temporarily change
3676   // the sharding to work around manual sharding.
3677   HloSharding root_sharding = entry_root->sharding();
3678   TF_ASSIGN_OR_RETURN(
3679       bool partition_changed,
3680       PartitionComputation(module->entry_computation(), root_sharding,
3681                            &next_channel_id, &logger));
3682   changed |= partition_changed;
3683 
3684   // For the entry computation, make sure that the root instruction and the
3685   // parameters preserve their signatures.
3686   auto new_program_shape = module->entry_computation()->ComputeProgramShape();
3687   if (!options_.allow_module_signature_change) {
3688     TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3689         program_shape.result(), new_program_shape.result()))
3690         << "Result shape changed for the entry computation";
3691     TF_RET_CHECK(program_shape.parameters_size() ==
3692                  new_program_shape.parameters_size())
3693         << "Parameter count changed for the entry computation";
3694     for (int64 i = 0; i < program_shape.parameters_size(); ++i) {
3695       TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3696           program_shape.parameters(i), new_program_shape.parameters(i)))
3697           << "Parameter shape changed for the entry computation";
3698     }
3699   } else {
3700     const auto& old_entry_layout = module->entry_computation_layout();
3701     // Shapes can change but the layout should still remain the same.
3702     for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) {
3703       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3704           old_entry_layout.parameter_shape(i),
3705           new_program_shape.mutable_parameters(i)));
3706     }
3707     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3708         old_entry_layout.result_shape(), new_program_shape.mutable_result()));
3709 
3710     HloModuleConfig config = module->config();
3711     *config.mutable_entry_computation_layout() =
3712         ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
3713     module->set_config(config);
3714   }
3715 
3716   XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
3717                         *module, options_.report_instruction_count));
3718   XLA_VLOG_LINES(1, logger.MakeReport());
3719 
3720   if (changed) {
3721     HloPassPipeline pass("spmd-cleanup");
3722     pass.AddPass<TupleSimplifier>();
3723     pass.AddPass<HloDCE>();
3724     pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
3725     pass.AddPass<FlattenCallGraph>();
3726     TF_RETURN_IF_ERROR(pass.Run(module).status());
3727   }
3728 
3729   TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
3730   return changed;
3731 }
3732 
PreprocessSharding(HloModule * module)3733 Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
3734   for (HloComputation* computation : module->computations()) {
3735     for (HloInstruction* hlo : computation->instructions()) {
3736       if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
3737         TF_RET_CHECK(hlo->has_sharding())
3738             << "Side-effect HLO must have sharding: " << hlo->ToString();
3739         TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
3740                      hlo->opcode() == HloOpcode::kInfeed ||
3741                      hlo->opcode() == HloOpcode::kOutfeed)
3742             << "Non-infeed side-effect HLO cannot have a replicated sharding:"
3743             << hlo->ToString();
3744       }
3745 
3746       // For unassigned HLOs, annotate with replicated sharding.
3747       //
3748       // Among side-effecting ops, only Rng is allowed to omit the annotation.
3749       // In that case, we currently force it to run on core 0, since we don't
3750       // support partitioning or replicating the Rng op (the values depend on
3751       // the seed provided to each device).
3752       //
3753       // TODO(hyouklee): Should we also convert single-device shardings (without
3754       // side-effects) into replicated?
3755       if (!hlo->has_sharding()) {
3756         if (hlo->opcode() == HloOpcode::kRng) {
3757           hlo->set_sharding(HloSharding::AssignDevice(0));
3758         } else {
3759           hlo->set_sharding(
3760               HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
3761         }
3762       } else if (!hlo->sharding().IsTileMaximal() &&
3763                  !hlo->sharding().IsManual()) {
3764         std::vector<int64> available(num_partitions_);
3765         std::iota(available.begin(), available.end(), 0);
3766         TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
3767                                             hlo->sharding(), available)
3768                                             .size())
3769             << "num_partitions:" << num_partitions_ << "\n"
3770             << "SPMD partitioner only supports tile sharding that includes all "
3771                "partitions. If you didn't add this sharding annotation in the "
3772                "model, please file a bug to XLA team.\n"
3773             << hlo->ToString();
3774       }
3775     }
3776   }
3777 
3778   // Entry computation's parameter and root sharding must be either all
3779   // replicated or all on a single device.
3780   if (!options_.allow_module_signature_change) {
3781     const HloComputation* entry = module->entry_computation();
3782     TF_RET_CHECK(entry->root_instruction()->has_sharding());
3783     const HloSharding& root_sharding = entry->root_instruction()->sharding();
3784     TF_RET_CHECK(root_sharding.IsReplicated() ||
3785                  root_sharding.UniqueDevice().has_value())
3786         << "Unsupported entry root sharding: " << root_sharding.ToString();
3787 
3788     for (const HloInstruction* param : entry->parameter_instructions()) {
3789       TF_RET_CHECK(param->has_sharding());
3790       TF_RET_CHECK(param->sharding().IsReplicated() ||
3791                    param->sharding().UniqueDevice().has_value())
3792           << "Unsupported entry parameter sharding:"
3793           << param->sharding().ToString();
3794     }
3795   }
3796 
3797   return Status::OK();
3798 }
3799 
3800 }  // namespace spmd
3801 }  // namespace xla
3802