1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17
18 #include <float.h>
19
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/client/lib/comparators.h"
33 #include "tensorflow/compiler/xla/comparison_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/protobuf_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cse.h"
41 #include "tensorflow/compiler/xla/service/hlo_dce.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
46 #include "tensorflow/compiler/xla/service/hlo_query.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
49 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
50 #include "tensorflow/compiler/xla/service/shape_inference.h"
51 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
52 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/window_util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/numbers.h"
58
59 namespace xla {
60 namespace spmd {
61
MakeReport()62 string SpmdLogger::MakeReport() {
63 string report;
64 absl::StrAppend(&report,
65 "\n\n***** SPMD memory during transformation *****\n");
66
67 std::sort(entries_.begin(), entries_.end(),
68 [](auto const& entry0, auto const& entry1) {
69 return entry0.first > entry1.first;
70 });
71 for (int64 i = 0;
72 i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
73 absl::StrAppend(
74 &report, "\n ",
75 tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
76 entries_[i].second, "\n");
77 }
78
79 return report;
80 }
81
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)82 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
83 const std::vector<HloInstruction*>& group) {
84 string report = hlo->ToString();
85 int64 max_value = -1;
86 for (HloInstruction* inst : group) {
87 if (!inst->shape().IsArray()) {
88 continue;
89 }
90 max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
91 absl::StrAppend(&report, " * ", inst->ToString(), "\n");
92 }
93 entries_.push_back(std::make_pair(max_value, report));
94 }
95
ReportBeforePartition(const HloModule & module,int64 report_instruction_count)96 /* static */ string SpmdLogger::ReportBeforePartition(
97 const HloModule& module, int64 report_instruction_count) {
98 string report;
99 absl::StrAppend(&report,
100 "\n\n***** SPMD memory usage before partition *****\n");
101 absl::StrAppend(&report, "\n ** Replicated instructions\n");
102 absl::StrAppend(&report, ReportMemoryUsage(
103 module,
104 [](const HloInstruction* hlo) {
105 return !hlo->has_sharding() ||
106 hlo->sharding().IsReplicated();
107 },
108 report_instruction_count));
109 absl::StrAppend(&report, "\n ** All instructions\n");
110 absl::StrAppend(&report,
111 ReportMemoryUsage(
112 module, [](const HloInstruction* hlo) { return true; },
113 report_instruction_count));
114 return report;
115 }
116
ReportAfterPartition(const HloModule & module,int64 report_instruction_count)117 /* static */ string SpmdLogger::ReportAfterPartition(
118 const HloModule& module, int64 report_instruction_count) {
119 string report;
120 absl::StrAppend(&report,
121 "\n\n***** SPMD memory usage after partition *****\n");
122 absl::StrAppend(&report,
123 ReportMemoryUsage(
124 module, [](const HloInstruction* hlo) { return true; },
125 report_instruction_count));
126 return report;
127 }
128
129 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64 report_instruction_count)130 /* static */ string SpmdLogger::ReportMemoryUsage(
131 const HloModule& module, const F& filter, int64 report_instruction_count) {
132 string report;
133 std::vector<HloInstruction*> instructions;
134 instructions.reserve(module.instruction_count());
135
136 for (auto computation : module.computations()) {
137 if (computation->IsFusionComputation()) {
138 continue;
139 }
140 for (auto hlo : computation->instructions()) {
141 if (hlo->shape().IsTuple() ||
142 ShapeUtil::IsEffectiveScalar(hlo->shape())) {
143 continue;
144 }
145 if (filter(hlo)) {
146 instructions.push_back(hlo);
147 }
148 }
149 }
150
151 const auto add_report = [&](std::vector<HloInstruction*>* insts) {
152 std::sort(insts->begin(), insts->end(),
153 [](const HloInstruction* inst0, const HloInstruction* inst1) {
154 return ShapeSizeInBytes(inst0->shape()) >
155 ShapeSizeInBytes(inst1->shape());
156 });
157 for (int64 i = 0;
158 i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
159 absl::StrAppend(&report, " ",
160 tensorflow::strings::HumanReadableNumBytes(
161 ShapeSizeInBytes((*insts)[i]->shape())),
162 " : ", (*insts)[i]->ToString(), "\n");
163 }
164 };
165
166 add_report(&instructions);
167 return report;
168 }
169
170 namespace {
171
172 // Clears all sharding attributes from instructions in the module. This must be
173 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module)174 Status ClearShardingAttributes(HloModule* module) {
175 for (HloComputation* computation : module->computations()) {
176 for (HloInstruction* hlo : computation->instructions()) {
177 // Keep sharding annotation on Infeed and entry parameters since they're
178 // used by HloReplicationAnalysis later (for ArCrsCombiner).
179 if (hlo->opcode() == HloOpcode::kInfeed) {
180 continue;
181 }
182 if (hlo->opcode() == HloOpcode::kParameter &&
183 computation == module->entry_computation()) {
184 continue;
185 }
186 hlo->clear_sharding();
187 }
188 }
189 return Status::OK();
190 }
191
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64> replication_dims)192 std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
193 const HloSharding& sharding, absl::Span<const int64> replication_dims) {
194 int64 group_size = 1;
195 for (int64 i : replication_dims) {
196 group_size *= sharding.tile_assignment().dim(i);
197 }
198 std::vector<std::vector<int64>> partition_groups(
199 sharding.tile_assignment().num_elements() / group_size);
200 sharding.tile_assignment().Each(
201 [&](absl::Span<const int64> indices, int64 partition) {
202 int64 group_id = 0;
203 for (int64 i = 0; i < indices.size(); ++i) {
204 if (!absl::c_linear_search(replication_dims, i)) {
205 group_id *= sharding.tile_assignment().dim(i);
206 group_id += indices[i];
207 }
208 }
209 partition_groups[group_id].push_back(partition);
210 });
211 return partition_groups;
212 }
213
214 } // namespace
215
AddInstruction(std::unique_ptr<HloInstruction> instruction)216 HloInstruction* SpmdBuilder::AddInstruction(
217 std::unique_ptr<HloInstruction> instruction) {
218 HloInstruction* hlo =
219 HloComputation::Builder::AddInstruction(std::move(instruction));
220 if (visiting_hlo_) {
221 instructions_[visiting_hlo_].push_back(hlo);
222 }
223 if (hlo->opcode() == HloOpcode::kBroadcast) {
224 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
225 if (!absl::c_linear_search(hlo->dimensions(), i)) {
226 broadcast_dims_[hlo].insert(i);
227 }
228 }
229 }
230 if (hlo->IsElementwise() && hlo->operand_count() > 0) {
231 absl::flat_hash_set<int64> broadcast_dims;
232 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
233 broadcast_dims.insert(i);
234 }
235 for (int64 i = 0; i < hlo->operand_count(); ++i) {
236 auto it = broadcast_dims_.find(hlo->operand(i));
237 if (it == broadcast_dims_.end()) {
238 broadcast_dims.clear();
239 break;
240 }
241 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
242 if (!it->second.contains(i)) {
243 broadcast_dims.erase(i);
244 }
245 }
246 }
247 if (!broadcast_dims.empty()) {
248 broadcast_dims_[hlo] = std::move(broadcast_dims);
249 }
250 }
251 if (hlo->opcode() == HloOpcode::kTranspose) {
252 auto it = broadcast_dims_.find(hlo->operand(0));
253 if (it != broadcast_dims_.end()) {
254 absl::flat_hash_set<int64> xpose_broadcast_dims;
255 std::vector<int64> reverse_map(hlo->shape().rank());
256 for (int64 i = 0; i < reverse_map.size(); ++i) {
257 reverse_map[hlo->dimensions(i)] = i;
258 }
259 for (int64 dim : it->second) {
260 xpose_broadcast_dims.insert(reverse_map[dim]);
261 }
262 broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
263 }
264 }
265 if (hlo->opcode() == HloOpcode::kReshape &&
266 Product(hlo->shape().dimensions()) > 0) {
267 auto it = broadcast_dims_.find(hlo->operand(0));
268 if (it != broadcast_dims_.end()) {
269 absl::flat_hash_set<int64> reshape_broadcast_dims;
270 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
271 reshape_broadcast_dims.insert(i);
272 }
273 std::vector<int64> before_dim_size_stack;
274 std::vector<int64> after_dim_size_stack;
275 for (int64 i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) {
276 before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
277 }
278 for (int64 i = hlo->shape().rank() - 1; i >= 0; --i) {
279 after_dim_size_stack.push_back(hlo->shape().dimensions(i));
280 }
281 while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
282 int64 before_size = before_dim_size_stack.back();
283 int64 after_size = after_dim_size_stack.back();
284 int64 current_before_dim =
285 hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
286 int64 current_after_dim =
287 hlo->shape().rank() - after_dim_size_stack.size();
288 before_dim_size_stack.pop_back();
289 after_dim_size_stack.pop_back();
290 if (!it->second.contains(current_before_dim)) {
291 reshape_broadcast_dims.erase(current_after_dim);
292 }
293 if (before_size == after_size) {
294 continue;
295 }
296 if (before_size % after_size == 0) {
297 // Split dim.
298 before_dim_size_stack.push_back(before_size / after_size);
299 } else if (after_size % before_size == 0) {
300 // Merge dim.
301 after_dim_size_stack.push_back(after_size / before_size);
302 } else {
303 // Other cases, mark all remaining dims as non-broadcast.
304 for (int64 i = current_after_dim; i < hlo->shape().rank(); ++i) {
305 reshape_broadcast_dims.erase(i);
306 }
307 break;
308 }
309 }
310 if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
311 reshape_broadcast_dims.clear();
312 }
313 if (!reshape_broadcast_dims.empty()) {
314 broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
315 }
316 }
317 }
318 if (hlo->opcode() == HloOpcode::kSlice ||
319 hlo->opcode() == HloOpcode::kDynamicSlice) {
320 auto it = broadcast_dims_.find(hlo->operand(0));
321 if (it != broadcast_dims_.end()) {
322 auto dims = it->second;
323 broadcast_dims_[hlo] = std::move(dims);
324 }
325 }
326 if (hlo->opcode() == HloOpcode::kPad) {
327 auto it = broadcast_dims_.find(hlo->operand(0));
328 if (it != broadcast_dims_.end()) {
329 absl::flat_hash_set<int64> pad_broadcast_dims;
330 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
331 const auto& dim = hlo->padding_config().dimensions(i);
332 if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
333 dim.interior_padding() == 0 && it->second.contains(i)) {
334 pad_broadcast_dims.insert(i);
335 }
336 }
337 if (!pad_broadcast_dims.empty()) {
338 broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
339 }
340 }
341 }
342 return hlo;
343 }
344
Reshard(const HloSharding & target)345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
346 if (sharding() == target) {
347 return *this;
348 }
349 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
350 const bool is_to_replicate =
351 hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
352 if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
353 for (auto& entry : cache) {
354 if (entry.first == target) {
355 return entry.second;
356 }
357 }
358 }
359 auto resharded = ReshardNoCache(target);
360 state_.reshard_cache->per_hlo_cache[resharded.hlo()]
361 .reshard_cache.emplace_back(sharding(), *this);
362 if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
363 cache.emplace_back(target, std::move(resharded));
364 return cache.back().second;
365 }
366 return resharded;
367 }
368
ReshardNoCache(const HloSharding & target)369 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
370 VLOG(2) << "Resharding " << hlo_->ToString() << " from "
371 << hlo_->sharding().ToString() << " to " << target.ToString();
372 const Shape& shape = hlo_->shape();
373 if (shape.element_type() == TOKEN) {
374 return *this;
375 }
376 CHECK(shape.IsTuple() || !target.IsTuple());
377
378 // Tuple shape instructions may have non-tuple sharding, which means that the
379 // same sharding applies to all the leaves.
380 if (shape.IsTuple() && !target.IsTuple()) {
381 return Reshard(target.GetTupleSharding(shape).ValueOrDie());
382 }
383
384 // For a tuple shape, recursively apply Reshard to all the leaves and return
385 // a tuple instruction.
386 if (shape.IsTuple()) {
387 std::vector<HloInstruction*> elements;
388 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
389 auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
390 auto element = state_.b->AddInstruction(
391 HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
392 element->set_sharding(sharding().GetSubSharding(shape, {i}));
393 elements.push_back(
394 PartitionedHlo(
395 element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
396 .Reshard(target.GetSubSharding(shape, {i}))
397 .hlo());
398 }
399 auto tuple =
400 state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
401 tuple->set_sharding(target);
402 return PartitionedHlo(tuple, base_shape_, state_);
403 }
404
405 if (sharding() == target) {
406 return *this;
407 }
408
409 if (CanReshardWithCollectivePermute(sharding(), target)) {
410 return ReshardWithCollectivePermute(target);
411 }
412
413 if (auto src_tgt_dims =
414 GetReshardAllToAllSourceTargetDims(sharding(), target)) {
415 return ReshardWithAllToAll(target, *src_tgt_dims);
416 }
417
418 if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
419 auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
420 if (try_reshard.has_value()) {
421 return try_reshard.value();
422 }
423 try_reshard = ReshardPartialReplicateWithAllToAll(target);
424 if (try_reshard.has_value()) {
425 return try_reshard.value();
426 }
427 }
428
429 if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
430 auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
431 if (try_reshard.has_value()) {
432 return try_reshard.value();
433 }
434 try_reshard = ReshardPartialReplicateWithAllToAll(target);
435 if (try_reshard.has_value()) {
436 return try_reshard.value();
437 }
438 }
439
440 // If not replicated yet, first replicate and then reshard to use one of the
441 // two implementations below.
442 if (!sharding().IsReplicated()) {
443 return Replicate().Reshard(target);
444 }
445
446 // 'Replicated' to 'SingleDevice'.
447 if (target.IsTileMaximal()) {
448 auto copy = state_.b->AddInstruction(
449 HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
450 copy->set_sharding(target);
451 return PartitionedHlo(copy, base_shape_, state_);
452 }
453
454 // 'Replicated' to partial replicated.
455 if (target.ReplicateOnLastTileDim()) {
456 std::vector<int64> group_dims(target.tile_assignment().num_dimensions() -
457 1);
458 std::iota(group_dims.begin(), group_dims.end(), 0);
459 auto target_grouped = GroupShardingOnDims(target, group_dims);
460 auto partially_sharded = PerGroupSliceFromReplicated(
461 hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
462 target_grouped.group_dim_sizes, state_.b);
463 partially_sharded->set_sharding(target);
464 return PartitionedHlo(partially_sharded, base_shape(), state_);
465 }
466
467 // 'Replicated' to 'Tiled'.
468 auto padded_hlo =
469 PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
470 auto shard_shape = MakePartitionedShape(shape, target);
471 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
472 shard_shape, padded_hlo,
473 MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
474 shard_shape.dimensions()));
475 slice->set_sharding(target);
476 return PartitionedHlo(slice, base_shape_, state_);
477 }
478
PadWithValue(HloInstruction * pad_value,absl::Span<const int64> left_padded_dims,absl::Span<const int64> skipped_dims) const479 PartitionedHlo PartitionedHlo::PadWithValue(
480 HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
481 absl::Span<const int64> skipped_dims) const {
482 const HloSharding& sharding = hlo_->sharding();
483 const Shape& shape = hlo_->shape();
484 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
485 if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
486 return *this;
487 }
488 CHECK(!sharding.IsTileMaximal());
489 auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
490 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
491 auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) {
492 // Comparison: iota + start_index < valid_size
493 auto iota =
494 state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
495 auto broadcast_start_index = state_.b->AddInstruction(
496 HloInstruction::CreateBroadcast(index_shape, start_index, {}));
497 auto index_in_full_shape =
498 state_.b->AddInstruction(HloInstruction::CreateBinary(
499 index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
500 ComparisonDirection direction = ComparisonDirection::kLt;
501 int64 index_limit = base_shape_.dimensions(dim);
502 if (absl::c_linear_search(left_padded_dims, dim)) {
503 direction = ComparisonDirection::kGe;
504 index_limit =
505 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
506 index_limit;
507 }
508 auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
509 LiteralUtil::CreateR0<int32>(index_limit)));
510 auto broadcast_limit = state_.b->AddInstruction(
511 HloInstruction::CreateBroadcast(index_shape, limit, {}));
512 return state_.b->AddInstruction(HloInstruction::CreateCompare(
513 mask_shape, index_in_full_shape, broadcast_limit, direction));
514 };
515
516 HloInstruction* mask = nullptr;
517 auto offsets = MakePartitionOffsets(base_shape_, sharding,
518 state_.partition_id, state_.b);
519 for (int64 i = 0; i < shape.rank(); ++i) {
520 if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
521 absl::c_linear_search(skipped_dims, i)) {
522 continue;
523 }
524 if (mask == nullptr) {
525 mask = get_mask_for_dim(i, offsets[i]);
526 } else {
527 mask = state_.b->AddInstruction(
528 HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
529 get_mask_for_dim(i, offsets[i])));
530 }
531 }
532
533 if (mask == nullptr) {
534 return *this;
535 }
536
537 auto broadcast_pad_value = state_.b->AddInstruction(
538 HloInstruction::CreateBroadcast(shape, pad_value, {}));
539 auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
540 shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
541 result->set_sharding(sharding);
542 return PartitionedHlo(result, base_shape_, state_);
543 }
544
545 absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)546 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
547 const HloSharding& target,
548 HloInstruction* pad_value,
549 bool mask_invalid_region) {
550 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
551 for (auto& entry : cache) {
552 if (std::get<0>(entry) == target &&
553 protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
554 return std::get<2>(entry);
555 }
556 }
557 auto update_cache = [&](WindowedInputShardReturnValue result) {
558 cache.emplace_back(target, window, std::move(result));
559 return std::get<2>(cache.back());
560 };
561 VLOG(2) << "ReshardAsWindowedInput()\n"
562 << "\twindow:" << window_util::ToString(window)
563 << "\ttarget sharding:" << target.ToString();
564
565 CHECK(!target.IsTileMaximal());
566 auto partition_ordinals =
567 MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
568 auto shard_shape = base_shape_;
569
570 std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
571 base_shape_.rank());
572 std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
573 base_shape_.rank());
574 std::vector<HloInstruction*> dynamic_slice_offset_on_output(
575 base_shape_.rank(), nullptr);
576
577 Window shard_window = window;
578 auto padded_shape = base_shape_;
579 std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
580 std::vector<int64> per_shard_window_counts(base_shape_.rank());
581 std::vector<int64> explicit_left_padding(base_shape_.rank());
582 for (int64 i = 0; i < base_shape_.rank(); ++i) {
583 // Do not pad non-partitioned dimensions.
584 int64 shard_count = target.tile_assignment().dim(i);
585 if (shard_count == 1) {
586 offsets_on_padded_shape[i] = state_.b->AddInstruction(
587 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
588 continue;
589 }
590 const auto& wd = window.dimensions(i);
591 const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
592 int64 full_size =
593 base_shape_.dimensions(i) +
594 (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
595 wd.padding_high() + wd.padding_low();
596 if (full_size < dilated_size) {
597 VLOG(2) << "Failed to reshard window operand because the window size is "
598 "larger than padded base size";
599 return absl::nullopt;
600 }
601 int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
602 per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
603 if (wd.stride() != 1 &&
604 (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
605 // TODO(yuanzx): Support this case.
606 VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
607 return absl::nullopt;
608 }
609
610 // We use explicit padding for full dilations, then use padding_low and
611 // padding_high on the sharded op for the remaining. padding_low and
612 // padding_high are now given initial values, which will be later updated if
613 // dilation is not 1.
614 auto swd = shard_window.mutable_dimensions(i);
615 explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
616 swd->set_padding_low(wd.padding_low() % wd.base_dilation());
617 swd->set_padding_high(0);
618
619 // Calculation for the first element needed on the 'padded-but-not-dilated'
620 // shape. The start on the dilated shape could be a hole, so we add
621 // wd.base_dilation() - 1 to the constant term to skip the leading holes.
622 start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
623 wd.stride() * per_shard_window_counts[i],
624 wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
625 int64 dilated_shard_size =
626 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
627 limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
628 wd.stride() * per_shard_window_counts[i],
629 dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
630 wd.base_dilation());
631
632 offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
633 partition_ordinals[i], state_.b);
634
635 auto shard_size_function =
636 limit_on_padded_calculations[i] - start_on_padded_calculations[i];
637 int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count);
638 shard_shape.set_dimensions(i, max_shard_size);
639 padded_shape.set_dimensions(
640 i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
641
642 // For base dilation, calculate the needed padding_low and padding_high, as
643 // well as the offset for the output if a dynamic slice is needed after the
644 // sharded op.
645 if (wd.base_dilation() != 1) {
646 // Returns the offset of a shard's first valid element in the dilated
647 // shard.
648 auto get_first_valid_element_offset_on_dilated_shard =
649 [&](int64 shard_ordinal) {
650 return start_on_padded_calculations[i].Calculate(shard_ordinal) *
651 wd.base_dilation() +
652 swd->padding_low() -
653 wd.stride() * per_shard_window_counts[i] * shard_ordinal;
654 };
655 CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
656 swd->padding_low());
657
658 // Determine swd->padding_high.
659 for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
660 ++shard_ordinal) {
661 int64 wanted_limit_on_dilated_shard =
662 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
663 int64 actual_limit_on_dilated_shard_without_pad_high =
664 get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
665 (max_shard_size - 1) * wd.base_dilation() + 1;
666 swd->set_padding_high(std::max<int64>(
667 swd->padding_high(),
668 wanted_limit_on_dilated_shard -
669 actual_limit_on_dilated_shard_without_pad_high));
670 }
671
672 // Determine swd->padding_low and output dynamic slice index.
673 if (wd.stride() == 1) {
674 int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0);
675 bool all_same = true;
676 for (int64 shard_ordinal = 1; shard_ordinal < shard_count;
677 ++shard_ordinal) {
678 int64 start =
679 get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
680 if (start != swd->padding_low()) {
681 all_same = false;
682 }
683 max_pad_low = std::max(max_pad_low, start);
684 }
685 if (!all_same) {
686 auto start_on_padded_input =
687 start_on_padded_calculations[i].Calculate(partition_ordinals[i],
688 state_.b);
689 // We will calculate
690 // max_pad_low - (first_window - required_first_window)
691 // which equals
692 // required_first_window - (first_window - max_pad_low)
693 auto first_window_minus_max_pad_low =
694 MultiplyAddDivideOffsetCalculation(
695 wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
696 .Calculate(start_on_padded_input, state_.b);
697 auto required_first_window =
698 MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
699 1)
700 .Calculate(partition_ordinals[i], state_.b);
701 dynamic_slice_offset_on_output[i] =
702 state_.b->AddInstruction(HloInstruction::CreateBinary(
703 required_first_window->shape(), HloOpcode::kSubtract,
704 required_first_window, first_window_minus_max_pad_low));
705 }
706 swd->set_padding_low(max_pad_low);
707 } else {
708 if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
709 0) {
710 // General base dilation not yet implemented.
711 return absl::nullopt;
712 }
713 // padding_low on all shards should equal the initially assigned
714 // swd->padding_low(), i.e., the padding_low() on the original window.
715 }
716 }
717 }
718
719 // Returns the output dynamic slice offset when needed, and absl::nullopt
720 // otherwise.
721 auto get_dynamic_slice_offset_on_output_if_needed =
722 [&]() -> absl::optional<std::vector<HloInstruction*>> {
723 if (absl::c_all_of(
724 dynamic_slice_offset_on_output,
725 [](HloInstruction* offset) { return offset == nullptr; })) {
726 return absl::nullopt;
727 }
728 auto zero = state_.b->AddInstruction(
729 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
730 for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
731 if (dynamic_slice_offset_on_output[i] == nullptr) {
732 dynamic_slice_offset_on_output[i] = zero;
733 }
734 }
735 return dynamic_slice_offset_on_output;
736 };
737
738 // If the currrent HLO is replicated, pad then slice.
739 if (sharding().IsReplicated()) {
740 PaddingConfig padding_config;
741 for (int64 i = 0; i < base_shape_.rank(); ++i) {
742 auto padding_config_dim = padding_config.add_dimensions();
743 padding_config_dim->set_interior_padding(0);
744 // Do not pad non-partitioned dimensions.
745 if (target.tile_assignment().dim(i) == 1) {
746 padding_config_dim->set_edge_padding_low(0);
747 padding_config_dim->set_edge_padding_high(0);
748 continue;
749 }
750 padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
751 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
752 explicit_left_padding[i] -
753 base_shape_.dimensions(i));
754 }
755 auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
756 ? hlo_
757 : state_.b->AddInstruction(HloInstruction::CreatePad(
758 padded_shape, hlo_, pad_value, padding_config));
759 auto sharded_input =
760 state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
761 shard_shape, padded_hlo, offsets_on_padded_shape,
762 shard_shape.dimensions()));
763 return update_cache(WindowedInputShardReturnValue{
764 sharded_input, shard_window,
765 get_dynamic_slice_offset_on_output_if_needed()});
766 }
767
768 if (target != sharding()) {
769 return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
770 }
771
772 // Halo exchange.
773 HloInstruction* visiting_hlo = hlo_;
774 auto original_shard_shape = MakePartitionedShape(base_shape_, target);
775
776 std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
777 std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
778 // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
779 // and in the second dimension (and beyond) we create halos by slicing the
780 // concat in the previous dimension, which is not optimal. We should generate
781 // halos only concating slices, instead of slicing concats.
782 for (int dim = 0; dim < base_shape_.rank(); ++dim) {
783 int64 shard_count = target.tile_assignment().dim(dim);
784 if (shard_count == 1) {
785 continue;
786 }
787 int64 input_shard_size =
788 CeilOfRatio(base_shape_.dimensions(dim), shard_count);
789
790 // Left halo. The size of the halo is derived by subtracting the first read
791 // element offset of the i'th partition from the limit of the (i-1)'th
792 // partition.
793 MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
794 input_shard_size, explicit_left_padding[dim], 1);
795 left_halo_size_functions[dim] =
796 shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
797
798 // Right halo.
799 MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
800 input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
801 right_halo_size_functions[dim] =
802 limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
803
804 auto resharded = ExchangeHaloAndGetValidData(
805 visiting_hlo, base_shape_, left_halo_size_functions[dim],
806 right_halo_size_functions[dim], explicit_left_padding[dim],
807 padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
808 offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
809 state_.collective_ops_creator, state_.next_channel_id, state_.b,
810 mask_invalid_region);
811 if (!resharded) {
812 VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
813 "is beyond the neighbor.";
814 return Replicate().ReshardAsWindowedInput(window, target, pad_value);
815 }
816 visiting_hlo = *resharded;
817 }
818 return update_cache(WindowedInputShardReturnValue{
819 visiting_hlo, shard_window,
820 get_dynamic_slice_offset_on_output_if_needed()});
821 }
822
Replicate()823 PartitionedHlo PartitionedHlo::Replicate() {
824 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
825 if (state_.partitioner->options().cache_all_gather) {
826 for (auto& entry : cache) {
827 if (entry.first.IsReplicated()) {
828 return entry.second;
829 }
830 }
831 }
832 const HloSharding& sharding = hlo_->sharding();
833 const Shape& shape = hlo_->shape();
834 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
835
836 if (sharding.IsReplicated()) {
837 return *this;
838 }
839 for (auto& entry : cache) {
840 if (entry.first.IsReplicated()) {
841 return entry.second;
842 }
843 }
844 auto update_cache = [&](PartitionedHlo resharded) {
845 state_.reshard_cache->per_hlo_cache[resharded.hlo()]
846 .reshard_cache.emplace_back(sharding, *this);
847 if (state_.partitioner->options().cache_all_gather) {
848 cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
849 return cache.back().second;
850 }
851 return resharded;
852 };
853 // 'Single Device' to 'Repliated'.
854 if (sharding.IsTileMaximal()) {
855 return update_cache(Broadcast());
856 }
857
858 // 'Tiled' to 'Replicated'.
859 std::vector<int64> all_dims(shape.rank());
860 std::iota(all_dims.begin(), all_dims.end(), 0);
861 HloInstruction* result = ReplicatePartial(all_dims);
862 result->set_sharding(HloSharding::Replicate());
863 return update_cache(PartitionedHlo(result, base_shape_, state_));
864 }
865
ReplicatePartial(absl::Span<const int64> dims)866 HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
867 CHECK(!sharding().IsTileMaximal());
868 const Shape& shard_shape = hlo()->shape();
869 Shape target_shape = shard_shape;
870 Shape padded_target_shape = shard_shape;
871 for (int64 i : dims) {
872 padded_target_shape.set_dimensions(
873 i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
874 target_shape.set_dimensions(i, base_shape().dimensions(i));
875 }
876
877 HloInstruction* result = nullptr;
878 if (state_.collective_ops_creator.create_cross_partition_all_gather) {
879 result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
880 state_.next_channel_id, dims,
881 state_.collective_ops_creator);
882 }
883 if (result == nullptr) {
884 auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
885 LiteralUtil::Zero(shard_shape.element_type())));
886 auto zero_bcast = state_.b->AddInstruction(
887 HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
888 auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
889 state_.partition_id, state_.b, dims);
890 auto dus =
891 state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
892 padded_target_shape, zero_bcast, hlo_, offsets));
893 HloComputation* reduction =
894 MakeBinaryAdd(shard_shape.element_type(), state_.module);
895 result = state_.partitioner->AllReduceAlongShardingDims(
896 state_.b, dus, sharding(), state_.next_channel_id, dims,
897 state_.collective_ops_creator, reduction);
898 }
899 if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
900 std::vector<int64> start_indices(target_shape.rank(), 0);
901 std::vector<int64> strides(target_shape.rank(), 1);
902 result = state_.b->AddInstruction(
903 HloInstruction::CreateSlice(target_shape, result, start_indices,
904 target_shape.dimensions(), strides));
905 }
906 return result;
907 }
908
909 absl::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)910 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
911 const HloSharding& target) {
912 if (!target.ReplicateOnLastTileDim()) {
913 return absl::nullopt;
914 }
915 // Tiled/partial replicate to partial replicate
916 // Get the comptible sharding to target with resharding by all reduce.
917 auto compatible_sharding =
918 PartialReplicateReshardCompatibleSharding(target, sharding());
919 if (!compatible_sharding.has_value()) {
920 return absl::nullopt;
921 }
922
923 const auto& temp_sharding = compatible_sharding.value();
924 auto partitioned_hlo = *this;
925 // Use collective permute to adjust device assignment if needed.
926 if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
927 partitioned_hlo =
928 partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
929 }
930
931 // Get replicate dims and replicate factor of each dimensions.
932 int64 rank = hlo_->shape().rank();
933 std::vector<int64> replicate_dims;
934 std::vector<int64> replicate_factors;
935 for (int64 dim = 0; dim < rank; dim++) {
936 int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) /
937 target.tile_assignment().dim(dim);
938 if (replicate_factor > 1) {
939 replicate_dims.emplace_back(dim);
940 replicate_factors.emplace_back(replicate_factor);
941 }
942 }
943
944 // Do left halo exchange if all-reduce directly will remove useful data
945 // from the source.
946 auto halo_exchange = TileToPartialReplicateHaloExchange(
947 partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
948 partitioned_hlo.state().collective_ops_creator,
949 partitioned_hlo.state().next_channel_id,
950 partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
951 if (!halo_exchange.has_value()) {
952 return absl::nullopt;
953 }
954 auto halo_exchange_hlo = halo_exchange.value();
955 // Grouped on replicate dimensions.
956 auto sharding_grouped =
957 GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
958 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
959 partitioned_hlo.state(), sharding_grouped.device_groups,
960 partitioned_hlo.state().b);
961 auto base_shape = MakePartitionedShape(base_shape_, target);
962 // It's possible that halo_exchange_hlo == hlo.hlo().
963 // Record the sharding of hlo here, and reset it before return.
964 auto original_sharding = partitioned_hlo.sharding();
965 halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
966 auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
967 per_group_partitioner_state);
968 HloInstruction* result =
969 partial_replicate_hlo.ReplicatePartial(replicate_dims);
970 partitioned_hlo.hlo()->set_sharding(original_sharding);
971 result->set_sharding(target);
972 return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
973 }
974
975 absl::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)976 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
977 const HloSharding& target) {
978 if (!sharding().ReplicateOnLastTileDim()) {
979 return absl::nullopt;
980 }
981
982 // Get the temp sharding target from partial replicate to target tile dims.
983 // target_compatible_sharding has the same tile_assignment dimensions
984 // as the target and can reshard to target by collective permute.
985 // target_compatible_sharding could have different device assignment as
986 // targe. sharding() can reshard to target_compatible_sharding by
987 // dynamic slice.
988 auto target_compatible_sharding =
989 PartialReplicateReshardCompatibleSharding(sharding(), target);
990 // Reshard to target_compatible_sharding by dynamic slice.
991 if (!target_compatible_sharding.has_value()) {
992 return absl::nullopt;
993 }
994 std::vector<int64> expand_tile_dims;
995 std::vector<int64> tiling_dim_factors;
996 int64 rank = hlo_->shape().rank();
997 tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
998 const auto& temp_target_sharding = target_compatible_sharding.value();
999 for (int64 dim = 0; dim < rank; dim++) {
1000 if (temp_target_sharding.tile_assignment().dim(dim) >
1001 sharding().tile_assignment().dim(dim)) {
1002 expand_tile_dims.push_back(dim);
1003 }
1004 tiling_dim_factors.emplace_back(
1005 temp_target_sharding.tile_assignment().dim(dim) /
1006 sharding().tile_assignment().dim(dim));
1007 }
1008
1009 // Add another dimension in tiling_dim_factors if target is partial replicate.
1010 if (target.ReplicateOnLastTileDim()) {
1011 tiling_dim_factors.emplace_back(
1012 target.tile_assignment().dimensions().back());
1013 }
1014
1015 // 2. Get the padded_hlo, do right halo exchange if needed.
1016 auto padded_hlo = PadFromPartialReplicateShape(
1017 hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1018 state_.collective_ops_creator, state_.next_channel_id,
1019 state_.partition_id, state_.b);
1020 if (!padded_hlo.has_value()) {
1021 return absl::nullopt;
1022 }
1023 // 3. Slice out the tile from replicate ones.
1024 auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1025 // Since we are just slicing, we can just use the differences between the new
1026 // and old offsets in the full shape as the dynamic-slice offsets.
1027 auto padded_base_shape = shard_shape;
1028 for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
1029 padded_base_shape.set_dimensions(
1030 i, padded_base_shape.dimensions(i) *
1031 temp_target_sharding.tile_assignment().dim(i));
1032 }
1033 auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1034 state_.partition_id, state_.b);
1035 auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1036 state_.partition_id, state_.b);
1037 for (int64 i = 0; i < offsets.size(); ++i) {
1038 offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1039 offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1040 }
1041 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1042 shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1043 slice->set_sharding(temp_target_sharding);
1044 auto result = PartitionedHlo(slice, base_shape_, state_);
1045 // If temp_target_sharding's device assignment is different from target,
1046 // use collective permute to reshard.
1047 if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1048 return result.ReshardWithCollectivePermute(target);
1049 }
1050 // If device assignment in temp_target_sharding and target are the same,
1051 // return result directly.
1052 return result;
1053 }
1054
Broadcast() const1055 PartitionedHlo PartitionedHlo::Broadcast() const {
1056 const Shape& shape = hlo_->shape();
1057 const HloSharding& sharding = hlo_->sharding();
1058 CHECK(sharding.HasUniqueDevice());
1059 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1060
1061 auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1062 LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
1063 Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1064 auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1065 bcast_shape,
1066 state_.b->AddInstruction(HloInstruction::CreateCompare(
1067 ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1068 ComparisonDirection::kEq)),
1069 {}));
1070
1071 auto zero = state_.b->AddInstruction(
1072 HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1073 auto zero_bcast = state_.b->AddInstruction(
1074 HloInstruction::CreateBroadcast(shape, zero, {}));
1075 auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1076 shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1077 HloComputation* reduction =
1078 MakeBinaryAdd(shape.element_type(), state_.module);
1079
1080 auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1081 state_.b, operand, reduction, {}, NewChannel());
1082 result->set_sharding(HloSharding::Replicate());
1083 return PartitionedHlo(result, base_shape_, state_);
1084 }
1085
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64,int64>> source_target_dims) const1086 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1087 const HloSharding& target,
1088 absl::Span<const std::pair<int64, int64>> source_target_dims) const {
1089 if (source_target_dims.empty()) {
1090 if (target == sharding()) {
1091 return *this;
1092 }
1093 // If the device order is different in the target, fix the order with
1094 // ReshardWithCollectivePermute.
1095 return ReshardWithCollectivePermute(target);
1096 }
1097
1098 // Swap one pair of dimensions.
1099 int64 source_dim = source_target_dims[0].first;
1100 int64 target_dim = source_target_dims[0].second;
1101 const int64 group_size = sharding().tile_assignment().dim(source_dim) /
1102 sharding().tile_assignment().dim(target_dim);
1103
1104 auto temp_target_tile = sharding().tile_assignment();
1105 {
1106 std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
1107 int64 i = 0;
1108 int64 added_source_dim = -1;
1109 int64 added_target_dim = -1;
1110 for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1111 if (source_dim == j) {
1112 reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1113 reshape_tile_dims[++i] = group_size;
1114 added_source_dim = i;
1115 } else if (target_dim == j) {
1116 reshape_tile_dims[i] = temp_target_tile.dim(j);
1117 reshape_tile_dims[++i] = 1;
1118 added_target_dim = i;
1119 } else {
1120 reshape_tile_dims[i] = temp_target_tile.dim(j);
1121 }
1122 ++i;
1123 }
1124 temp_target_tile.Reshape(reshape_tile_dims);
1125 std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
1126 std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1127 xpose_dims[added_source_dim] = added_target_dim;
1128 xpose_dims[added_target_dim] = added_source_dim;
1129 temp_target_tile = hlo_sharding_util::TransposeSharding(
1130 HloSharding::Tile(temp_target_tile), xpose_dims)
1131 .tile_assignment();
1132 auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1133 temp_target_tile_dims[source_dim] =
1134 sharding().tile_assignment().dim(target_dim);
1135 temp_target_tile_dims[target_dim] =
1136 sharding().tile_assignment().dim(source_dim);
1137 temp_target_tile.Reshape(temp_target_tile_dims);
1138 }
1139 auto temp_target = target.ReplicateOnLastTileDim()
1140 ? HloSharding::PartialTile(temp_target_tile)
1141 : HloSharding::Tile(temp_target_tile);
1142 auto padded_shape = hlo_->shape();
1143 padded_shape.set_dimensions(
1144 target_dim,
1145 RoundUpToNearest(padded_shape.dimensions(target_dim),
1146 temp_target.tile_assignment().dim(target_dim)));
1147 auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1148
1149 // The order of ids in the group must follow the temp_target sharding.
1150 std::vector<std::vector<int64>> groups(
1151 temp_target.tile_assignment().num_elements() / group_size);
1152 temp_target.tile_assignment().Each(
1153 [&](absl::Span<const int64> indices, int64 device) {
1154 int64 group_id = 0;
1155 for (int64 dim = 0; dim < indices.size(); ++dim) {
1156 if (dim == target_dim) {
1157 group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1158 group_id += indices[dim] / group_size;
1159 } else {
1160 group_id *= temp_target.tile_assignment().dim(dim);
1161 group_id += indices[dim];
1162 }
1163 }
1164 groups[group_id].push_back(device);
1165 });
1166
1167 HloInstruction* result = nullptr;
1168
1169 // Split along the split dimension (target_dim) of the all-to-all
1170 // output.
1171 std::vector<int64> dimensions;
1172 for (int64 i = 0; i < base_shape_.rank(); ++i) {
1173 if (i == target_dim) {
1174 dimensions.push_back(group_size);
1175 dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1176 } else {
1177 dimensions.push_back(padded_hlo->shape().dimensions(i));
1178 }
1179 }
1180 auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1181 ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1182 padded_hlo));
1183 // After the reshape, it is guaranteed to have at least 3 dimensions.
1184 auto all_to_all =
1185 state_.collective_ops_creator.create_cross_partition_all_to_all(
1186 state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1187
1188 // Reorder the split dimension of the reshape to be located in front of the
1189 // input partition dimension, so the two dimensions can be combined.
1190 int64 new_source_dim =
1191 (target_dim < source_dim) ? source_dim + 1 : source_dim;
1192 std::vector<int64> permutation;
1193 for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
1194 if (i == target_dim) {
1195 continue;
1196 }
1197 if (i == new_source_dim) {
1198 permutation.push_back(target_dim);
1199 }
1200 permutation.push_back(i);
1201 }
1202 auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1203 ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1204 .ValueOrDie(),
1205 all_to_all, permutation));
1206
1207 // Combine the split dimension and the input partition dimension.
1208 auto new_shape = ShapeInference::InferAllToAllShape(
1209 padded_hlo->shape(), target_dim, source_dim, group_size)
1210 .ValueOrDie();
1211 result = state_.b->AddInstruction(
1212 HloInstruction::CreateReshape(new_shape, transpose));
1213
1214 const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1215 if (result_shape != result->shape()) {
1216 result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1217 result_shape, result, std::vector<int64>(result_shape.rank(), 0),
1218 result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
1219 }
1220 result->set_sharding(temp_target);
1221 auto remaining_source_target_dims = source_target_dims;
1222 remaining_source_target_dims.remove_prefix(1);
1223 return PartitionedHlo(result, base_shape_, state_)
1224 .ReshardWithAllToAll(target, remaining_source_target_dims);
1225 }
1226
1227 absl::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1228 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1229 bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1230 const auto& partial_replicate_sharding =
1231 source_is_partial_replicate ? sharding() : target;
1232 // If neither the source nor the target is partial replicate, return null.
1233 if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1234 return absl::nullopt;
1235 }
1236 const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1237 // If both source and target are partial replicate, should be supported in
1238 // Reshard with AllToAll already.
1239 if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1240 return absl::nullopt;
1241 }
1242
1243 // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1244 // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1245 // the last tile dim will be replicate first before all-to-all.
1246 // Or resharding from
1247 // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1248 // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1249 // the last tile dim will be sharded after all-to-all.
1250 const int num_replicas =
1251 partial_replicate_sharding.tile_assignment().dimensions().back();
1252 if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1253 partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1254 (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1255 return absl::nullopt;
1256 }
1257 int to_replicate_dim = -1;
1258 for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1259 --i) {
1260 if (tile_sharding.tile_assignment().dim(i) > 1 &&
1261 (to_replicate_dim == -1)) {
1262 if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1263 return absl::nullopt;
1264 }
1265 to_replicate_dim = i;
1266 }
1267
1268 if (tile_sharding.tile_assignment().dim(i) !=
1269 partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1270 return absl::nullopt;
1271 }
1272 }
1273
1274 if (to_replicate_dim == -1) {
1275 return absl::nullopt;
1276 }
1277
1278 // Check if core assignments for source and the target are the same.
1279 auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1280 reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1281 if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1282 return absl::nullopt;
1283 }
1284
1285 auto tmp_tile_assignment = tile_sharding.tile_assignment();
1286 auto tmp_tile_assignment_dimensions =
1287 tile_sharding.tile_assignment().dimensions();
1288 tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1289 tmp_tile_assignment_dimensions.push_back(num_replicas);
1290 tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1291 auto tmp_partial_replicate_sharding =
1292 HloSharding::PartialTile(tmp_tile_assignment);
1293
1294 if (source_is_partial_replicate) {
1295 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1296 sharding(), tmp_partial_replicate_sharding)) {
1297 auto partitioned_hlo =
1298 ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1299 return partitioned_hlo.Reshard(target);
1300 }
1301 } else {
1302 auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1303
1304 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1305 partitioned_hlo.sharding(), target)) {
1306 return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1307 }
1308 }
1309
1310 return absl::nullopt;
1311 }
1312
ReshardWithCollectivePermute(const HloSharding & target) const1313 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1314 const HloSharding& target) const {
1315 CHECK(CanReshardWithCollectivePermute(sharding(), target))
1316 << sharding().ToString() << " to " << target.ToString();
1317 if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1318 if (!(*broadcast_dims)->empty()) {
1319 // If hlo() has broadcast dims, check if data is already the same between
1320 // source/destination pairs.
1321 std::vector<int64> broadcast_dims_vector;
1322 for (int64 i = 0; i < hlo()->shape().rank(); ++i) {
1323 if ((*broadcast_dims)->contains(i)) {
1324 broadcast_dims_vector.push_back(i);
1325 }
1326 }
1327 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1328 sharding(), broadcast_dims_vector) ==
1329 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1330 target, broadcast_dims_vector)) {
1331 auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1332 hlo()->shape(), HloOpcode::kCopy, hlo()));
1333 copy->set_sharding(target);
1334 return PartitionedHlo(copy, base_shape_, state_);
1335 }
1336 }
1337 }
1338 std::vector<std::pair<int64, int64>> src_dst_pairs;
1339 sharding().tile_assignment().Each(
1340 [&](absl::Span<const int64> indices, int64 src_device) {
1341 int64 dst_device = target.tile_assignment()(indices);
1342 src_dst_pairs.emplace_back(src_device, dst_device);
1343 });
1344 auto cp =
1345 state_.collective_ops_creator.create_cross_partition_collective_permute(
1346 state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1347 cp->set_sharding(target);
1348 return PartitionedHlo(cp, base_shape_, state_);
1349 }
1350
SpmdPartitioningVisitor(HloComputation * computation,int64 num_partitions,int64 num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1351 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1352 HloComputation* computation, int64 num_partitions, int64 num_replicas,
1353 const SPMDCollectiveOpsCreator& collective_ops_creator,
1354 int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
1355 SpmdPartitioner* partitioner)
1356 : changed_(false),
1357 module_(computation->parent()),
1358 num_partitions_(num_partitions),
1359 num_replicas_(num_replicas),
1360 collective_ops_creator_(collective_ops_creator),
1361 next_channel_id_(next_channel_id),
1362 b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
1363 partition_id_(collective_ops_creator_.create_partition_id(&b_)),
1364 logger_(logger),
1365 options_(std::move(options)),
1366 partitioner_(partitioner) {}
1367
DefaultAction(HloInstruction * hlo)1368 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
1369 if (hlo->HasSideEffect()) {
1370 return Unimplemented("Side-effect ops cannot be replicated: %s",
1371 hlo->ToString());
1372 }
1373
1374 if (hlo->IsElementwise() && hlo->operand_count() > 0) {
1375 return HandleElementwise(hlo);
1376 }
1377
1378 if (!hlo->sharding().IsTileMaximal()) {
1379 VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
1380 << hlo->ToString();
1381 for (int64 i = 0; i < hlo->operand_count(); ++i) {
1382 VLOG(1) << " operand " << i
1383 << " sharding:" << hlo->operand(i)->sharding().ToString();
1384 }
1385 }
1386
1387 HloSharding sharding = hlo->sharding().HasUniqueDevice()
1388 ? hlo->sharding()
1389 : HloSharding::Replicate();
1390
1391 // If the instruction cannot be partitioned, replicate the instruction unless
1392 // the instruction has side-effect.
1393 std::vector<HloInstruction*> new_operands;
1394 for (HloInstruction* operand : hlo->operands()) {
1395 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1396 }
1397 auto clone =
1398 b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
1399 clone->set_sharding(sharding);
1400 clone->set_metadata(hlo->metadata());
1401 SetPartitionedHlo(hlo,
1402 PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
1403 .Reshard(hlo->sharding()));
1404 return Status::OK();
1405 }
1406
Preprocess(HloInstruction * hlo)1407 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
1408 visiting_hlo_ = hlo;
1409 b_.set_visiting_hlo(hlo);
1410 // Temporarily replace manual sharding to one-device sharding so that the
1411 // partitioner will not change the HLOs.
1412 auto manual_to_onedevice = [&](const Shape& shape,
1413 const HloSharding& sharding) {
1414 // If a tuple's elements are all manual, then sharding.IsManual() == True,
1415 // so we test whether it is tuple first.
1416 if (sharding.IsTuple()) {
1417 std::vector<HloSharding> subshardings = sharding.tuple_elements();
1418 for (HloSharding& subsharding : subshardings) {
1419 if (subsharding.IsManual()) {
1420 subsharding = HloSharding::AssignDevice(0);
1421 }
1422 }
1423 return HloSharding::Tuple(shape, subshardings);
1424 }
1425 if (sharding.IsManual()) {
1426 return HloSharding::AssignDevice(0);
1427 }
1428 return sharding;
1429 };
1430 const bool has_manual_sharding =
1431 hlo->sharding().IsManual() ||
1432 (hlo->sharding().IsTuple() &&
1433 absl::c_any_of(
1434 hlo->sharding().tuple_elements(),
1435 [](const HloSharding& sharding) { return sharding.IsManual(); }));
1436 if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
1437 visiting_hlo_sharding_ = hlo->sharding();
1438 hlo->set_sharding(
1439 manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
1440
1441 visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1442 for (auto operand : hlo->operands()) {
1443 visiting_hlo_operand_shardings_.push_back(operand->sharding());
1444 operand->set_sharding(
1445 manual_to_onedevice(operand->shape(), operand->sharding()));
1446 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1447 }
1448 }
1449 return Status::OK();
1450 }
1451
Postprocess(HloInstruction * hlo)1452 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
1453 logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
1454 b_.derived_instructions(hlo));
1455 visiting_hlo_ = nullptr;
1456 b_.set_visiting_hlo(nullptr);
1457 // Revert fake one-device shardings for manually partitioned ops.
1458 if (visiting_hlo_sharding_) {
1459 hlo->set_sharding(*visiting_hlo_sharding_);
1460 GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
1461 for (int64 i = 0; i < hlo->operand_count(); ++i) {
1462 auto operand = hlo->mutable_operand(i);
1463 operand->set_sharding(visiting_hlo_operand_shardings_[i]);
1464 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1465 }
1466 visiting_hlo_sharding_.reset();
1467 visiting_hlo_operand_shardings_.clear();
1468 }
1469 return Status::OK();
1470 }
1471
HandleElementwise(HloInstruction * hlo)1472 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
1473 std::vector<HloInstruction*> new_operands;
1474 for (HloInstruction* operand : hlo->operands()) {
1475 new_operands.push_back(
1476 GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
1477 }
1478 SetPartitionedHlo(hlo, [&] {
1479 return b_.AddInstruction(hlo->CloneWithNewOperands(
1480 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1481 });
1482 return Status::OK();
1483 }
1484
HandleConcatenate(HloInstruction * hlo)1485 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
1486 const HloSharding& sharding = hlo->sharding();
1487 if (sharding.IsTileMaximal()) {
1488 return DefaultAction(hlo);
1489 }
1490
1491 const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
1492 const int64 dimension = hlo->concatenate_dimension();
1493 if (sharding.tile_assignment().dim(dimension) == 1) {
1494 std::vector<HloInstruction*> new_operands;
1495 for (HloInstruction* operand : hlo->operands()) {
1496 new_operands.push_back(
1497 GetPartitionedHlo(operand).Reshard(sharding).hlo());
1498 }
1499 SetPartitionedHlo(hlo, [&] {
1500 return b_.AddInstruction(
1501 hlo->CloneWithNewOperands(shard_shape, new_operands));
1502 });
1503 return Status::OK();
1504 }
1505
1506 // If the concatenate dimension is along one of the partitioned dimensions,
1507 // allocate the full output shape, each partition updates its owned region,
1508 // all-reduce across partitions, and then slice its output region.
1509
1510 // temp_output_shape is the output shape where the concatenate dimension
1511 // is changed to the full (and padded to shard count) dimension size.
1512 auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
1513 auto last_operand_padded_shape =
1514 MakePartitionedShape(hlo->operands().back()->shape(), sharding);
1515 // If the last operand has more padding than the temp_output padding, needs to
1516 // add extra padding to avoid dynamic update slice out of bound.
1517 int last_operand_padding =
1518 last_operand_padded_shape.dimensions(dimension) *
1519 sharding.tile_assignment().dim(dimension) -
1520 hlo->operands().back()->shape().dimensions(dimension);
1521 int temp_output_padding = temp_output_shape.dimensions(dimension) *
1522 sharding.tile_assignment().dim(dimension) -
1523 hlo->shape().dimensions(dimension);
1524 int padding_for_last_operand =
1525 last_operand_padding < temp_output_padding
1526 ? 0
1527 : last_operand_padding - temp_output_padding;
1528 temp_output_shape.set_dimensions(
1529 dimension, temp_output_shape.dimensions(dimension) *
1530 sharding.tile_assignment().dim(dimension) +
1531 padding_for_last_operand);
1532 auto temp_output = CreateZero(temp_output_shape, &b_);
1533
1534 // Offset of each operand along the concatenate dimension.
1535 int64 offset = 0;
1536 for (HloInstruction* operand : hlo->operands()) {
1537 auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
1538 std::vector<HloInstruction*> start_indices(
1539 hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
1540 LiteralUtil::Zero(S32))));
1541 start_indices[dimension] =
1542 MultiplyAddDivideOffsetCalculation(
1543 spmd_operand->shape().dimensions(dimension), offset, 1)
1544 .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_,
1545 &b_)[dimension],
1546 &b_);
1547 temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1548 temp_output_shape, temp_output, spmd_operand, start_indices));
1549 offset += operand->shape().dimensions(dimension);
1550 }
1551 std::vector<int64> non_concat_dims;
1552 non_concat_dims.reserve(hlo->shape().rank() - 1);
1553 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1554 if (i != dimension) {
1555 non_concat_dims.push_back(i);
1556 }
1557 }
1558 auto grouped = GroupShardingOnDims(sharding, non_concat_dims);
1559 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1560 MakePartitioningState(), grouped.device_groups, &b_);
1561 auto all_reduce = per_group_partitioner_state.collective_ops_creator
1562 .create_cross_partition_all_reduce(
1563 &b_, temp_output,
1564 MakeBinaryAdd(hlo->shape().element_type(), module_),
1565 {}, NewChannel());
1566 SetPartitionedHlo(hlo, [&] {
1567 auto start_indices = MakeTiledPartitionOrdinals(
1568 grouped.sharding, per_group_partitioner_state.partition_id, &b_);
1569 start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
1570 shard_shape.dimensions(dimension), 0, 1)
1571 .Calculate(start_indices[dimension], &b_);
1572 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
1573 shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
1574 });
1575
1576 return Status::OK();
1577 }
1578
HandleSlice(HloInstruction * hlo)1579 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
1580 const HloSharding& sharding = hlo->sharding();
1581 if (sharding.IsTileMaximal()) {
1582 return DefaultAction(hlo);
1583 }
1584
1585 auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
1586
1587 // Create a window config to represent the slice.
1588 Window window;
1589 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1590 WindowDimension* dim = window.add_dimensions();
1591 dim->set_size(1);
1592 dim->set_stride(hlo->slice_strides(i));
1593 dim->set_window_dilation(1);
1594 dim->set_window_reversal(false);
1595 dim->set_padding_low(-hlo->slice_starts(i));
1596 dim->set_padding_high(hlo->slice_limits(i) -
1597 operand.base_shape().dimensions(i));
1598 dim->set_base_dilation(1);
1599 }
1600
1601 auto reshard_operand = operand.ReshardAsWindowedInput(
1602 window, sharding,
1603 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
1604 /*mask_invalid_region=*/false);
1605 if (!reshard_operand.has_value()) {
1606 return DefaultAction(hlo);
1607 }
1608 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
1609 const Shape& operand_shape = reshard_operand->sharded_input->shape();
1610
1611 std::vector<int64> start_indices = hlo->slice_starts();
1612 std::vector<int64> limit_indices = hlo->slice_limits();
1613 std::vector<int64> strides = hlo->slice_strides();
1614 bool need_slice = false;
1615 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1616 auto dim = reshard_operand->shard_window.dimensions(i);
1617 start_indices[i] = -dim.padding_low();
1618 limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
1619 if (start_indices[i] != 0 || strides[i] != 1 ||
1620 limit_indices[i] != operand_shape.dimensions(i)) {
1621 need_slice = true;
1622 }
1623 }
1624
1625 SetPartitionedHlo(hlo, [&] {
1626 if (need_slice) {
1627 auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
1628 return b_.AddInstruction(HloInstruction::CreateSlice(
1629 shard_shape, reshard_operand->sharded_input, start_indices,
1630 limit_indices, strides));
1631 }
1632 auto data = reshard_operand->sharded_input;
1633 // Create a copy so that it will not share the resharding cache.
1634 return b_.AddInstruction(
1635 HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
1636 });
1637
1638 return Status::OK();
1639 }
1640
HandleSort(HloInstruction * hlo)1641 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
1642 HloSharding sharding = hlo->sharding();
1643 // Special handling for sort in TopK when first operand partitioined at
1644 // sort dimension.
1645 auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
1646 if (k.has_value()) {
1647 // When the first operand partitioned at sort dimension:
1648 // 1. Partition sort computation to different partitions;
1649 // 2. Slice TopK value and index from different partitions;
1650 // 3. Gather and replicate value and index from different partitions,
1651 // the shape of replicated value and index will be
1652 // [batch_size, ..., partition_count * k, ...];
1653 // 4. Final sort uses replicated value and index from different partitions
1654 // as input.
1655 // GetTupleElement and Slice after the non-partitoned sort won't change
1656 // at this point, as HandleGetTupleElement and HandleSlice will update them.
1657 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1658 const int64 sort_dim = sort->sort_dimension();
1659 auto input = hlo->operand(0);
1660 auto index = hlo->operand(1);
1661 const HloSharding& input_sharding = input->sharding();
1662 const int64 partition_count =
1663 input_sharding.tile_assignment().dim(sort_dim);
1664 const int64 input_size = input->shape().dimensions(sort_dim);
1665 const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
1666 const auto element_type = input->shape().element_type();
1667 const auto index_type = index->shape().element_type();
1668
1669 // Partition and pad input and index.
1670 // Pad input with minimal value.
1671 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1672 CreateFirstWithType(element_type, &b_));
1673 // Pad index with max value.
1674 auto partitioned_index =
1675 GetPartitionedHlo(index)
1676 .Reshard(input_sharding)
1677 .PadWithValue(CreateLastWithType(index_type, &b_));
1678
1679 // Each partition needs to do TopK separately, thus the base shape
1680 // becomes the padded shape.
1681 std::vector<int64> replicated_dimensions(
1682 input->shape().dimensions().begin(), input->shape().dimensions().end());
1683 replicated_dimensions[sort_dim] = per_partition_size * partition_count;
1684 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1685 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1686 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1687
1688 // Partition original topk to different shards.
1689 auto topk_sharding =
1690 input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1691 auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
1692 auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
1693 shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
1694
1695 // Get value from first sort.
1696 HloInstruction* value_gte =
1697 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1698 topk->shape().tuple_shapes(0), topk, 0));
1699 HloInstruction* index_gte =
1700 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1701 topk->shape().tuple_shapes(1), topk, 1));
1702
1703 // Slice top K value from the first partitioned sort.
1704 replicated_dimensions[sort_dim] = k.value() * partition_count;
1705 auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
1706 slice_input->set_sharding(input_sharding);
1707 PartitionedHlo partitioned_slice_input(
1708 slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
1709 MakePartitioningState());
1710 // Reshard value to be replicated.
1711 auto replicated_slice_input =
1712 partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
1713
1714 // Slice top K index from the first parttioned sort.
1715 auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
1716 slice_index->set_sharding(input_sharding);
1717 PartitionedHlo partitioned_slice_index(
1718 slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
1719 MakePartitioningState());
1720 // Reshard value to be replicated.
1721 auto replicated_slice_index =
1722 partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
1723
1724 // Creates replicated sort to do TopK, the input is value and index pairs
1725 // from all the partitions.
1726 const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
1727 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1728 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1729 HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
1730 final_topk_shape, sort_dim,
1731 {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
1732 sort->is_stable()));
1733 final_sort->set_sharding(HloSharding::Replicate()
1734 .GetTupleSharding(final_sort->shape())
1735 .ValueOrDie());
1736 PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
1737 MakePartitioningState());
1738 SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
1739
1740 return Status::OK();
1741 }
1742
1743 if (hlo->shape().IsTuple()) {
1744 // Check that all elements are sharded in the same way.
1745 if (hlo->shape().tuple_shapes_size() == 0) {
1746 return DefaultAction(hlo);
1747 }
1748 sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
1749 for (int64 i = 1; i < hlo->operand_count(); ++i) {
1750 if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
1751 return DefaultAction(hlo);
1752 }
1753 }
1754 }
1755 if (sharding.IsTileMaximal()) {
1756 return DefaultAction(hlo);
1757 }
1758 for (int64 dim : hlo->dimensions()) {
1759 if (sharding.tile_assignment().dim(dim) > 1) {
1760 return DefaultAction(hlo);
1761 }
1762 }
1763 // Reshard operands to the same as the output.
1764 std::vector<HloInstruction*> new_operands;
1765 for (HloInstruction* operand : hlo->operands()) {
1766 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1767 }
1768 SetPartitionedHlo(hlo, [&] {
1769 return b_.AddInstruction(hlo->CloneWithNewOperands(
1770 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1771 });
1772 return Status::OK();
1773 }
1774
HandleCustomCall(HloInstruction * hlo)1775 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
1776 if (hlo->custom_call_target() == "SPMDFullToShardShape") {
1777 // This op switches from auto partitioning to manual partitioning.
1778 auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
1779 if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
1780 input_partitioned = input_partitioned.PadWithValue(
1781 CreateR0WithType(hlo->shape().element_type(), 0, &b_));
1782 }
1783 auto input = input_partitioned.hlo();
1784 CHECK(hlo->sharding().IsManual());
1785 CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
1786 auto copy = b_.AddInstruction(
1787 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
1788 SetPartitionedHlo(hlo, [&] { return copy; });
1789 return Status::OK();
1790 }
1791 if (hlo->custom_call_target() == "SPMDShardToFullShape") {
1792 // This op switches from manual partitioning to auto partitioning.
1793 auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
1794 CHECK(input->sharding().IsManual());
1795 auto copy = b_.AddInstruction(
1796 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
1797 CHECK(ShapeUtil::Compatible(
1798 copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
1799 SetPartitionedHlo(hlo, [&] { return copy; });
1800 return Status::OK();
1801 }
1802 if (hlo->custom_call_target() != "TopK") {
1803 return DefaultAction(hlo);
1804 }
1805
1806 if (!hlo->operand(0)->has_sharding()) {
1807 return DefaultAction(hlo);
1808 }
1809
1810 const HloSharding& sharding = hlo->operand(0)->sharding();
1811 if (sharding.IsTileMaximal() || sharding.IsReplicated()) {
1812 return DefaultAction(hlo);
1813 }
1814
1815 const int64 sort_dim = 1;
1816 const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
1817
1818 if (shard_count <= 1) {
1819 return DefaultAction(hlo);
1820 }
1821
1822 const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1823 const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0);
1824 const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
1825 const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
1826
1827 if (k >= per_partition_size) {
1828 return DefaultAction(hlo);
1829 }
1830
1831 auto input = hlo->operand(0);
1832 const auto element_type = input->shape().element_type();
1833
1834 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1835 CreateFirstWithType(element_type, &b_));
1836
1837 // Each partition needs to do TopK separately, thus the base shape
1838 // becomes [batch_size, k * shard_count].
1839 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1840 {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
1841 {batch_size, k * shard_count}),
1842 ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
1843 auto custom_call_sharding =
1844 sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1845 auto shard_shape =
1846 MakePartitionedShape(replicated_shape, custom_call_sharding);
1847 auto topk = b_.AddInstruction(
1848 hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
1849 topk->set_sharding(custom_call_sharding);
1850 // Partition customcall.
1851 PartitionedHlo partitioned_topk(topk, replicated_shape,
1852 MakePartitioningState());
1853 topk = partitioned_topk.hlo();
1854
1855 // Get value from TopK.
1856 HloInstruction* value_gte =
1857 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1858 topk->shape().tuple_shapes(0), topk, 0));
1859 value_gte->set_sharding(sharding);
1860 // Partition GetTupleElement of value.
1861 PartitionedHlo value_partitioned_gte(
1862 value_gte, partitioned_topk.base_shape().tuple_shapes(0),
1863 MakePartitioningState());
1864 // Reshard value to be replicated.
1865 auto replicated_value_gte =
1866 value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
1867
1868 // Get index from TopK.
1869 HloInstruction* index_gte =
1870 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1871 topk->shape().tuple_shapes(1), topk, 1));
1872 auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
1873 ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
1874 partition_id_));
1875 // Add per partition offset to index, index returned from CustomCall always
1876 // starts from 0.
1877 auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
1878 index_gte->shape(),
1879 b_.AddInstruction(HloInstruction::CreateBinary(
1880 partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
1881 b_.AddInstruction(HloInstruction::CreateConstant(
1882 LiteralUtil::CreateR0<int32>(per_partition_size))))),
1883 {}));
1884 index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
1885 index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
1886 index_gte->set_sharding(sharding);
1887 // Parttion GetTupleElement of index.
1888 PartitionedHlo index_partitioned_gte(
1889 index_gte, partitioned_topk.base_shape().tuple_shapes(1),
1890 MakePartitioningState());
1891 // Reshard index to be replicated.
1892 auto replicated_index_gte =
1893 index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
1894
1895 // Creates replicated sort to do TopK, the input is value and index pairs
1896 // from all the partitions. The reason to use Sort instead of CustomCall TopK
1897 // is CustomCall only takes value as input. There will be an extra Gather
1898 // to get the correct index if CustomCall is used here.
1899
1900 // Create comparator for the sort.
1901 XlaBuilder b("Sort.Compare");
1902 XlaComputation comparator = CreateScalarComparisonComputation(
1903 "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
1904 &b);
1905 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
1906 HloModuleConfig config(program_shape);
1907 TF_ASSIGN_OR_RETURN(auto new_module,
1908 HloModule::CreateFromProto(comparator.proto(), config));
1909 HloCloneContext context(module_);
1910 auto compare_computation =
1911 module_->DeepCloneComputation(new_module->entry_computation(), &context);
1912 auto sort = b_.AddInstruction(HloInstruction::CreateSort(
1913 replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
1914 compare_computation, true));
1915 sort->set_sharding(
1916 HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie());
1917 PartitionedHlo replicated_sort(sort, replicated_shape,
1918 MakePartitioningState());
1919
1920 // Slice value and index from top-k for output.
1921 HloInstruction* sort_value_gte =
1922 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1923 replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
1924 0));
1925 HloInstruction* sort_index_gte =
1926 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1927 replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
1928 1));
1929 // Slice value from final sort.
1930 HloInstruction* slice_sort_value =
1931 SliceFirstK(sort_value_gte, &b_, sort_dim, k);
1932 // Slice index from final sort.
1933 HloInstruction* slice_index_value =
1934 SliceFirstK(sort_index_gte, &b_, sort_dim, k);
1935 auto create_tuple = b_.AddInstruction(
1936 HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
1937 create_tuple->set_sharding(HloSharding::Replicate());
1938
1939 SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(),
1940 MakePartitioningState())
1941 .Reshard(hlo->sharding()));
1942
1943 return Status::OK();
1944 }
1945
HandleTranspose(HloInstruction * hlo)1946 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
1947 const HloSharding& sharding = hlo->sharding();
1948 if (sharding.IsTileMaximal()) {
1949 return DefaultAction(hlo);
1950 }
1951
1952 std::vector<int64> inverse_dimensions(hlo->shape().rank());
1953 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
1954 inverse_dimensions[hlo->dimensions(i)] = i;
1955 }
1956 auto desired_operand_sharding =
1957 hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
1958
1959 auto operand = GetPartitionedHlo(hlo->operand(0))
1960 .Reshard(desired_operand_sharding)
1961 .hlo();
1962 SetPartitionedHlo(hlo, [&] {
1963 return b_.AddInstruction(hlo->CloneWithNewOperands(
1964 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
1965 });
1966 return Status::OK();
1967 }
1968
HandleReshape(HloInstruction * hlo)1969 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
1970 const HloSharding& sharding = hlo->sharding();
1971 if (sharding.IsTileMaximal()) {
1972 return DefaultAction(hlo);
1973 }
1974
1975 auto operand = GetPartitionedHlo(hlo->operand(0));
1976 // The output shape is the source and the operand shape is the target to get
1977 // the aligned sharding for the operand.
1978 auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
1979 hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
1980 if (desired_operand_sharding.has_value()) {
1981 auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
1982 SetPartitionedHlo(hlo, [&] {
1983 return b_.AddInstruction(hlo->CloneWithNewOperands(
1984 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
1985 });
1986 return Status::OK();
1987 }
1988
1989 // Check if operand sharding and sharding are both tiled or partial replicate.
1990 // If both of them are partial replicate, check num_replications are the same.
1991 if (operand.sharding().ReplicateOnLastTileDim() !=
1992 sharding.ReplicateOnLastTileDim() ||
1993 (sharding.ReplicateOnLastTileDim() &&
1994 (operand.sharding().tile_assignment().dimensions().back() !=
1995 sharding.tile_assignment().dimensions().back()))) {
1996 return DefaultAction(hlo);
1997 }
1998
1999 // Try use halo exchange for certain split-dim/merge-dims cases.
2000 // ReshapeSharding failed in these cases probably due to uneven partitioning,
2001 // where halo exchange could help. Specifically we check the following
2002 // conditions to detect supported cases:
2003 // 1) Both input and output are partitioned on one dimension.
2004 // 2) The combined size of dimensions before the partitioned dimension are the
2005 // same on input and output. This means we don't need to consider the major
2006 // dimensions.
2007 // 3) Let A = the input size on the partitioned dimension, and
2008 // B = the output size on the partitioned dimension; then
2009 // either A % B == 0 (split dim) or B % A == 0 (merge dims).
2010 auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
2011 auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
2012 if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
2013 return DefaultAction(hlo);
2014 }
2015 int64 input_sharded_dim = *maybe_input_sharded_dim;
2016 int64 output_sharded_dim = *maybe_output_sharded_dim;
2017 // Check that the major dims before the sharded dim have the same total size
2018 // for input and output.
2019 int64 input_major_dims_size = 1;
2020 for (int64 i = 0; i < input_sharded_dim; ++i) {
2021 input_major_dims_size *= operand.base_shape().dimensions(i);
2022 }
2023 int64 output_major_dims_size = 1;
2024 for (int64 i = 0; i < output_sharded_dim; ++i) {
2025 output_major_dims_size *= hlo->shape().dimensions(i);
2026 }
2027 if (input_major_dims_size != output_major_dims_size) {
2028 return DefaultAction(hlo);
2029 }
2030 // Fix potential device ordering mismatch in tile assignment.
2031 Array<int64> new_input_tile_assignment = sharding.tile_assignment();
2032 new_input_tile_assignment.Reshape(
2033 operand.sharding().tile_assignment().dimensions());
2034 auto aligned_sharding =
2035 sharding.ReplicateOnLastTileDim()
2036 ? HloSharding::PartialTile(new_input_tile_assignment)
2037 : HloSharding::Tile(new_input_tile_assignment);
2038 operand = operand.Reshard(aligned_sharding);
2039 auto replication_count = sharding.ReplicateOnLastTileDim()
2040 ? sharding.tile_assignment().dimensions().back()
2041 : 1;
2042
2043 int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2044 int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2045 auto input_shard_shape =
2046 MakePartitionedShape(operand.base_shape(), operand.sharding());
2047 auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2048 if (input_dim_size % output_dim_size == 0) {
2049 // Split dim.
2050 int64 split_factor = input_dim_size / output_dim_size;
2051 int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim);
2052 // Use halo exchange to fix misaligned data.
2053 Window window;
2054 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2055 WindowDimension* dim = window.add_dimensions();
2056 dim->set_size(1);
2057 dim->set_stride(1);
2058 dim->set_window_dilation(1);
2059 dim->set_window_reversal(false);
2060 dim->set_base_dilation(1);
2061 dim->set_padding_low(0);
2062 if (i == input_sharded_dim) {
2063 dim->set_padding_high(output_shard_size * split_factor *
2064 num_partitions_ / replication_count -
2065 input_dim_size);
2066 } else {
2067 dim->set_padding_high(0);
2068 }
2069 }
2070
2071 auto reshard_operand = operand.ReshardAsWindowedInput(
2072 window, operand.sharding(),
2073 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2074 /*mask_invalid_region=*/false);
2075 if (!reshard_operand.has_value()) {
2076 return DefaultAction(hlo);
2077 }
2078 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2079 CHECK_EQ(
2080 reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2081 output_shard_size * split_factor);
2082 SetPartitionedHlo(hlo, [&] {
2083 // Do a local reshape.
2084 return b_.AddInstruction(HloInstruction::CreateReshape(
2085 output_shard_shape, reshard_operand->sharded_input));
2086 });
2087 return Status::OK();
2088 } else if (output_dim_size % input_dim_size == 0) {
2089 // Merge dims.
2090 int64 merge_factor = output_dim_size / input_dim_size;
2091 // First reshape locally. (The sharded dimension could include padded data.)
2092 auto tmp_shard_shape = output_shard_shape;
2093 tmp_shard_shape.set_dimensions(
2094 output_sharded_dim,
2095 input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2096 auto tmp_reshape = b_.AddInstruction(
2097 HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2098 tmp_reshape->set_metadata(hlo->metadata());
2099 tmp_reshape->set_sharding(hlo->sharding());
2100 auto tmp_full_shape = tmp_shard_shape;
2101 tmp_full_shape.set_dimensions(
2102 output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2103 num_partitions_ / replication_count);
2104 auto tmp_output =
2105 PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2106
2107 // Use halo exchange to fix misaligned data.
2108 Window window;
2109 for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) {
2110 WindowDimension* dim = window.add_dimensions();
2111 dim->set_size(1);
2112 dim->set_stride(1);
2113 dim->set_window_dilation(1);
2114 dim->set_window_reversal(false);
2115 dim->set_base_dilation(1);
2116 dim->set_padding_low(0);
2117 if (i == output_sharded_dim) {
2118 dim->set_padding_high(output_dim_size -
2119 tmp_shard_shape.dimensions(output_sharded_dim) *
2120 num_partitions_ / replication_count);
2121 } else {
2122 dim->set_padding_high(0);
2123 }
2124 }
2125
2126 auto reshard_output = tmp_output.ReshardAsWindowedInput(
2127 window, sharding,
2128 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2129 /*mask_invalid_region=*/false);
2130 if (!reshard_output.has_value()) {
2131 return DefaultAction(hlo);
2132 }
2133 TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2134 CHECK_EQ(
2135 reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2136 output_shard_shape.dimensions(output_sharded_dim));
2137 SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2138 return Status::OK();
2139 }
2140 return DefaultAction(hlo);
2141 }
2142
HandleIota(HloInstruction * hlo)2143 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2144 const HloSharding& sharding = hlo->sharding();
2145 if (sharding.IsTileMaximal()) {
2146 return DefaultAction(hlo);
2147 }
2148
2149 SetPartitionedHlo(hlo, [&] {
2150 int64 dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2151 auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2152 MakePartitionedShape(hlo->shape(), sharding), dimension));
2153
2154 if (sharding.tile_assignment().dim(dimension) > 1) {
2155 auto partition_ordinals =
2156 MakeTiledPartitionOrdinals(sharding, partition_id_, &b_);
2157 auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2158 LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
2159 auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2160 ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2161 partition_ordinals[dimension], multiplier));
2162 if (iota->shape().element_type() != S32) {
2163 offset = b_.AddInstruction(HloInstruction::CreateConvert(
2164 ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2165 }
2166 auto broadcast = b_.AddInstruction(
2167 HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2168 return b_.AddInstruction(HloInstruction::CreateBinary(
2169 iota->shape(), HloOpcode::kAdd, iota, broadcast));
2170 }
2171
2172 return iota;
2173 });
2174
2175 return Status::OK();
2176 }
2177
HandleSingleDevice(const HloInstruction * hlo)2178 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2179 TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2180 int64 device = hlo->sharding().GetUniqueDevice();
2181 const HloSharding sharding = HloSharding::AssignDevice(device);
2182
2183 std::vector<HloInstruction*> operands;
2184 std::vector<Shape> operand_shapes;
2185 for (const HloInstruction* operand : hlo->operands()) {
2186 operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2187 operand_shapes.push_back(operand->shape());
2188 }
2189 auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2190 auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
2191
2192 auto on_device = b_.AddInstruction(
2193 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
2194 auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2195 ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device,
2196 ComparisonDirection::kEq));
2197
2198 SpmdBuilder true_b("true_computation", visiting_hlo_);
2199 HloComputation* true_computation;
2200 {
2201 auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2202 /*parameter_number=*/0, operand_shape, "true_branch_param"));
2203 std::vector<HloInstruction*> new_operands;
2204 for (int64 i = 0; i < operands.size(); ++i) {
2205 new_operands.push_back(true_b.AddInstruction(
2206 HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
2207 }
2208 auto root = true_b.AddInstruction(
2209 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2210 true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2211 }
2212
2213 SpmdBuilder false_b("false_computation", visiting_hlo_);
2214 HloComputation* false_computation;
2215 {
2216 false_b.AddInstruction(HloInstruction::CreateParameter(
2217 /*parameter_number=*/0, operand_shape, "false_branch_param"));
2218 auto root = CreateZero(hlo->shape(), &false_b);
2219 false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2220 }
2221
2222 SetPartitionedHlo(hlo, [&]() {
2223 return b_.AddInstruction(HloInstruction::CreateConditional(
2224 hlo->shape(), pred, operand, true_computation, operand,
2225 false_computation));
2226 });
2227 return Status::OK();
2228 }
2229
HandleAllReduce(HloInstruction * hlo)2230 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2231 if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2232 return HandleElementwise(hlo);
2233 }
2234 return DefaultAction(hlo);
2235 }
2236
HandleBroadcast(HloInstruction * hlo)2237 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2238 if (hlo->sharding().IsTileMaximal()) {
2239 return DefaultAction(hlo);
2240 }
2241
2242 auto& operand = GetPartitionedHlo(hlo->operand(0));
2243
2244 // Tiled output.
2245 std::vector<int64> new_dims;
2246 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2247 if (!absl::c_linear_search(hlo->dimensions(), i)) {
2248 new_dims.push_back(i);
2249 }
2250 }
2251 auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2252 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2253 new_dims),
2254 new_dims);
2255 auto input = operand.Reshard(desired_input_sharding).hlo();
2256 auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2257 SetPartitionedHlo(hlo, [&] {
2258 return b_.AddInstruction(
2259 hlo->CloneWithNewOperands(output_shard_shape, {input}));
2260 });
2261 return Status::OK();
2262 }
2263
HandleConstant(HloInstruction * hlo)2264 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2265 const Literal& literal = hlo->literal();
2266 if (literal.shape().IsTuple() ||
2267 (!hlo->sharding().IsTileMaximal() &&
2268 (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2269 !literal.IsAllFirst()))) {
2270 return DefaultAction(hlo);
2271 }
2272
2273 SetPartitionedHlo(hlo, [&]() {
2274 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2275 std::vector<int64> start_indices(hlo->shape().rank(), 0);
2276 auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2277 literal.Slice(start_indices, shard_shape.dimensions())));
2278 *constant->mutable_shape() = shard_shape;
2279 return constant;
2280 });
2281 return Status::OK();
2282 }
2283
HandleDynamicSlice(HloInstruction * hlo)2284 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2285 if (hlo->sharding().IsTileMaximal()) {
2286 return DefaultAction(hlo);
2287 }
2288 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2289 if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2290 (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
2291 !hlo->operand(i + 1)->IsConstant() ||
2292 !hlo->operand(i + 1)->literal().IsZero({}))) {
2293 // We currently do not partition the sliced dimensions.
2294 return DefaultAction(hlo);
2295 }
2296 }
2297 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2298 auto new_input =
2299 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2300 for (int64 i = 0; i < new_indices.size(); ++i) {
2301 // Replicate the indices.
2302 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
2303 .Reshard(HloSharding::Replicate())
2304 .hlo();
2305 }
2306 SetPartitionedHlo(hlo, [&]() {
2307 auto partitioned_shape =
2308 MakePartitionedShape(hlo->shape(), hlo->sharding());
2309 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2310 partitioned_shape, new_input, new_indices,
2311 partitioned_shape.dimensions()));
2312 });
2313 return Status::OK();
2314 }
2315
HandleDynamicUpdateSlice(HloInstruction * hlo)2316 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
2317 if (hlo->sharding().IsTileMaximal()) {
2318 return DefaultAction(hlo);
2319 }
2320
2321 std::vector<int64> partitioned_slice_dims;
2322 std::vector<int64> slice_dims;
2323 std::vector<int64> partitioned_non_slice_dims;
2324 std::vector<int64> partitioned_slice_offsets;
2325 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2326 if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
2327 slice_dims.push_back(i);
2328 if (hlo->sharding().tile_assignment().dim(i) != 1) {
2329 if (!hlo->operand(i + 2)->IsConstant()) {
2330 return DefaultAction(hlo);
2331 }
2332 partitioned_slice_dims.push_back(i);
2333 partitioned_slice_offsets.push_back(
2334 hlo->operand(i + 2)->literal().Get<int>({}));
2335 }
2336 } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
2337 if (!hlo->operand(i + 2)->IsConstant() ||
2338 !hlo->operand(i + 2)->literal().IsZero({})) {
2339 return DefaultAction(hlo);
2340 }
2341 partitioned_non_slice_dims.push_back(i);
2342 }
2343 }
2344
2345 // Handle when there is slice dim partitioned.
2346 if (!partitioned_slice_dims.empty()) {
2347 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
2348 return b_.AddInstruction(std::move(to_add));
2349 };
2350 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2351 for (int64 i = 0; i < new_indices.size(); ++i) {
2352 // Replicate the indices.
2353 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2354 .Reshard(HloSharding::Replicate())
2355 .hlo();
2356 }
2357
2358 // Get partitioned input.
2359 const auto& dus_sharding = hlo->sharding();
2360 const auto& partitioned_input =
2361 GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
2362
2363 // Get replicate update.
2364 auto update_sharding = HloSharding::Replicate();
2365 if (!partitioned_non_slice_dims.empty()) {
2366 // Do partial replicate for update if non slice dims are partitioned.
2367 update_sharding =
2368 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
2369 slice_dims);
2370 }
2371 HloInstruction* replicate_update =
2372 GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
2373
2374 const auto& update_shape = replicate_update->shape();
2375 const auto& partitioned_shape = partitioned_input->shape();
2376 auto partition_ordinals =
2377 MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
2378 HloInstruction* all_dims_within_partition = add_hlo(
2379 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2380
2381 for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
2382 int dim = partitioned_slice_dims[i];
2383 // Calculate per partition size.
2384 const int64 per_partition_size = partitioned_shape.dimensions(dim);
2385
2386 // Only update within a single partition is supported.
2387 if ((partitioned_slice_offsets[i] / per_partition_size) !=
2388 ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) - 1) /
2389 per_partition_size)) {
2390 return DefaultAction(hlo);
2391 }
2392
2393 // within_partition = (offset >= partition_id * per_partition_size) &&
2394 // (offset < (partition_id + 1) * per_partition_size)
2395 const Shape& compare_shape =
2396 ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
2397 auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
2398 LiteralUtil::CreateR0<int>(per_partition_size)));
2399 const Shape& offset_shape = per_partition_size_hlo->shape();
2400 auto partition_offset = add_hlo(HloInstruction::CreateBinary(
2401 offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
2402 per_partition_size_hlo));
2403 // offset >= partition_id * per_partition_size
2404 auto offset_ge = add_hlo(HloInstruction::CreateCompare(
2405 compare_shape, new_indices[dim], partition_offset,
2406 ComparisonDirection::kGe));
2407 // offset < (partition_id + 1) * per_partition_size
2408 auto offset_lt = add_hlo(HloInstruction::CreateCompare(
2409 compare_shape, new_indices[dim],
2410 add_hlo(HloInstruction::CreateBinary(
2411 offset_shape, HloOpcode::kMultiply,
2412 add_hlo(HloInstruction::CreateBinary(
2413 offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
2414 add_hlo(HloInstruction::CreateConstant(
2415 LiteralUtil::CreateR0<int>(1))))),
2416 per_partition_size_hlo)),
2417 ComparisonDirection::kLt));
2418 auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
2419 compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
2420
2421 all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
2422 compare_shape, HloOpcode::kAnd, all_dims_within_partition,
2423 update_within_partition));
2424
2425 // Calculate offset.
2426 // slice dim offset =
2427 // within_partition ?
2428 // offset - partition_id * per_partition_size : 0
2429 new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
2430 new_indices[dim]->shape(), HloOpcode::kSelect,
2431 update_within_partition,
2432 add_hlo(HloInstruction::CreateBinary(
2433 new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
2434 partition_offset)),
2435 add_hlo(
2436 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
2437 }
2438
2439 // Create dynamic update slice.
2440 auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
2441 partitioned_shape, partitioned_input, replicate_update, new_indices));
2442 SetPartitionedHlo(hlo, [&]() {
2443 // Select if update is needed.
2444 return add_hlo(HloInstruction::CreateTernary(
2445 dus->shape(), HloOpcode::kSelect,
2446 add_hlo(HloInstruction::CreateBroadcast(
2447 ShapeUtil::ChangeElementType(dus->shape(), PRED),
2448 all_dims_within_partition, {})),
2449 dus, partitioned_input));
2450 });
2451 return Status::OK();
2452 }
2453
2454 // Partition non slice dims only.
2455 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2456 auto new_input =
2457 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2458 auto new_update =
2459 GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
2460 for (int64 i = 0; i < new_indices.size(); ++i) {
2461 // Replicate the indices.
2462 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2463 .Reshard(HloSharding::Replicate())
2464 .hlo();
2465 }
2466 SetPartitionedHlo(hlo, [&]() {
2467 auto partitioned_shape =
2468 MakePartitionedShape(hlo->shape(), hlo->sharding());
2469 return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2470 partitioned_shape, new_input, new_update, new_indices));
2471 });
2472 return Status::OK();
2473 }
2474
HandleGetTupleElement(HloInstruction * hlo)2475 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2476 const auto& tuple = GetPartitionedHlo(hlo->operand(0));
2477 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2478 ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
2479 tuple.hlo(), hlo->tuple_index()));
2480 const auto source_sharding =
2481 tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
2482 gte->set_sharding(source_sharding);
2483 PartitionedHlo source_partitioned_gte(
2484 gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
2485 MakePartitioningState());
2486 source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
2487 SetPartitionedHlo(hlo, source_partitioned_gte);
2488 return Status::OK();
2489 }
2490
HandleInfeed(HloInstruction * hlo)2491 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
2492 const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
2493 auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
2494 if (ShapeUtil::GetLeafCount(shape) == 0) {
2495 // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
2496 // requires one element for an empty tuple, but leaf-count number of
2497 // elements for non-empty tuple. So if it has a nested empty tuple, we
2498 // cannot invoke GetSubSharding() since it expects a sharding for the empty
2499 // tuple. This is a workaround for that case.
2500 SetPartitionedHlo(hlo, [&]() {
2501 return b_.AddInstruction(
2502 HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
2503 });
2504 return Status::OK();
2505 }
2506 auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2507 auto shard_shape = MakePartitionedShape(shape, sharding);
2508 if (EvenlyPartitions(shape, sharding)) {
2509 SetPartitionedHlo(hlo, [&]() {
2510 return b_.AddInstruction(HloInstruction::CreateInfeed(
2511 shard_shape, token, hlo->infeed_config()));
2512 });
2513 return Status::OK();
2514 }
2515
2516 if (hlo->sharding().HasUniqueDevice()) {
2517 return HandleSingleDevice(hlo);
2518 }
2519
2520 // Create a branch for each unique partitioned shape.
2521 std::vector<Shape> per_branch_partitioned_shapes;
2522 std::vector<int32> conditional_branch_indices(num_partitions_);
2523 for (int64 i = 0; i < num_partitions_; ++i) {
2524 auto partitioned_shape =
2525 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2526 int64 matching_existing_index = 0;
2527 for (; matching_existing_index < per_branch_partitioned_shapes.size();
2528 ++matching_existing_index) {
2529 if (ShapeUtil::Compatible(
2530 partitioned_shape,
2531 per_branch_partitioned_shapes[matching_existing_index])) {
2532 break;
2533 }
2534 }
2535 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2536 conditional_branch_indices[i] = matching_existing_index;
2537 } else {
2538 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2539 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2540 }
2541 }
2542
2543 HloInstruction* branch_index;
2544 if (per_branch_partitioned_shapes.size() == num_partitions_) {
2545 // Use partition ID as the branch index if each partition has its own
2546 // branch.
2547 branch_index = partition_id_;
2548 // PartitionId's output is U32 but conditional requires S32.
2549 if (branch_index->shape().element_type() != S32) {
2550 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2551 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2552 branch_index));
2553 }
2554 } else {
2555 // Otherwise, use a constant table to look up the branch index.
2556 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2557 LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2558 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2559 ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2560 {1}));
2561 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2562 ShapeUtil::MakeShape(S32, {}), branch_index));
2563 }
2564
2565 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2566 for (int64 i = 0; i < branches.size(); ++i) {
2567 SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
2568 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2569 /*parameter_number=*/0, token->shape(), "infeed_token_param"));
2570 auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
2571 per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
2572 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2573 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2574 pad_infeed = [&](const ShapeIndex& index,
2575 HloInstruction* infeed_element) -> HloInstruction* {
2576 if (index == ShapeIndex({1})) {
2577 // Token.
2578 return infeed_element;
2579 }
2580 const Shape& element_shape =
2581 ShapeUtil::GetSubshape(infeed->shape(), index);
2582 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2583 std::vector<HloInstruction*> padded_elements(
2584 element_shape.tuple_shapes_size());
2585 for (int64 i = 0; i < padded_elements.size(); ++i) {
2586 auto sub_index = index;
2587 sub_index.push_back(i);
2588 padded_elements[i] = pad_infeed(
2589 sub_index,
2590 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2591 ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
2592 i)));
2593 }
2594 return branch_b.AddInstruction(
2595 HloInstruction::CreateTuple(padded_elements));
2596 }
2597 const Shape& pad_shape =
2598 ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1));
2599 if (ShapeUtil::Compatible(element_shape, pad_shape)) {
2600 return infeed_element;
2601 }
2602 if (element_shape.IsArray()) {
2603 CHECK(pad_shape.IsArray());
2604 return PadToShape(infeed_element, pad_shape, &branch_b);
2605 }
2606 CHECK(element_shape.IsTuple());
2607 CHECK(element_shape.tuple_shapes().empty());
2608 return CreateZero(pad_shape, &branch_b);
2609 };
2610 pad_infeed({}, infeed);
2611 }
2612 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
2613 }
2614 SetPartitionedHlo(hlo, [&]() {
2615 return b_.AddInstruction(HloInstruction::CreateConditional(
2616 ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
2617 branches, std::vector<HloInstruction*>(branches.size(), token)));
2618 });
2619 return Status::OK();
2620 }
2621
HandlePad(HloInstruction * hlo)2622 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
2623 if (hlo->sharding().IsTileMaximal()) {
2624 return DefaultAction(hlo);
2625 }
2626 auto lhs = GetPartitionedHlo(hlo->operand(0));
2627 // Create a window config to represent the pad.
2628 Window window;
2629 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2630 const auto& pd = hlo->padding_config().dimensions(i);
2631 WindowDimension* dim = window.add_dimensions();
2632 dim->set_size(1);
2633 dim->set_stride(1);
2634 dim->set_window_dilation(1);
2635 dim->set_window_reversal(false);
2636 dim->set_padding_low(pd.edge_padding_low());
2637 dim->set_padding_high(pd.edge_padding_high());
2638 dim->set_base_dilation(pd.interior_padding() + 1);
2639 }
2640
2641 auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
2642 .Reshard(HloSharding::Replicate())
2643 .hlo();
2644 auto reshard_operand =
2645 lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
2646 /*mask_invalid_region=*/false);
2647 if (!reshard_operand.has_value()) {
2648 return DefaultAction(hlo);
2649 }
2650 PaddingConfig sharded_padding_config;
2651 bool need_pad = false;
2652 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
2653 auto dim = sharded_padding_config.add_dimensions();
2654 const auto& wd = reshard_operand->shard_window.dimensions(i);
2655 dim->set_edge_padding_low(wd.padding_low());
2656 dim->set_edge_padding_high(wd.padding_high());
2657 dim->set_interior_padding(wd.base_dilation() - 1);
2658 if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
2659 wd.base_dilation() != 1) {
2660 need_pad = true;
2661 }
2662 }
2663 auto sharded_pad = reshard_operand->sharded_input;
2664 if (need_pad) {
2665 TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
2666 ShapeInference::InferPadShape(sharded_pad->shape(),
2667 replicated_rhs->shape(),
2668 sharded_padding_config));
2669 sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
2670 sharded_pad, replicated_rhs,
2671 sharded_padding_config));
2672 }
2673
2674 SetPartitionedHlo(hlo, [&]() {
2675 if (!reshard_operand->dynamic_slice_index_on_output) {
2676 return sharded_pad;
2677 }
2678 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2679 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2680 shard_shape, sharded_pad,
2681 *reshard_operand->dynamic_slice_index_on_output,
2682 shard_shape.dimensions()));
2683 });
2684 return Status::OK();
2685 }
2686
HandleParameter(HloInstruction * hlo)2687 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
2688 SetPartitionedHlo(hlo, [&]() {
2689 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2690 auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
2691 hlo->parameter_number(), shard_shape, "param"));
2692 if (hlo->parameter_replicated_at_leaf_buffers()) {
2693 new_param->set_parameter_replicated_at_leaf_buffers(
2694 *hlo->parameter_replicated_at_leaf_buffers());
2695 }
2696 return new_param;
2697 });
2698 return Status::OK();
2699 }
2700
HandleReduce(HloInstruction * hlo)2701 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
2702 int64 input_count = 1;
2703 auto per_input_sharding = hlo->sharding();
2704 if (hlo->shape().IsTuple()) {
2705 input_count = hlo->shape().tuple_shapes_size();
2706 CHECK_GT(input_count, 0);
2707 per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2708 }
2709
2710 std::vector<PartitionedHlo> inputs;
2711 std::vector<HloInstruction*> inits;
2712 std::vector<int64> preserved_dims;
2713 for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
2714 if (!absl::c_linear_search(hlo->dimensions(), i)) {
2715 preserved_dims.push_back(i);
2716 }
2717 }
2718
2719 for (int64 operand_id = 0; operand_id < input_count; ++operand_id) {
2720 inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
2721 .Reshard(HloSharding::Replicate())
2722 .hlo());
2723 inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
2724 if (operand_id > 0) {
2725 // Make sure all operands are sharded in the same way.
2726 inputs.back() = inputs.back().Reshard(inputs[0].sharding());
2727 }
2728 if (!inputs[0].sharding().IsTileMaximal()) {
2729 inputs.back() =
2730 inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
2731 /*skipped_dims=*/preserved_dims);
2732 }
2733 }
2734
2735 std::vector<Shape*> new_operand_shapes(input_count * 2);
2736 for (int64 i = 0; i < input_count; ++i) {
2737 new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
2738 new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
2739 }
2740 // Create the shard shape of the reduce result.
2741 TF_ASSIGN_OR_RETURN(
2742 auto reduce_shape,
2743 ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
2744 hlo->to_apply()->ComputeProgramShape()));
2745
2746 std::vector<HloInstruction*> input_hlos(input_count);
2747 for (int64 i = 0; i < input_count; ++i) {
2748 input_hlos[i] = inputs[i].hlo();
2749 }
2750 auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2751 reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
2752 local_reduce->set_metadata(hlo->metadata());
2753
2754 SetPartitionedHlo(hlo, [&]() {
2755 HloInstruction* reduce = local_reduce;
2756 const bool reduce_sharded_dimension =
2757 !inputs[0].sharding().IsTileMaximal() &&
2758 absl::c_any_of(hlo->dimensions(), [&](int64 i) {
2759 return inputs[0].sharding().tile_assignment().dim(i) > 1;
2760 });
2761 if (reduce_sharded_dimension) {
2762 if (inputs[0].sharding().ReplicateOnLastTileDim()) {
2763 preserved_dims.push_back(inputs[0].base_shape().rank());
2764 }
2765 if (local_reduce->shape().IsArray()) {
2766 reduce = partitioner_->AllReduceAlongShardingDims(
2767 &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
2768 hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
2769 } else {
2770 auto grouped =
2771 GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
2772 auto grouped_state = CreatePerGroupPartitioningState(
2773 inputs[0].state(), grouped.device_groups, &b_);
2774 std::vector<HloInstruction*> all_gathered_partial_results(input_count);
2775 for (int64 i = 0; i < input_count; ++i) {
2776 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2777 ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
2778 i));
2779 auto expanded_shape = input_hlos[i]->shape();
2780 auto all_gather_shape = input_hlos[i]->shape();
2781 for (int64 dim : hlo->dimensions()) {
2782 expanded_shape.set_dimensions(dim, 1);
2783 all_gather_shape.set_dimensions(
2784 dim, inputs[0].sharding().tile_assignment().dim(dim));
2785 }
2786 auto reshape = b_.AddInstruction(
2787 HloInstruction::CreateReshape(expanded_shape, gte));
2788 // Replicate per group.
2789 reshape->set_sharding(grouped.sharding);
2790 all_gathered_partial_results[i] =
2791 PartitionedHlo(reshape, all_gather_shape, grouped_state)
2792 .Replicate()
2793 .hlo();
2794 }
2795 reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2796 reduce_shape, all_gathered_partial_results, inits,
2797 hlo->dimensions(), hlo->to_apply()));
2798 }
2799 }
2800 auto sharding = hlo_sharding_util::RemoveShapeDimensions(
2801 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2802 inputs[0].sharding(), hlo->dimensions()),
2803 hlo->dimensions());
2804 if (local_reduce->shape().IsArray()) {
2805 reduce->set_sharding(sharding);
2806 } else {
2807 reduce->set_sharding(HloSharding::Tuple(
2808 reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
2809 }
2810 return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
2811 .Reshard(hlo->sharding())
2812 .hlo();
2813 });
2814 return Status::OK();
2815 }
2816
HandleReverse(HloInstruction * hlo)2817 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
2818 auto reverse = Cast<HloReverseInstruction>(hlo);
2819 if (reverse->sharding().IsTileMaximal()) {
2820 return DefaultAction(hlo);
2821 }
2822 auto operand = GetPartitionedHlo(reverse->operand(0))
2823 .Reshard(hlo_sharding_util::ReverseSharding(
2824 reverse->sharding(), reverse->dimensions()));
2825 auto left_padded_operand =
2826 HaloExchangeToPadOnLeft(operand, reverse->dimensions());
2827 if (!left_padded_operand) {
2828 return DefaultAction(hlo);
2829 }
2830 SetPartitionedHlo(hlo, [&] {
2831 return b_.AddInstruction(hlo->CloneWithNewOperands(
2832 left_padded_operand->shape(), {left_padded_operand}));
2833 });
2834 return Status::OK();
2835 }
2836
HandleWhile(HloInstruction * hlo)2837 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
2838 const HloSharding& sharding = hlo->sharding();
2839
2840 // Shardings for the body parameter, body root, and cond parameter must be
2841 // the same, and the condition root must be replicated so that all partitions
2842 // follow the same control flow.
2843 hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
2844 hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
2845 TF_RETURN_IF_ERROR(partitioner_
2846 ->PartitionComputation(hlo->while_condition(),
2847 HloSharding::Replicate(),
2848 next_channel_id_, logger_)
2849 .status());
2850 TF_RETURN_IF_ERROR(partitioner_
2851 ->PartitionComputation(hlo->while_body(), sharding,
2852 next_channel_id_, logger_)
2853 .status());
2854 SetPartitionedHlo(hlo, [&] {
2855 return b_.AddInstruction(HloInstruction::CreateWhile(
2856 MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
2857 hlo->while_body(),
2858 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
2859 });
2860 return Status::OK();
2861 }
2862
HandleConditional(HloInstruction * hlo)2863 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
2864 std::vector<HloInstruction*> branch_args;
2865 for (int64 i = 0; i < hlo->branch_count(); ++i) {
2866 HloComputation* computation = hlo->branch_computation(i);
2867
2868 // Shardings of the branch computation parameter and its argument must be
2869 // the same.
2870 computation->parameter_instruction(0)->set_sharding(
2871 hlo->operand(i + 1)->sharding());
2872 branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
2873 }
2874
2875 // The root of the branch computations must follow the sharding of the
2876 // conditional instruction.
2877 for (int64 i = 0; i < hlo->branch_count(); ++i) {
2878 HloComputation* computation = hlo->branch_computation(i);
2879 TF_RETURN_IF_ERROR(partitioner_
2880 ->PartitionComputation(computation, hlo->sharding(),
2881 next_channel_id_, logger_)
2882 .status());
2883 }
2884
2885 // We replicate the predicate of the conditional (the first operand) so that
2886 // all partitions follow the same control flow.
2887 SetPartitionedHlo(hlo, [&] {
2888 return b_.AddInstruction(HloInstruction::CreateConditional(
2889 MakePartitionedShape(hlo->shape(), hlo->sharding()),
2890 GetPartitionedHlo(hlo->operand(0))
2891 .Reshard(HloSharding::Replicate())
2892 .hlo(),
2893 hlo->called_computations(), branch_args));
2894 });
2895 return Status::OK();
2896 }
2897
HandleOutfeed(HloInstruction * hlo)2898 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
2899 if (hlo->sharding().HasUniqueDevice()) {
2900 return HandleSingleDevice(hlo);
2901 }
2902
2903 const auto& sharding = hlo->sharding();
2904 const Shape& shape = hlo->operand(0)->shape();
2905 auto partitioned_operand =
2906 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2907 const auto& shard_shape = partitioned_operand.hlo()->shape();
2908 const auto& operand = partitioned_operand.hlo();
2909 auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
2910
2911 if (EvenlyPartitions(shape, sharding)) {
2912 Shape outfeed_shape = operand->shape();
2913 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
2914 &outfeed_shape));
2915 SetPartitionedHlo(hlo, [&]() {
2916 return b_.AddInstruction(HloInstruction::CreateOutfeed(
2917 outfeed_shape, operand, token, hlo->outfeed_config()));
2918 });
2919 return Status::OK();
2920 }
2921
2922 // Create a branch for each unique partitioned shape.
2923 std::vector<Shape> per_branch_partitioned_shapes;
2924 std::vector<int32> conditional_branch_indices(num_partitions_);
2925 for (int64 i = 0; i < num_partitions_; ++i) {
2926 auto partitioned_shape =
2927 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2928 int64 matching_existing_index = 0;
2929 for (; matching_existing_index < per_branch_partitioned_shapes.size();
2930 ++matching_existing_index) {
2931 if (ShapeUtil::Compatible(
2932 partitioned_shape,
2933 per_branch_partitioned_shapes[matching_existing_index])) {
2934 break;
2935 }
2936 }
2937 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2938 conditional_branch_indices[i] = matching_existing_index;
2939 } else {
2940 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2941 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2942 }
2943 }
2944
2945 // Get branch index for this partition.
2946 HloInstruction* branch_index;
2947 if (per_branch_partitioned_shapes.size() == num_partitions_) {
2948 // Use partition ID as the branch index if each partition has its own
2949 // branch.
2950 branch_index = partition_id_;
2951 // PartitionId's output is U32 but conditional requires S32.
2952 if (branch_index->shape().element_type() != S32) {
2953 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2954 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2955 branch_index));
2956 }
2957 } else {
2958 // Otherwise, use a constant table to look up the branch index.
2959 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2960 LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2961 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2962 ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2963 {1}));
2964 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2965 ShapeUtil::MakeShape(S32, {}), branch_index));
2966 }
2967
2968 // Create conditional for the outfeed.
2969 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2970 for (int64 i = 0; i < branches.size(); ++i) {
2971 SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
2972 // Create tuple param within the branch.
2973 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2974 /*parameter_number=*/0,
2975 ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
2976 "outfeed_token_param"));
2977 auto outfeed_data = branch_b.AddInstruction(
2978 HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
2979 auto outfeed_token = branch_b.AddInstruction(
2980 HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
2981 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2982 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2983 slice_outfeed =
2984 [&](const ShapeIndex& index,
2985 HloInstruction* outfeed_operand) -> HloInstruction* {
2986 // Get outfeed element shape.
2987 const Shape& element_shape =
2988 ShapeUtil::GetSubshape(outfeed_data->shape(), index);
2989 // Recursively call slice_outfeed for tuple shapes.
2990 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2991 std::vector<HloInstruction*> slice_elements(
2992 element_shape.tuple_shapes_size());
2993 for (int64 i = 0; i < slice_elements.size(); ++i) {
2994 auto sub_index = index;
2995 sub_index.push_back(i);
2996 slice_elements[i] = slice_outfeed(
2997 sub_index,
2998 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2999 ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
3000 i)));
3001 }
3002 return branch_b.AddInstruction(
3003 HloInstruction::CreateTuple(slice_elements));
3004 }
3005 // Get the slice shape.
3006 const Shape& slice_shape = ShapeUtil::GetSubshape(
3007 per_branch_partitioned_shapes[i], ShapeIndexView(index));
3008 if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3009 return outfeed_operand;
3010 }
3011 // Slice out useful data.
3012 if (element_shape.IsArray()) {
3013 CHECK(slice_shape.IsArray());
3014 std::vector<int64> start_indices(slice_shape.rank(), 0);
3015 std::vector<int64> slice_strides(slice_shape.rank(), 1);
3016 return branch_b.AddInstruction(HloInstruction::CreateSlice(
3017 slice_shape, outfeed_operand, start_indices,
3018 slice_shape.dimensions(), slice_strides));
3019 }
3020 CHECK(element_shape.IsTuple());
3021 CHECK(element_shape.tuple_shapes().empty());
3022 return outfeed_operand;
3023 };
3024 outfeed_data = slice_outfeed({}, outfeed_data);
3025 }
3026 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3027 hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3028 branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3029 per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3030 hlo->outfeed_config()));
3031 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3032 }
3033 SetPartitionedHlo(hlo, [&]() {
3034 return b_.AddInstruction(HloInstruction::CreateConditional(
3035 token->shape(), branch_index, branches,
3036 std::vector<HloInstruction*>(
3037 branches.size(),
3038 b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3039 });
3040 return Status::OK();
3041 }
3042
HandleRng(HloInstruction * hlo)3043 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3044 if (hlo->sharding().HasUniqueDevice()) {
3045 return HandleSingleDevice(hlo);
3046 }
3047
3048 if (hlo->sharding().IsReplicated()) {
3049 SetPartitionedHlo(hlo, [&] {
3050 // Run on a single device (0) and distribute the data to all other cores.
3051 std::vector<HloInstruction*> new_operands;
3052 for (int64 i = 0; i < hlo->operand_count(); ++i) {
3053 new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3054 .Reshard(HloSharding::AssignDevice(0))
3055 .hlo());
3056 }
3057 auto clone = b_.AddInstruction(
3058 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3059 clone->set_sharding(HloSharding::AssignDevice(0));
3060 return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3061 .Reshard(HloSharding::Replicate())
3062 .hlo();
3063 });
3064 return Status::OK();
3065 }
3066
3067 TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3068 // Replicate the operands and run partitioned Rng on all devices.
3069 std::vector<HloInstruction*> new_operands;
3070 for (int64 i = 0; i < hlo->operand_count(); ++i) {
3071 new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3072 .Reshard(HloSharding::Replicate())
3073 .hlo());
3074 }
3075
3076 if (!hlo->sharding().ReplicateOnLastTileDim()) {
3077 SetPartitionedHlo(hlo, [&] {
3078 return b_.AddInstruction(HloInstruction::CreateRng(
3079 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3080 hlo->random_distribution(), new_operands));
3081 });
3082 } else {
3083 std::vector<int64> group_dims(
3084 hlo->sharding().tile_assignment().num_dimensions() - 1);
3085 std::iota(group_dims.begin(), group_dims.end(), 0);
3086 auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims);
3087 auto per_group_state = CreatePerGroupPartitioningState(
3088 MakePartitioningState(), sharding_grouped.device_groups, &b_);
3089 auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3090 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3091 hlo->random_distribution(), new_operands));
3092 rng->set_sharding(HloSharding::AssignDevice(0));
3093 SetPartitionedHlo(hlo, [&]() {
3094 return PartitionedHlo(rng, rng->shape(), per_group_state)
3095 .Replicate()
3096 .hlo();
3097 });
3098 }
3099 return Status::OK();
3100 }
3101
HandleReduceWindow(HloInstruction * hlo)3102 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3103 // TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
3104 if (hlo->shape().IsTuple()) {
3105 return DefaultAction(hlo);
3106 }
3107 auto& operand = GetPartitionedHlo(hlo->operand(0));
3108 if (hlo->sharding().IsTileMaximal()) {
3109 return DefaultAction(hlo);
3110 }
3111
3112 // Replicate init
3113 auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1))
3114 .Reshard(HloSharding::Replicate());
3115 auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3116 hlo->window(), hlo->sharding(), replicated_init.hlo());
3117 if (!resharded_operand_and_window.has_value()) {
3118 return DefaultAction(hlo);
3119 }
3120
3121 TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3122 ShapeInference::InferReduceWindowShape(
3123 resharded_operand_and_window->sharded_input->shape(),
3124 replicated_init.hlo()->shape(),
3125 resharded_operand_and_window->shard_window,
3126 hlo->to_apply()->ComputeProgramShape()));
3127 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3128 *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3129 SetPartitionedHlo(hlo, [&]() {
3130 auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow(
3131 sharded_rw_shape, resharded_operand_and_window->sharded_input,
3132 replicated_init.hlo(), resharded_operand_and_window->shard_window,
3133 hlo->to_apply()));
3134 if (!resharded_operand_and_window->dynamic_slice_index_on_output
3135 .has_value()) {
3136 CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()));
3137 return sharded_rw;
3138 }
3139 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3140 shard_shape, sharded_rw,
3141 *resharded_operand_and_window->dynamic_slice_index_on_output,
3142 shard_shape.dimensions()));
3143 });
3144 return Status::OK();
3145 }
3146
HandleSelectAndScatter(HloInstruction * hlo)3147 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3148 if (hlo->sharding().IsTileMaximal()) {
3149 return DefaultAction(hlo);
3150 }
3151 auto operand = GetPartitionedHlo(hlo->operand(0));
3152 auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3153 if (hlo->sharding() != operand.sharding()) {
3154 operand = operand.Reshard(hlo->sharding());
3155 }
3156 if (hlo->sharding() != source.sharding()) {
3157 source = source.Reshard(hlo->sharding());
3158 }
3159
3160 // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3161 // low/high padding, since comparison will return false with NaN input.
3162 if (hlo->shape().element_type() != F32 &&
3163 hlo->shape().element_type() != BF16) {
3164 return DefaultAction(hlo);
3165 }
3166
3167 auto select = hlo->called_computations()[0];
3168 auto select_root = select->root_instruction();
3169 if (select_root->opcode() != HloOpcode::kCompare ||
3170 select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3171 select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3172 select_root->operand(0)->parameter_number() +
3173 select_root->operand(1)->parameter_number() !=
3174 1) {
3175 return DefaultAction(hlo);
3176 }
3177
3178 float float_pad_value;
3179 if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3180 select_root->comparison_direction() == ComparisonDirection::kGt) {
3181 if (select_root->operand(0)->parameter_number() == 0) {
3182 float_pad_value = -std::numeric_limits<float>::infinity();
3183 } else {
3184 float_pad_value = std::numeric_limits<float>::infinity();
3185 }
3186 } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3187 select_root->comparison_direction() == ComparisonDirection::kLt) {
3188 if (select_root->operand(0)->parameter_number() == 0) {
3189 float_pad_value = std::numeric_limits<float>::infinity();
3190 } else {
3191 float_pad_value = -std::numeric_limits<float>::infinity();
3192 }
3193 } else {
3194 return DefaultAction(hlo);
3195 }
3196
3197 auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
3198 hlo->shape().element_type() == BF16
3199 ? LiteralUtil::CreateR0<bfloat16>(
3200 static_cast<bfloat16>(float_pad_value))
3201 : LiteralUtil::CreateR0<float>(float_pad_value)));
3202
3203 // Replicate init
3204 auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
3205 .Reshard(HloSharding::Replicate());
3206
3207 auto partition_ordinals =
3208 MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
3209
3210 // The first window for each dimension that overlaps with the shard area.
3211 std::vector<MultiplyAddDivideOffsetCalculation> first_window(
3212 hlo->shape().rank());
3213 // The first window for each dimension that goes beyond with the shard area.
3214 std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
3215 hlo->shape().rank());
3216 std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
3217 std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
3218 std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
3219 std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
3220 auto unpadded_data_shard_shape =
3221 MakePartitionedShape(hlo->shape(), hlo->sharding());
3222 auto unpadded_source_shard_shape =
3223 MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
3224 auto source_shard_hlo = source.hlo();
3225 auto data_shard_hlo = operand.hlo();
3226 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
3227 int64 shard_count = hlo->sharding().tile_assignment().dim(i);
3228 if (shard_count == 1) {
3229 continue;
3230 }
3231 // If stride > window_size, there will be gaps between windows. These gaps
3232 // will also exist in the output, so we keep them during halo exchange.
3233 //
3234 // TODO(yuanzx): This could introduce overhead if partitions start at
3235 // different offsets in a gap.
3236 auto wd = hlo->window().dimensions(i);
3237 if (wd.stride() > wd.size()) {
3238 wd.set_size(wd.stride());
3239 }
3240 // shard_size * i < stride * k - pad_low + window_size =>
3241 // k > (shard_size * i + pad_low - window_size) / stride =>
3242 // first_k == (shard_size * i + pad_low - window_size + stride) / stride
3243 first_window[i] = MultiplyAddDivideOffsetCalculation(
3244 unpadded_data_shard_shape.dimensions(i),
3245 wd.padding_low() - wd.size() + wd.stride(), wd.stride());
3246 // shard_size * (i + 1) <= stride * k - pad_low =>
3247 // k >= (shard_size * i + shard_size + pad_low) / stride =>
3248 // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
3249 // stride
3250 limit_window[i] = MultiplyAddDivideOffsetCalculation(
3251 unpadded_data_shard_shape.dimensions(i),
3252 unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
3253 wd.stride() - 1,
3254 wd.stride());
3255 source_left_halo_sizes[i] =
3256 MultiplyAddDivideOffsetCalculation(
3257 unpadded_source_shard_shape.dimensions(i), 0, 1) -
3258 first_window[i];
3259 source_right_halo_sizes[i] =
3260 limit_window[i] - MultiplyAddDivideOffsetCalculation(
3261 unpadded_source_shard_shape.dimensions(i),
3262 unpadded_source_shard_shape.dimensions(i), 1);
3263 data_left_halo_sizes[i] =
3264 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3265 unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
3266 OffsetCalculation(
3267 HloOpcode::kMultiply, first_window[i],
3268 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
3269 data_right_halo_sizes[i] =
3270 OffsetCalculation(
3271 HloOpcode::kMultiply, limit_window[i],
3272 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
3273 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3274 unpadded_data_shard_shape.dimensions(i),
3275 unpadded_data_shard_shape.dimensions(i) + wd.stride() +
3276 wd.padding_low() - wd.size(),
3277 1));
3278
3279 int64 max_windows =
3280 (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
3281 auto first_window_hlo =
3282 first_window[i].Calculate(partition_ordinals[i], &b_);
3283 // Padding on the source is filled with the init value so they do not change
3284 // the data on overlapping windows.
3285 auto resharded_source = ExchangeHaloAndGetValidData(
3286 source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
3287 source_right_halo_sizes[i], 0,
3288 limit_window[i].Calculate(shard_count - 1), max_windows, i,
3289 hlo->sharding(), first_window_hlo, replicated_init.hlo(),
3290 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3291 if (!resharded_source) {
3292 return DefaultAction(hlo);
3293 }
3294 source_shard_hlo = *resharded_source;
3295
3296 auto offset_start_in_data =
3297 MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
3298 .Calculate(first_window_hlo, &b_);
3299 int64 padded_data_size =
3300 (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
3301 wd.size();
3302 int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
3303 auto resharded_data = ExchangeHaloAndGetValidData(
3304 data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
3305 data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
3306 data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
3307 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3308 if (!resharded_data) {
3309 return DefaultAction(hlo);
3310 }
3311 data_shard_hlo = *resharded_data;
3312 }
3313
3314 Window window_on_shard = hlo->window();
3315 for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
3316 int64 shard_count = hlo->sharding().tile_assignment().dim(i);
3317 if (shard_count == 1) {
3318 continue;
3319 }
3320 auto reshard_wd = window_on_shard.mutable_dimensions(i);
3321 // The shards are already explicitly padded.
3322 reshard_wd->set_padding_low(0);
3323 reshard_wd->set_padding_high(0);
3324 }
3325
3326 auto sharded_select_and_scatter =
3327 b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
3328 data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
3329 source_shard_hlo, replicated_init.hlo(),
3330 hlo->called_computations()[1]));
3331 SetPartitionedHlo(hlo, [&]() {
3332 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3333 if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
3334 shard_shape)) {
3335 return sharded_select_and_scatter;
3336 }
3337 auto zero = b_.AddInstruction(
3338 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
3339 std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
3340 for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
3341 if (hlo->sharding().tile_assignment().dim(i) == 1) {
3342 continue;
3343 }
3344 int64 pad_low = hlo->window().dimensions(i).padding_low();
3345 auto left_halo_size =
3346 data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
3347 if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
3348 slice_offsets[i] = left_halo_size;
3349 } else {
3350 auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
3351 ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
3352 ComparisonDirection::kEq));
3353 auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
3354 LiteralUtil::CreateR0<int32>(pad_low)));
3355 slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
3356 zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
3357 left_halo_size));
3358 }
3359 }
3360 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3361 shard_shape, sharded_select_and_scatter, slice_offsets,
3362 shard_shape.dimensions()));
3363 });
3364 return Status::OK();
3365 }
3366
HandleTuple(HloInstruction * hlo)3367 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
3368 std::vector<HloInstruction*> new_operands;
3369 for (int64 i = 0; i < hlo->operand_count(); ++i) {
3370 new_operands.push_back(
3371 GetPartitionedHlo(hlo->operand(i))
3372 .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
3373 .hlo());
3374 }
3375 SetPartitionedHlo(hlo, [&]() {
3376 return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
3377 });
3378 return Status::OK();
3379 }
3380
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)3381 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
3382 HloComputation* computation, const HloSharding& root_sharding,
3383 const SpmdPartitionerOptions& options) {
3384 VLOG(2) << "Partitioning computation " << computation->name() << " for "
3385 << num_replicas_ << " replicas and " << num_partitions_
3386 << " partitions";
3387 TF_RETURN_IF_ERROR(computation->Accept(this));
3388
3389 HloModule* module = computation->parent();
3390 auto new_root =
3391 GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
3392 auto new_computation =
3393 module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
3394 TF_RETURN_IF_ERROR(
3395 DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
3396
3397 // Replace the original computation with the new SPMD computation.
3398 std::unordered_map<HloComputation*, HloComputation*> replacement;
3399 replacement[computation] = new_computation;
3400 module->ReplaceComputations(replacement);
3401 return changed_;
3402 }
3403
HandlePartitionId(HloInstruction * hlo)3404 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
3405 return Unimplemented(
3406 "PartitionId instruction is not supported for SPMD partitioning since "
3407 "the meaning is ambiguous -- whether the instruction is replicated or "
3408 "the data is replicated, and if the latter which data is replicated.");
3409 }
3410
GetDefaultCollectiveOpsCreator(int64 num_partitions,int64 num_replicas)3411 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
3412 int64 num_replicas) {
3413 return {
3414 [](SpmdBuilder* b) {
3415 return b->AddInstruction(HloInstruction::CreatePartitionId());
3416 },
3417 [num_replicas, num_partitions](
3418 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
3419 const std::vector<std::vector<int64>>& partition_subgroups,
3420 int64 channel_id) {
3421 if (partition_subgroups.size() <= 1) {
3422 std::vector<ReplicaGroup> groups(num_replicas);
3423 // TODO(yuanzx): Unify subgroup definition with AllToAll.
3424 for (int64 i = 0; i < num_replicas; ++i) {
3425 groups[i].add_replica_ids(i);
3426 }
3427 return b->AddInstruction(HloInstruction::CreateAllReduce(
3428 operand->shape(), {operand}, reduction, groups,
3429 /*constrain_layout=*/false, channel_id,
3430 /*use_global_device_ids=*/false));
3431 }
3432
3433 std::vector<ReplicaGroup> device_groups;
3434 device_groups.reserve(partition_subgroups.size() * num_replicas);
3435 for (int64 i = 0; i < num_replicas; ++i) {
3436 for (const auto& pgroup : partition_subgroups) {
3437 device_groups.emplace_back();
3438 for (int64 pid : pgroup) {
3439 device_groups.back().add_replica_ids(i * num_partitions + pid);
3440 }
3441 }
3442 }
3443 return b->AddInstruction(HloInstruction::CreateAllReduce(
3444 operand->shape(), {operand}, reduction, device_groups,
3445 /*constrain_layout=*/false, channel_id,
3446 /*use_global_device_ids=*/true));
3447 },
3448 [](SpmdBuilder* b, HloInstruction* operand,
3449 std::vector<std::pair<int64, int64>>& src_dst_pairs,
3450 int64 channel_id) {
3451 return b->AddInstruction(HloInstruction::CreateCollectivePermute(
3452 operand->shape(), operand, src_dst_pairs, channel_id));
3453 },
3454 [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
3455 const std::vector<std::vector<int64>>& partition_subgroups,
3456 int64 channel_id, absl::optional<int64> split_dimension) {
3457 std::vector<Shape> shapes(operands.size(), operands[0]->shape());
3458 const Shape output_shape = (shapes.size() == 1)
3459 ? shapes[0]
3460 : ShapeUtil::MakeTupleShape(shapes);
3461 std::vector<ReplicaGroup> groups(partition_subgroups.size());
3462 for (int64 i = 0; i < groups.size(); ++i) {
3463 for (int64 id : partition_subgroups[i]) {
3464 groups[i].add_replica_ids(id);
3465 }
3466 }
3467 return b->AddInstruction(HloInstruction::CreateAllToAll(
3468 output_shape, operands, groups,
3469 /*constrain_layout=*/false, channel_id, split_dimension));
3470 },
3471 [num_replicas, num_partitions](
3472 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
3473 const std::vector<std::vector<int64>>& partition_subgroups,
3474 int64 channel_id, int64 all_gather_dimension) {
3475 std::vector<ReplicaGroup> device_groups;
3476 device_groups.reserve(partition_subgroups.size() * num_replicas);
3477 for (int64 i = 0; i < num_replicas; ++i) {
3478 for (const auto& pgroup : partition_subgroups) {
3479 device_groups.emplace_back();
3480 for (int64 pid : pgroup) {
3481 device_groups.back().add_replica_ids(i * num_partitions + pid);
3482 }
3483 }
3484 }
3485 return b->AddInstruction(HloInstruction::CreateAllGather(
3486 ag_shape, operand, all_gather_dimension, device_groups,
3487 /*constrain_layout=*/false, channel_id,
3488 /*use_global_device_ids=*/true));
3489 },
3490 };
3491 }
3492
SpmdPartitioner(int64 num_partitions,int64 num_replicas,SpmdPartitionerOptions options)3493 SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
3494 SpmdPartitionerOptions options)
3495 : SpmdPartitioner(
3496 num_partitions, num_replicas, std::move(options),
3497 GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
3498
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)3499 HloInstruction* SpmdPartitioner::AllGatherShards(
3500 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3501 int64* next_channel_id, absl::Span<const int64> selected_dims,
3502 const SPMDCollectiveOpsCreator& collectives_creator) {
3503 return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
3504 selected_dims, collectives_creator,
3505 /*per_dim_ag=*/true);
3506 }
3507
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)3508 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
3509 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3510 int64* next_channel_id, absl::Span<const int64> selected_dims,
3511 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
3512 if (selected_dims.empty()) {
3513 return operand;
3514 }
3515 CHECK(!sharding.IsTileMaximal());
3516 // Add one leading dimension to gather all partitions.
3517 std::vector<int64> shape;
3518 shape.push_back(1);
3519 for (int64 dim : operand->shape().dimensions()) {
3520 shape.push_back(dim);
3521 }
3522 auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
3523 ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
3524 HloInstruction* result = reshape;
3525 if (per_dim_ag) {
3526 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3527 if (sharding.tile_assignment().dim(*it) == 1) {
3528 continue;
3529 }
3530 auto partition_subgroups =
3531 GetPartitionGroupsForReplication(sharding, {*it});
3532 shape[0] *= partition_subgroups[0].size();
3533 result = collectives_creator.create_cross_partition_all_gather(
3534 b, result,
3535 ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3536 partition_subgroups, (*next_channel_id)++,
3537 /*all_gather_dimension=*/0);
3538 }
3539 } else {
3540 auto partition_subgroups =
3541 GetPartitionGroupsForReplication(sharding, selected_dims);
3542 shape[0] *= partition_subgroups[0].size();
3543 result = collectives_creator.create_cross_partition_all_gather(
3544 b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3545 partition_subgroups, (*next_channel_id)++,
3546 /*all_gather_dimension=*/0);
3547 }
3548 // If n > 1 dimensions are partitioned, split the leading dimension to n.
3549 std::vector<int64> tiled_dims;
3550 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
3551 if (sharding.tile_assignment().dim(i) > 1 &&
3552 absl::c_linear_search(selected_dims, i)) {
3553 tiled_dims.push_back(i);
3554 }
3555 }
3556 if (tiled_dims.size() > 1) {
3557 std::vector<int64> split_dim_shape;
3558 split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
3559 for (int64 i : tiled_dims) {
3560 split_dim_shape.push_back(sharding.tile_assignment().dim(i));
3561 }
3562 for (int64 dim : operand->shape().dimensions()) {
3563 split_dim_shape.push_back(dim);
3564 }
3565 result = b->AddInstruction(HloInstruction::CreateReshape(
3566 ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
3567 result));
3568 }
3569 // Transpose the gathered dimensions to next to their corresponding
3570 // partitioned dimensions.
3571 std::vector<int64> xpose_permutation(result->shape().rank());
3572 int64 split_dims_added = 0;
3573 for (int64 i = 0; i < xpose_permutation.size(); ++i) {
3574 if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
3575 !absl::c_linear_search(selected_dims, i - split_dims_added)) {
3576 xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
3577 } else {
3578 xpose_permutation[i] = split_dims_added;
3579 xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
3580 split_dims_added++;
3581 i++;
3582 }
3583 }
3584 result = b->AddInstruction(HloInstruction::CreateTranspose(
3585 ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
3586 .ValueOrDie(),
3587 result, xpose_permutation));
3588 // Reshape to the desired shape.
3589 auto ag_shape = operand->shape();
3590 for (int64 i : tiled_dims) {
3591 ag_shape.set_dimensions(
3592 i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
3593 }
3594 result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
3595 return result;
3596 }
3597
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)3598 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
3599 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3600 int64* next_channel_id, absl::Span<const int64> selected_dims,
3601 const SPMDCollectiveOpsCreator& collectives_creator,
3602 HloComputation* reduction) {
3603 return AllReduceAlongShardingDimsInternal(
3604 b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
3605 reduction, /*per_dim_ar=*/true);
3606 }
3607
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)3608 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
3609 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3610 int64* next_channel_id, absl::Span<const int64> selected_dims,
3611 const SPMDCollectiveOpsCreator& collectives_creator,
3612 HloComputation* reduction, bool per_dim_ar) {
3613 if (!per_dim_ar) {
3614 auto partition_subgroups =
3615 GetPartitionGroupsForReplication(sharding, selected_dims);
3616 return collectives_creator.create_cross_partition_all_reduce(
3617 b, operand, reduction, partition_subgroups, (*next_channel_id)++);
3618 }
3619 auto result = operand;
3620 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3621 if (sharding.tile_assignment().dim(*it) == 1) {
3622 continue;
3623 }
3624 auto partition_subgroups =
3625 GetPartitionGroupsForReplication(sharding, {*it});
3626 result = collectives_creator.create_cross_partition_all_reduce(
3627 b, result, reduction, partition_subgroups, (*next_channel_id)++);
3628 }
3629 return result;
3630 }
3631
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64 * next_channel_id,SpmdLogger * logger)3632 StatusOr<bool> SpmdPartitioner::PartitionComputation(
3633 HloComputation* computation, const HloSharding& root_sharding,
3634 int64* next_channel_id, SpmdLogger* logger) {
3635 auto visitor =
3636 CreateVisitor(computation, num_partitions_, num_replicas_,
3637 collective_ops_creator_, next_channel_id, logger, options_);
3638 return visitor->DoPartition(computation, root_sharding, options_);
3639 }
3640
CreateVisitor(HloComputation * computation,int64 num_partitions,int64 num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)3641 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
3642 HloComputation* computation, int64 num_partitions, int64 num_replicas,
3643 const SPMDCollectiveOpsCreator& collective_ops_creator,
3644 int64* next_channel_id, SpmdLogger* logger,
3645 SpmdPartitionerOptions options) {
3646 return absl::make_unique<SpmdPartitioningVisitor>(
3647 computation, num_partitions, num_replicas, collective_ops_creator,
3648 next_channel_id, logger, std::move(options), this);
3649 }
3650
Run(HloModule * module)3651 StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
3652 TF_RETURN_IF_ERROR(PreprocessSharding(module));
3653
3654 XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
3655 *module, options_.report_instruction_count));
3656
3657 // Add the parameters' and output's shardings to the module.
3658 std::vector<HloSharding> entry_params_shardings;
3659 for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) {
3660 auto param = module->entry_computation()->parameter_instruction(i);
3661 CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
3662 entry_params_shardings.push_back(param->sharding());
3663 }
3664 module->set_spmd_parameters_shardings(entry_params_shardings);
3665 auto entry_root = module->entry_computation()->root_instruction();
3666 CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
3667 module->set_spmd_output_sharding(entry_root->sharding());
3668
3669 FlattenCallGraph flatten;
3670 TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
3671
3672 SpmdLogger logger(options_.report_instruction_count);
3673 auto program_shape = module->entry_computation()->ComputeProgramShape();
3674 int64 next_channel_id = hlo_query::NextChannelId(*module);
3675 // Copy the root sharding since the partitioner visitor may temporarily change
3676 // the sharding to work around manual sharding.
3677 HloSharding root_sharding = entry_root->sharding();
3678 TF_ASSIGN_OR_RETURN(
3679 bool partition_changed,
3680 PartitionComputation(module->entry_computation(), root_sharding,
3681 &next_channel_id, &logger));
3682 changed |= partition_changed;
3683
3684 // For the entry computation, make sure that the root instruction and the
3685 // parameters preserve their signatures.
3686 auto new_program_shape = module->entry_computation()->ComputeProgramShape();
3687 if (!options_.allow_module_signature_change) {
3688 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3689 program_shape.result(), new_program_shape.result()))
3690 << "Result shape changed for the entry computation";
3691 TF_RET_CHECK(program_shape.parameters_size() ==
3692 new_program_shape.parameters_size())
3693 << "Parameter count changed for the entry computation";
3694 for (int64 i = 0; i < program_shape.parameters_size(); ++i) {
3695 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3696 program_shape.parameters(i), new_program_shape.parameters(i)))
3697 << "Parameter shape changed for the entry computation";
3698 }
3699 } else {
3700 const auto& old_entry_layout = module->entry_computation_layout();
3701 // Shapes can change but the layout should still remain the same.
3702 for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) {
3703 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3704 old_entry_layout.parameter_shape(i),
3705 new_program_shape.mutable_parameters(i)));
3706 }
3707 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3708 old_entry_layout.result_shape(), new_program_shape.mutable_result()));
3709
3710 HloModuleConfig config = module->config();
3711 *config.mutable_entry_computation_layout() =
3712 ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
3713 module->set_config(config);
3714 }
3715
3716 XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
3717 *module, options_.report_instruction_count));
3718 XLA_VLOG_LINES(1, logger.MakeReport());
3719
3720 if (changed) {
3721 HloPassPipeline pass("spmd-cleanup");
3722 pass.AddPass<TupleSimplifier>();
3723 pass.AddPass<HloDCE>();
3724 pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
3725 pass.AddPass<FlattenCallGraph>();
3726 TF_RETURN_IF_ERROR(pass.Run(module).status());
3727 }
3728
3729 TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
3730 return changed;
3731 }
3732
PreprocessSharding(HloModule * module)3733 Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
3734 for (HloComputation* computation : module->computations()) {
3735 for (HloInstruction* hlo : computation->instructions()) {
3736 if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
3737 TF_RET_CHECK(hlo->has_sharding())
3738 << "Side-effect HLO must have sharding: " << hlo->ToString();
3739 TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
3740 hlo->opcode() == HloOpcode::kInfeed ||
3741 hlo->opcode() == HloOpcode::kOutfeed)
3742 << "Non-infeed side-effect HLO cannot have a replicated sharding:"
3743 << hlo->ToString();
3744 }
3745
3746 // For unassigned HLOs, annotate with replicated sharding.
3747 //
3748 // Among side-effecting ops, only Rng is allowed to omit the annotation.
3749 // In that case, we currently force it to run on core 0, since we don't
3750 // support partitioning or replicating the Rng op (the values depend on
3751 // the seed provided to each device).
3752 //
3753 // TODO(hyouklee): Should we also convert single-device shardings (without
3754 // side-effects) into replicated?
3755 if (!hlo->has_sharding()) {
3756 if (hlo->opcode() == HloOpcode::kRng) {
3757 hlo->set_sharding(HloSharding::AssignDevice(0));
3758 } else {
3759 hlo->set_sharding(
3760 HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
3761 }
3762 } else if (!hlo->sharding().IsTileMaximal() &&
3763 !hlo->sharding().IsManual()) {
3764 std::vector<int64> available(num_partitions_);
3765 std::iota(available.begin(), available.end(), 0);
3766 TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
3767 hlo->sharding(), available)
3768 .size())
3769 << "num_partitions:" << num_partitions_ << "\n"
3770 << "SPMD partitioner only supports tile sharding that includes all "
3771 "partitions. If you didn't add this sharding annotation in the "
3772 "model, please file a bug to XLA team.\n"
3773 << hlo->ToString();
3774 }
3775 }
3776 }
3777
3778 // Entry computation's parameter and root sharding must be either all
3779 // replicated or all on a single device.
3780 if (!options_.allow_module_signature_change) {
3781 const HloComputation* entry = module->entry_computation();
3782 TF_RET_CHECK(entry->root_instruction()->has_sharding());
3783 const HloSharding& root_sharding = entry->root_instruction()->sharding();
3784 TF_RET_CHECK(root_sharding.IsReplicated() ||
3785 root_sharding.UniqueDevice().has_value())
3786 << "Unsupported entry root sharding: " << root_sharding.ToString();
3787
3788 for (const HloInstruction* param : entry->parameter_instructions()) {
3789 TF_RET_CHECK(param->has_sharding());
3790 TF_RET_CHECK(param->sharding().IsReplicated() ||
3791 param->sharding().UniqueDevice().has_value())
3792 << "Unsupported entry parameter sharding:"
3793 << param->sharding().ToString();
3794 }
3795 }
3796
3797 return Status::OK();
3798 }
3799
3800 } // namespace spmd
3801 } // namespace xla
3802