1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/array.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33
34 namespace xla {
35 namespace hlo_sharding_util {
36
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)37 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
38 CHECK_EQ(lhs.IsTuple(), rhs.IsTuple());
39 if (lhs.IsTuple()) {
40 // For tuples we consider lhs to have a better sharding if none of the
41 // elements are worse and at least one element is better then in rhs
42 // sharding.
43 const auto& lhs_shardings = lhs.tuple_elements();
44 const auto& rhs_shardings = rhs.tuple_elements();
45 CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
46 bool is_better = false;
47 for (int64 i = 0; i < lhs_shardings.size(); ++i) {
48 if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
49 return false;
50 }
51 if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
52 is_better = true;
53 }
54 }
55 return is_better;
56 }
57 if (!rhs.IsTileMaximal()) {
58 return lhs.NumTiles() > rhs.NumTiles();
59 } else if (!rhs.IsReplicated()) {
60 // If we are not replicated then only tiled (not tile maximal) shardings
61 // can improve us.
62 return !lhs.IsTileMaximal();
63 } else {
64 // If we are replicated then any non-replicated sharding can improve us.
65 return !lhs.IsReplicated();
66 }
67 }
68
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)69 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
70 bool may_combine_partial_sharding) {
71 if (old.IsTuple()) {
72 CHECK(to_merge->IsTuple());
73 bool changed = false;
74 for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
75 changed |=
76 MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
77 may_combine_partial_sharding);
78 }
79 return changed;
80 }
81 if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
82 !to_merge->ReplicateOnLastTileDim() ||
83 old.tile_assignment().num_elements() !=
84 to_merge->tile_assignment().num_elements()) {
85 return IsShardingMoreSpecific(*to_merge, old);
86 }
87 // Combine the tile dimension sizes from new and old.
88 int64 num_devices = old.tile_assignment().num_elements();
89 std::vector<int64> new_tile_dims;
90 bool compatible = true;
91 new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
92 for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
93 int64 new_dim = to_merge->tile_assignment().dim(i);
94 int64 old_dim = old.tile_assignment().dim(i);
95 if (new_dim == 1) {
96 new_tile_dims.push_back(old_dim);
97 } else if (old_dim == 1) {
98 new_tile_dims.push_back(new_dim);
99 } else if (new_dim == old_dim) {
100 new_tile_dims.push_back(new_dim);
101 } else {
102 compatible = false;
103 break;
104 }
105 }
106 int64 replication = num_devices / Product(new_tile_dims);
107 if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
108 replication >= old.tile_assignment().dimensions().back()) {
109 return IsShardingMoreSpecific(*to_merge, old);
110 }
111 new_tile_dims.push_back(replication);
112 Array<int64> new_tile(new_tile_dims);
113 // Maps from replication group ID to sorted members.
114 absl::flat_hash_map<int64, std::set<int64>> old_group_members;
115 absl::flat_hash_map<int64, std::set<int64>> new_group_members;
116 auto get_group_index = [&](absl::Span<const int64> tile_indices,
117 const HloSharding& sharding) {
118 int64 group_id = 0;
119 for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
120 group_id *= to_merge->tile_assignment().dim(i);
121 group_id += tile_indices[i];
122 }
123 return group_id;
124 };
125 old.tile_assignment().Each(
126 [&](absl::Span<const int64> indices, int64 device) {
127 old_group_members[get_group_index(indices, old)].insert(device);
128 });
129 to_merge->tile_assignment().Each(
130 [&](absl::Span<const int64> indices, int64 device) {
131 new_group_members[get_group_index(indices, *to_merge)].insert(device);
132 });
133 // Try to find the intersection of old and new replication groups, in
134 // order to determine the merged tile assignment.
135 new_tile.Each([&](absl::Span<const int64> indices, int64* device) {
136 if (!compatible) {
137 return;
138 }
139 std::vector<int64> old_index(indices.begin(), indices.end());
140 std::vector<int64> new_index = old_index;
141 for (int64 i = 0; i < indices.size() - 1; ++i) {
142 if (old.tile_assignment().dim(i) == 1) {
143 old_index[i] = 0;
144 }
145 if (to_merge->tile_assignment().dim(i) == 1) {
146 new_index[i] = 0;
147 }
148 }
149 int64 old_group_id = get_group_index(old_index, old);
150 int64 new_group_id = get_group_index(new_index, *to_merge);
151 if (old_group_members[old_group_id].empty() ||
152 new_group_members[new_group_id].empty()) {
153 compatible = false;
154 return;
155 }
156
157 int64 smallest_old = *old_group_members[old_group_id].begin();
158 int64 smallest_new = *new_group_members[new_group_id].begin();
159 if (smallest_old < smallest_new) {
160 if (old_group_members[old_group_id].count(smallest_new) == 0) {
161 compatible = false;
162 return;
163 }
164 *device = smallest_new;
165 } else {
166 if (new_group_members[new_group_id].count(smallest_old) == 0) {
167 compatible = false;
168 return;
169 }
170 *device = smallest_old;
171 }
172 old_group_members[old_group_id].erase(*device);
173 new_group_members[new_group_id].erase(*device);
174 });
175 if (compatible) {
176 std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
177 merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
178 const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
179 protobuf_util::ProtobufEqualsWrapper>
180 metadata_set(merged_metadata.begin(), merged_metadata.end());
181 absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
182 [&metadata_set](const OpMetadata& data) {
183 return !ContainsKey(metadata_set, data);
184 });
185 if (replication == 1) {
186 new_tile_dims.pop_back();
187 new_tile.Reshape(new_tile_dims);
188 *to_merge = HloSharding::Tile(new_tile, merged_metadata);
189 } else {
190 *to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
191 }
192 return true;
193 }
194 return IsShardingMoreSpecific(*to_merge, old);
195 }
196
SelectDominantDevice(const std::map<int64,int64> & device_map,int64 * top_count)197 absl::optional<int64> SelectDominantDevice(
198 const std::map<int64, int64>& device_map, int64* top_count) {
199 int64 device = 0;
200 int64 count = 0;
201 for (auto& it : device_map) {
202 if (it.second > count) {
203 count = it.second;
204 device = it.first;
205 }
206 }
207 if (top_count != nullptr) {
208 *top_count = count;
209 }
210 return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
211 }
212
AssignComputationDevice(HloComputation * computation,int64 device)213 Status AssignComputationDevice(HloComputation* computation, int64 device) {
214 VLOG(4) << "Assigning device " << device << " to " << computation->name()
215 << " computation";
216 for (HloInstruction* instruction : computation->instructions()) {
217 if (!instruction->has_sharding()) {
218 VLOG(4) << "Assigning device " << device << " to " << instruction->name();
219 instruction->set_device_sharding(device);
220 }
221 }
222 return Status::OK();
223 }
224
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)225 absl::optional<int64> GetMostOccurringDevice(
226 absl::Span<HloInstruction* const> instructions) {
227 std::map<int64, int64> device_map;
228 for (HloInstruction* instruction : instructions) {
229 if (instruction->has_sharding()) {
230 for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
231 // The UsedDevices() API returns a map<device, occurrence_count>.
232 device_map[it.first] += it.second;
233 }
234 }
235 }
236 return SelectDominantDevice(device_map, nullptr);
237 }
238
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)239 StatusOr<absl::optional<int64>> GetDominantDevice(
240 absl::Span<HloComputation* const> computations, double dominant_factor) {
241 int64 instruction_count = 0;
242 std::map<int64, int64> device_map;
243 for (HloComputation* computation : computations) {
244 for (HloInstruction* instruction : computation->instructions()) {
245 int64 count = 1;
246 if (instruction->has_sharding()) {
247 for (auto& it : instruction->sharding().UsedDevices(&count)) {
248 // The UsedDevices() API returns a map<device, occurrence_count>.
249 device_map[it.first] += it.second;
250 }
251 }
252 instruction_count += count;
253 }
254 }
255 int64 count;
256 absl::optional<int64> device = SelectDominantDevice(device_map, &count);
257 absl::optional<int64> dominant_device;
258 if (device) {
259 double factor =
260 static_cast<double>(count) / static_cast<double>(instruction_count);
261 if (factor >= dominant_factor) {
262 dominant_device = device;
263 }
264 }
265 return dominant_device;
266 }
267
TransposeSharding(const HloSharding & sharding,const std::vector<int64> & dimensions)268 HloSharding TransposeSharding(const HloSharding& sharding,
269 const std::vector<int64>& dimensions) {
270 if (sharding.IsTileMaximal()) {
271 return sharding;
272 }
273 auto perm_dimensions = dimensions;
274 if (sharding.ReplicateOnLastTileDim() &&
275 dimensions.size() < sharding.tile_assignment().num_dimensions()) {
276 perm_dimensions.push_back(dimensions.size());
277 }
278 const int64 rank = perm_dimensions.size();
279 std::vector<int64> tile_assignment_dim(rank);
280 for (int64 i = 0; i < rank; ++i) {
281 tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]);
282 }
283 Array<int64> tile_assignment = sharding.tile_assignment();
284 tile_assignment.Reshape(tile_assignment_dim);
285 tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
286 std::vector<int64> src_indices(indices.size(), -1);
287 for (int64 i = 0; i < indices.size(); ++i) {
288 src_indices[perm_dimensions[i]] = indices[i];
289 }
290 *value = sharding.tile_assignment()(src_indices);
291 });
292 return sharding.ReplicateOnLastTileDim()
293 ? HloSharding::PartialTile(tile_assignment, sharding.metadata())
294 : HloSharding::Tile(tile_assignment, sharding.metadata());
295 }
296
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)297 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
298 const Shape& target_shape,
299 const HloSharding& sharding) {
300 if (sharding.IsTileMaximal()) {
301 return sharding;
302 }
303
304 // In case of a tiled sharding the reshaped sharding will be a valid if the
305 // reshape is composed from the following operations:
306 // * Adding or removing dimensions with size 1.
307 // * Merging consecutive dimensions where only the most major is sharded.
308 // * Splitting a dimension to consecutive dimensions.
309 // * Any reshaping of unsharded dimensions.
310 // Note that merge and split can happen consecutively on the same dimension,
311 // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
312 // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
313 // to make supporting such cases easy.
314 const Shape tile_shape = sharding.TileShape(source_shape);
315 std::vector<int64> target_tile_assignment_dimensions;
316 std::vector<int64> source_dims_stack(source_shape.rank());
317 std::vector<int64> target_dims_stack(target_shape.rank());
318 std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
319 for (int64 i = 0; i < source_shape.rank(); ++i) {
320 source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
321 sharding_tile_dims_stack[i] =
322 sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
323 }
324 for (int64 i = 0; i < target_shape.rank(); ++i) {
325 target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
326 }
327 while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
328 if (target_dims_stack.empty()) {
329 if (Product(sharding_tile_dims_stack) != 1) {
330 return absl::nullopt;
331 }
332 break;
333 }
334 int64 s_size = 1;
335 int64 t_size = 1;
336 int64 s_partitions = 1;
337 if (!source_dims_stack.empty()) {
338 s_size = source_dims_stack.back();
339 source_dims_stack.pop_back();
340 s_partitions = sharding_tile_dims_stack.back();
341 sharding_tile_dims_stack.pop_back();
342 }
343 t_size = target_dims_stack.back();
344 target_dims_stack.pop_back();
345 if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
346 // No more partitions left.
347 target_tile_assignment_dimensions.push_back(1);
348 continue;
349 }
350 if (s_size == t_size) {
351 // Same dimension.
352 target_tile_assignment_dimensions.push_back(s_partitions);
353 } else if (t_size == 1) {
354 // Trivial dimension added.
355 target_tile_assignment_dimensions.push_back(1);
356 source_dims_stack.push_back(s_size);
357 sharding_tile_dims_stack.push_back(s_partitions);
358 } else if (s_size == 1) {
359 // Trivial dimension removed.
360 if (s_partitions != 1) {
361 return absl::nullopt;
362 }
363 target_dims_stack.push_back(t_size);
364 } else if (s_size > t_size) {
365 // Dimension split.
366 if (s_size % t_size != 0 || s_size % s_partitions != 0) {
367 return absl::nullopt;
368 }
369 if (t_size % s_partitions == 0) {
370 target_tile_assignment_dimensions.push_back(s_partitions);
371 // We have part of the s_size unprocessed, so put it back to stack.
372 source_dims_stack.push_back(s_size / t_size);
373 sharding_tile_dims_stack.push_back(1);
374 } else if (s_partitions % t_size == 0) {
375 target_tile_assignment_dimensions.push_back(t_size);
376 // We have part of the s_size unprocessed, so put it back to stack.
377 source_dims_stack.push_back(s_size / t_size);
378 sharding_tile_dims_stack.push_back(s_partitions / t_size);
379 } else {
380 return absl::nullopt;
381 }
382 } else {
383 // Dimension merge. Also merge the source dimension with the next, and
384 // process it next time.
385 if (s_size % s_partitions != 0) {
386 return absl::nullopt;
387 }
388 CHECK(!source_dims_stack.empty());
389 if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
390 // If the next dimension to combine is sharded, we require that the
391 // current dimension's shard size to be 1. Otherwise, the new shard
392 // would be non-contiguous.
393 return absl::nullopt;
394 }
395 source_dims_stack.back() *= s_size;
396 sharding_tile_dims_stack.back() *= s_partitions;
397 target_dims_stack.push_back(t_size);
398 }
399 }
400 Array<int64> new_tile_assignment = sharding.tile_assignment();
401 if (sharding.ReplicateOnLastTileDim()) {
402 target_tile_assignment_dimensions.push_back(
403 sharding.tile_assignment().dimensions().back());
404 }
405 new_tile_assignment.Reshape(target_tile_assignment_dimensions);
406 return sharding.ReplicateOnLastTileDim()
407 ? HloSharding::PartialTile(new_tile_assignment,
408 sharding.metadata())
409 : HloSharding::Tile(new_tile_assignment, sharding.metadata());
410 }
411
ReverseSharding(const HloSharding & sharding,absl::Span<const int64> dimensions)412 HloSharding ReverseSharding(const HloSharding& sharding,
413 absl::Span<const int64> dimensions) {
414 if (sharding.IsTileMaximal() || dimensions.empty()) {
415 return sharding;
416 }
417
418 Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
419 new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
420 std::vector<int64> original_indices(indices.begin(), indices.end());
421 for (int64 d : dimensions) {
422 original_indices[d] =
423 new_tile_assignment.dim(d) - 1 - original_indices[d];
424 }
425 *device = sharding.tile_assignment()(original_indices);
426 });
427 return sharding.ReplicateOnLastTileDim()
428 ? HloSharding::PartialTile(new_tile_assignment,
429 sharding.metadata())
430 : HloSharding::Tile(new_tile_assignment, sharding.metadata());
431 }
432
ReshapeToTileDimension(const HloSharding & sharding,int64 dim,absl::Span<const int64> dims)433 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
434 absl::Span<const int64> dims) {
435 CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
436 CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
437 // We optimize the tile assignment on the single dimension dim in a way to
438 // minimize communication among devices caused by the reshard:
439 // +---+---+ +---+---+ +-+-+-+-+
440 // | | | | 0 | | | | | |
441 // | 0 | 1 | +-------+ | | | | |
442 // | | | reshape on | 1 | reshape on | | | | |
443 // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
444 // | | | | 2 | | | | | |
445 // | 2 | 3 | +-------+ | | | | |
446 // | | | | 3 | | | | | |
447 // +---+---+ +---+---+ +-+-+-+-+
448
449 std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
450 // Handle ignore dimensions.
451 std::vector<int64> ignore_sizes;
452 int64 ignore_size = 1;
453 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
454 if (absl::c_find(dims, i) == dims.end()) {
455 int64 size = sharding.tile_assignment().dim(i);
456 ignore_sizes.push_back(size);
457 tile_dims[i] = size;
458 ignore_size *= size;
459 }
460 }
461
462 using Buckets = std::vector<std::vector<int64>>;
463 Array<Buckets> buckets(ignore_sizes,
464 Buckets(sharding.tile_assignment().dim(dim)));
465 sharding.tile_assignment().Each(
466 [&](absl::Span<const int64> index, int64 device) {
467 std::vector<int64> ignore_index;
468 for (int64 i = 0; i < index.size(); ++i) {
469 if (absl::c_find(dims, i) == dims.end()) {
470 ignore_index.push_back(index[i]);
471 }
472 }
473 buckets(ignore_index)[index[dim]].push_back(device);
474 });
475 std::vector<int64> devices;
476 buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
477 for (auto& bucket : buckets) {
478 devices.insert(devices.end(), bucket.begin(), bucket.end());
479 }
480 });
481 tile_dims[dim] = devices.size() / ignore_size;
482 Array<int64> tile_assignment(tile_dims);
483 tile_assignment.SetValues(devices);
484 return HloSharding::Tile(tile_assignment, sharding.metadata());
485 }
486
ContainsTileSharding(const HloModule & module)487 bool ContainsTileSharding(const HloModule& module) {
488 for (const HloComputation* computation : module.computations()) {
489 for (const HloInstruction* instruction : computation->instructions()) {
490 if (instruction->has_sharding() &&
491 !instruction->sharding().IsTileMaximal()) {
492 return true;
493 }
494 }
495 }
496 return false;
497 }
498
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)499 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
500 const HloInstruction* hlo) {
501 if (index_sharding.IsTileMaximal()) {
502 return index_sharding;
503 }
504
505 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
506 std::vector<int64> output_tile_assignment_dims;
507 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
508 if (absl::c_binary_search(dnums.offset_dims(), i)) {
509 output_tile_assignment_dims.push_back(1);
510 } else {
511 const int64 new_tile_dimension =
512 index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
513 output_tile_assignment_dims.push_back(
514 index_sharding.tile_assignment().dim(new_tile_dimension));
515 ++index_dim;
516 }
517 }
518
519 if (index_sharding.ReplicateOnLastTileDim()) {
520 output_tile_assignment_dims.push_back(
521 index_sharding.tile_assignment().dimensions().back());
522 }
523
524 Array<int64> new_tile_assignment = index_sharding.tile_assignment();
525 if (new_tile_assignment.num_elements() !=
526 Product(output_tile_assignment_dims)) {
527 return HloSharding::Replicate(index_sharding.metadata());
528 }
529 new_tile_assignment.Reshape(output_tile_assignment_dims);
530 return index_sharding.ReplicateOnLastTileDim()
531 ? HloSharding::PartialTile(new_tile_assignment,
532 index_sharding.metadata())
533 : HloSharding::Tile(new_tile_assignment,
534 index_sharding.metadata());
535 }
536
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)537 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
538 const HloInstruction* hlo) {
539 CHECK(hlo->opcode() == HloOpcode::kGather);
540 if (output_sharding.IsTileMaximal()) {
541 return output_sharding;
542 }
543
544 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
545 std::vector<int64> index_tile_assignment_dims;
546 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
547 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
548 index_tile_assignment_dims.push_back(
549 output_sharding.tile_assignment().dim(i));
550 }
551 }
552 int64 index_rank = hlo->operand(1)->shape().rank();
553
554 // Vector indices sharding is not supported yet.
555 if (index_rank > index_tile_assignment_dims.size()) {
556 index_tile_assignment_dims.insert(
557 index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
558 }
559
560 int64 partial_replication_size = 1;
561 if (output_sharding.ReplicateOnLastTileDim()) {
562 partial_replication_size *=
563 output_sharding.tile_assignment().dimensions().back();
564 }
565
566 Array<int64> new_tile_assignment = output_sharding.tile_assignment();
567 const int64 index_tile_elements =
568 Product(index_tile_assignment_dims) * partial_replication_size;
569 if (new_tile_assignment.num_elements() != index_tile_elements) {
570 if (new_tile_assignment.num_elements() % index_tile_elements == 0) {
571 partial_replication_size *=
572 (new_tile_assignment.num_elements() / index_tile_elements);
573 } else {
574 return HloSharding::Replicate(output_sharding.metadata());
575 }
576 }
577 if (partial_replication_size > 1) {
578 index_tile_assignment_dims.push_back(partial_replication_size);
579 }
580 new_tile_assignment.Reshape(index_tile_assignment_dims);
581 return partial_replication_size > 1
582 ? HloSharding::PartialTile(new_tile_assignment,
583 output_sharding.metadata())
584 : HloSharding::Tile(new_tile_assignment,
585 output_sharding.metadata());
586 }
587
GatherEffectiveOutputSharding(const HloInstruction & hlo)588 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
589 if (hlo.sharding().IsTileMaximal()) {
590 return hlo.sharding();
591 }
592
593 const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
594 std::vector<int64> tile_assignment_dims(hlo.shape().rank());
595 int64 num_elements = 1;
596 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
597 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
598 tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
599 num_elements *= hlo.sharding().tile_assignment().dim(i);
600 } else {
601 tile_assignment_dims[i] = 1;
602 }
603 }
604 if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
605 // Output sharding is only on non offset dimensions. We use output sharding
606 // to shard this gather op directly.
607 return hlo.sharding();
608 }
609
610 if (num_elements == 1) {
611 // Output sharding is only on offset dimensions. We do not shard this gather
612 // op. Return a tile maximal sharding with the first device in output
613 // sharding tile assignment.
614 return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
615 hlo.sharding().metadata());
616 }
617
618 // Output sharding is on both offset and non offset dimensions. We shard the
619 // gather op only on non offset dimensions.
620 // For example:
621 // - the gather op has sharding [2,2]{0,1,2,3},
622 // - first dimension is non offset dimension,
623 // - second dimension is offset dimension,
624 // Then the result sharding will be [2,1]{0,2}.
625 std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
626 slice_limits(hlo.shape().rank());
627 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
628 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
629 slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
630 } else {
631 slice_limits[i] = 1;
632 }
633 }
634 Array<int64> tile_assignment =
635 hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
636 return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
637 }
638
ScatterIndexSharding(const HloSharding & data_sharding,const HloInstruction * hlo)639 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
640 const HloInstruction* hlo) {
641 if (data_sharding.IsTileMaximal()) {
642 return data_sharding;
643 }
644
645 const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
646 std::vector<int64> index_tile_assignment_dims;
647 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
648 if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
649 index_tile_assignment_dims.push_back(
650 data_sharding.tile_assignment().dim(i));
651 }
652 }
653 if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
654 index_tile_assignment_dims.push_back(1);
655 }
656 if (data_sharding.ReplicateOnLastTileDim()) {
657 index_tile_assignment_dims.push_back(
658 data_sharding.tile_assignment().dimensions().back());
659 }
660 Array<int64> new_tile_assignment = data_sharding.tile_assignment();
661 if (new_tile_assignment.num_elements() !=
662 Product(index_tile_assignment_dims)) {
663 return HloSharding::Replicate(data_sharding.metadata());
664 }
665 new_tile_assignment.Reshape(index_tile_assignment_dims);
666 return data_sharding.ReplicateOnLastTileDim()
667 ? HloSharding::PartialTile(new_tile_assignment,
668 data_sharding.metadata())
669 : HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
670 }
671
ScatterDataSharding(const HloSharding & index_sharding,const HloInstruction * hlo)672 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
673 const HloInstruction* hlo) {
674 if (index_sharding.IsTileMaximal()) {
675 return index_sharding;
676 }
677
678 const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
679 std::vector<int64> data_tile_assignment_dims;
680 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
681 if (absl::c_binary_search(dnums.update_window_dims(), i)) {
682 data_tile_assignment_dims.push_back(1);
683 } else {
684 data_tile_assignment_dims.push_back(
685 index_sharding.tile_assignment().dim(index_dim));
686 index_dim++;
687 }
688 }
689 if (index_sharding.ReplicateOnLastTileDim()) {
690 data_tile_assignment_dims.push_back(
691 index_sharding.tile_assignment().dimensions().back());
692 }
693 Array<int64> new_tile_assignment = index_sharding.tile_assignment();
694 if (new_tile_assignment.num_elements() !=
695 Product(data_tile_assignment_dims)) {
696 return HloSharding::Replicate(index_sharding.metadata());
697 }
698 new_tile_assignment.Reshape(data_tile_assignment_dims);
699 return index_sharding.ReplicateOnLastTileDim()
700 ? HloSharding::PartialTile(new_tile_assignment,
701 index_sharding.metadata())
702 : HloSharding::Tile(new_tile_assignment,
703 index_sharding.metadata());
704 }
705
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloInstruction & hlo)706 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
707 const HloInstruction& hlo) {
708 if (index_sharding.IsTileMaximal()) {
709 return index_sharding;
710 }
711
712 // Only shard on first "number of scatter_window_dims" dimensions.
713 const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
714 int64 num_elements = 1;
715 int64 index_dim = 0;
716 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
717 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
718 num_elements *= index_sharding.tile_assignment().dim(index_dim);
719 index_dim++;
720 }
721 }
722 if (num_elements == index_sharding.tile_assignment().num_elements()) {
723 // Index sharding is only on scatter_window_dims. We use this index sharding
724 // directly.
725 return index_sharding;
726 }
727
728 // Index sharding is only on update_window_dims. We do not shard this scatter
729 // op. Return a tile maximal sharding with the first device in index sharding
730 // tile assignment.
731 if (num_elements == 1) {
732 return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
733 index_sharding.metadata());
734 }
735
736 const int64 index_rank = hlo.operand(1)->shape().rank();
737 std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
738 for (int64 i = 0; i < index_rank; ++i) {
739 if (i < index_dim) {
740 slice_limits[i] = index_sharding.tile_assignment().dim(i);
741 } else {
742 slice_limits[i] = 1;
743 }
744 }
745 Array<int64> tile_assignment =
746 index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
747 return HloSharding::Tile(tile_assignment, index_sharding.metadata());
748 }
749
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloInstruction & hlo)750 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
751 const HloInstruction& hlo) {
752 if (data_sharding.IsTileMaximal()) {
753 return data_sharding;
754 }
755
756 const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
757 const int64 data_rank = hlo.operand(2)->shape().rank();
758 std::vector<int64> tile_assignment_dims(data_rank, 1LL);
759 int64 num_elements = 1;
760 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
761 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
762 CHECK_LT(i, data_rank);
763 tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
764 num_elements *= data_sharding.tile_assignment().dim(i);
765 }
766 }
767 if (num_elements == data_sharding.tile_assignment().num_elements()) {
768 // Data sharding is only on scatter_window_dims. We use this data sharding
769 // directly.
770 return data_sharding;
771 }
772
773 if (num_elements == 1) {
774 // Data sharding is only on update_window_dims. We do not shard this
775 // scatter op. Return a tile maximal sharding with the first device in
776 // data sharding tile assignment.
777 return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
778 data_sharding.metadata());
779 }
780
781 // Data sharding is on both update_window_dims and scatter_window_dims. We
782 // shard the scatter op only on scatter_window_dims. For example:
783 // - the scatter data has sharding [2,2]{0,1,2,3},
784 // - first dimension is scatter_window_dims,
785 // - second dimension is update_window_dims,
786 // Then the result sharding will be [2,1]{0,2}.
787 std::vector<int64> slice_starts(data_rank, 0LL);
788 Array<int64> tile_assignment =
789 data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
790 return HloSharding::Tile(tile_assignment, data_sharding.metadata());
791 }
792
793 namespace {
794
795 // If partitioning in the operand only happens in dimensions in passthrough
796 // dimensions (offset dimensions in the gather output (or scatter update) that
797 // have the same size as the operand), returns the corresponding output (or
798 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)799 absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
800 const Shape& operand_shape, const HloSharding& operand_sharding,
801 const Shape& update_or_gather_shape,
802 absl::Span<const int64> collapsed_or_inserted_dims,
803 absl::Span<const int64> index_map,
804 absl::Span<const int64> offset_or_window_dims,
805 absl::Span<const int64> slice_size) {
806 if (operand_sharding.IsTileMaximal()) {
807 return operand_sharding;
808 }
809 std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
810 int64 collapsed = 0;
811 for (int64 i = 0; i < operand_shape.rank(); ++i) {
812 int64 dim_partitions = operand_sharding.tile_assignment().dim(i);
813 if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
814 absl::c_linear_search(index_map, i)) {
815 if (dim_partitions > 1) {
816 return absl::nullopt;
817 }
818 collapsed++;
819 continue;
820 }
821 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
822 return absl::nullopt;
823 }
824 int64 offset_dim = offset_or_window_dims[i - collapsed];
825 if (i - collapsed > 0 &&
826 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
827 // Output offsets are transposed, we do not support this case.
828 return absl::nullopt;
829 }
830 passthrough_tile[offset_dim] = dim_partitions;
831 }
832 if (operand_sharding.ReplicateOnLastTileDim()) {
833 passthrough_tile.push_back(
834 operand_sharding.tile_assignment().dimensions().back());
835 }
836 Array<int64> tile_assignment = operand_sharding.tile_assignment();
837 tile_assignment.Reshape(passthrough_tile);
838 return operand_sharding.ReplicateOnLastTileDim()
839 ? HloSharding::PartialTile(tile_assignment,
840 operand_sharding.metadata())
841 : HloSharding::Tile(tile_assignment, operand_sharding.metadata());
842 }
843
844 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)845 absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
846 const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
847 absl::Span<const int64> collapsed_or_inserted_dims,
848 absl::Span<const int64> index_map,
849 absl::Span<const int64> offset_or_window_dims,
850 absl::Span<const int64> slice_size) {
851 if (update_or_gather_sharding.IsTileMaximal()) {
852 return update_or_gather_sharding;
853 }
854 std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
855 int64 collapsed = 0;
856 for (int64 i = 0; i < operand_shape.rank(); ++i) {
857 if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
858 absl::c_linear_search(index_map, i)) {
859 collapsed++;
860 continue;
861 }
862 int64 offset_dim = offset_or_window_dims[i - collapsed];
863 int64 dim_partitions =
864 update_or_gather_sharding.tile_assignment().dim(offset_dim);
865 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
866 return absl::nullopt;
867 }
868 if (i - collapsed > 0 &&
869 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
870 // Output offsets are transposed, we do not support this case.
871 return absl::nullopt;
872 }
873 passthrough_tile[i] = dim_partitions;
874 }
875
876 if (update_or_gather_sharding.ReplicateOnLastTileDim()) {
877 passthrough_tile.push_back(
878 update_or_gather_sharding.tile_assignment().dimensions().back());
879 }
880 Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
881 if (tile_assignment.num_elements() != Product(passthrough_tile)) {
882 return absl::nullopt;
883 }
884 tile_assignment.Reshape(passthrough_tile);
885 return update_or_gather_sharding.ReplicateOnLastTileDim()
886 ? HloSharding::PartialTile(tile_assignment,
887 update_or_gather_sharding.metadata())
888 : HloSharding::Tile(tile_assignment,
889 update_or_gather_sharding.metadata());
890 }
891
892 // Collect data operand sharding for a gather with parallel dimensions from
893 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)894 absl::optional<HloSharding> GatherParallelDataOperandSharding(
895 const HloSharding& output_sharding, const HloInstruction& gather,
896 const GatherParallelDims& parallel_dims) {
897 if (output_sharding.IsTileMaximal()) {
898 return output_sharding;
899 }
900 auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
901 auto output_aligned_operand_parallel_dims =
902 GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
903 const Shape gather_shape = gather.shape();
904 CHECK_EQ(output_parallel_dims.size(),
905 output_aligned_operand_parallel_dims.size());
906 std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
907 1);
908 for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
909 if (parallel_idx >= output_parallel_dims.size() ||
910 output_parallel_dims[parallel_idx] != i) {
911 continue;
912 }
913 const int64 operand_dim =
914 output_aligned_operand_parallel_dims[parallel_idx++];
915 operand_tile_assignment[operand_dim] =
916 output_sharding.tile_assignment().dim(i);
917 }
918 int64 partially_replicated_size = 1;
919 if (output_sharding.ReplicateOnLastTileDim()) {
920 partially_replicated_size *=
921 output_sharding.tile_assignment().dimensions().back();
922 }
923 Array<int64> tile_assignment = output_sharding.tile_assignment();
924 const int64 operand_tile_elements =
925 Product(operand_tile_assignment) * partially_replicated_size;
926 if (tile_assignment.num_elements() != operand_tile_elements) {
927 if (tile_assignment.num_elements() % operand_tile_elements == 0) {
928 partially_replicated_size *=
929 (tile_assignment.num_elements() / operand_tile_elements);
930 } else {
931 return absl::nullopt;
932 }
933 }
934 if (partially_replicated_size > 1) {
935 operand_tile_assignment.push_back(partially_replicated_size);
936 }
937 tile_assignment.Reshape(operand_tile_assignment);
938 return partially_replicated_size > 1
939 ? HloSharding::PartialTile(tile_assignment,
940 output_sharding.metadata())
941 : HloSharding::Tile(tile_assignment, output_sharding.metadata());
942 }
943
944 } // namespace
945
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,const Shape & output_shape,const Shape & operand_shape)946 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
947 const HloSharding& data_operand_sharding, const HloInstruction& hlo,
948 const Shape& output_shape, const Shape& operand_shape) {
949 const auto& dnums = hlo.gather_dimension_numbers();
950 std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
951 dnums.collapsed_slice_dims().end());
952 std::vector<int64> start_index_map(dnums.start_index_map().begin(),
953 dnums.start_index_map().end());
954 std::vector<int64> offset_dims(dnums.offset_dims().begin(),
955 dnums.offset_dims().end());
956 return PassthroughOperandToGatherOutputOrScatterUpdate(
957 operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
958 start_index_map, offset_dims, hlo.gather_slice_sizes());
959 }
960
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)961 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
962 const HloSharding& output_sharding, const HloInstruction& hlo) {
963 const auto& dnums = hlo.gather_dimension_numbers();
964 std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
965 dnums.collapsed_slice_dims().end());
966 std::vector<int64> start_index_map(dnums.start_index_map().begin(),
967 dnums.start_index_map().end());
968 std::vector<int64> offset_dims(dnums.offset_dims().begin(),
969 dnums.offset_dims().end());
970
971 absl::optional<HloSharding> parallel_sharding;
972 auto parallel_dims = GetGatherBatchParallelDims(hlo);
973 absl::Span<const int64> operand_parallel_dims;
974 if (parallel_dims) {
975 // Prioritize parallel sharding first as this is how it is in
976 // spmd_partitioner.
977 parallel_sharding =
978 GatherParallelDataOperandSharding(hlo.sharding(), hlo, *parallel_dims);
979 operand_parallel_dims = parallel_dims->operand_parallel_dims;
980 }
981 HloSharding filtered_output_sharding = PartiallyReplicateTiledShardingOnDims(
982 output_sharding, operand_parallel_dims);
983 absl::optional<HloSharding> passthrough_sharding =
984 PassthroughGatherOutputOrScatterUpdateToOperand(
985 hlo.operand(0)->shape(), filtered_output_sharding,
986 collapsed_slice_dims, start_index_map, offset_dims,
987 hlo.gather_slice_sizes());
988 // Try to merge the two shardings or return the one that is present if only
989 // one of the two is.
990 if (!passthrough_sharding) {
991 return parallel_sharding;
992 }
993 if (!parallel_sharding) {
994 return passthrough_sharding;
995 }
996 if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
997 /*may_combine_partial_sharding=*/true)) {
998 return passthrough_sharding;
999 }
1000 if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1001 /*may_combine_partial_sharding=*/true)) {
1002 return parallel_sharding;
1003 }
1004 return absl::nullopt;
1005 }
1006
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloInstruction & hlo)1007 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
1008 const HloSharding& update_sharding, const HloInstruction& hlo) {
1009 const auto& dnums = hlo.scatter_dimension_numbers();
1010 std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1011 dnums.inserted_window_dims().end());
1012 std::vector<int64> scatter_dims_to_operand_dims(
1013 dnums.scatter_dims_to_operand_dims().begin(),
1014 dnums.scatter_dims_to_operand_dims().end());
1015 std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1016 dnums.update_window_dims().end());
1017 std::vector<int64> slice_size(hlo.shape().rank(), 1);
1018 int64 num_update_window_dims = 0;
1019 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
1020 if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1021 continue;
1022 }
1023 slice_size[i] = hlo.operand(2)->shape().dimensions(
1024 dnums.update_window_dims(num_update_window_dims++));
1025 }
1026 return PassthroughGatherOutputOrScatterUpdateToOperand(
1027 hlo.shape(), update_sharding, inserted_window_dims,
1028 scatter_dims_to_operand_dims, update_window_dims, slice_size);
1029 }
1030
ScatterUpdateShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1031 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
1032 const HloSharding& output_sharding, const HloInstruction& hlo) {
1033 const auto& dnums = hlo.scatter_dimension_numbers();
1034 std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1035 dnums.inserted_window_dims().end());
1036 std::vector<int64> scatter_dims_to_operand_dims(
1037 dnums.scatter_dims_to_operand_dims().begin(),
1038 dnums.scatter_dims_to_operand_dims().end());
1039 std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1040 dnums.update_window_dims().end());
1041 std::vector<int64> slice_size(hlo.shape().rank(), 1);
1042 int64 num_update_window_dims = 0;
1043 for (int64 i = 0; i < hlo.shape().rank(); ++i) {
1044 if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1045 continue;
1046 }
1047 slice_size[i] = hlo.operand(2)->shape().dimensions(
1048 dnums.update_window_dims(num_update_window_dims++));
1049 }
1050 return PassthroughOperandToGatherOutputOrScatterUpdate(
1051 hlo.shape(), output_sharding, hlo.operand(2)->shape(),
1052 inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
1053 slice_size);
1054 }
1055
1056 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1057 IdentityValueAndHloOpcodeForScatterReduceComputation(
1058 const HloScatterInstruction& scatter) {
1059 auto computation = scatter.to_apply();
1060 // We only handle computations with 2 parameters and only 1 calculation.
1061 if (computation->instruction_count() != 3) {
1062 return Status(
1063 tensorflow::error::Code::INVALID_ARGUMENT,
1064 "Expected scatter reduce computation with 2 parameters and only 1 "
1065 "calculation");
1066 }
1067
1068 auto root_instruction = computation->root_instruction();
1069 if (root_instruction->opcode() == HloOpcode::kAdd ||
1070 root_instruction->opcode() == HloOpcode::kOr) {
1071 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1072 scatter.shape().element_type())),
1073 root_instruction->opcode());
1074 } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1075 root_instruction->opcode() == HloOpcode::kAnd) {
1076 return std::make_pair(HloInstruction::CreateConstant(
1077 LiteralUtil::One(scatter.shape().element_type())),
1078 root_instruction->opcode());
1079 } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1080 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1081 scatter.shape().element_type())),
1082 root_instruction->opcode());
1083 } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1084 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1085 scatter.shape().element_type())),
1086 root_instruction->opcode());
1087 }
1088
1089 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1090 "Expected scatter reduce computation which is "
1091 "add/or/multiply/add/min/max");
1092 }
1093
1094 namespace {
1095
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64> & available_devices,absl::flat_hash_set<int64> * used)1096 void DevicesForShardingInternal(
1097 const HloSharding& sharding,
1098 const absl::flat_hash_set<int64>& available_devices,
1099 absl::flat_hash_set<int64>* used) {
1100 if (sharding.IsTuple()) {
1101 for (const auto& subsharding : sharding.tuple_elements()) {
1102 DevicesForShardingInternal(subsharding, available_devices, used);
1103 }
1104 return;
1105 }
1106
1107 if (sharding.IsReplicated()) {
1108 for (int64 device : available_devices) {
1109 if (!HloSharding::IsReservedDevice(device)) {
1110 used->insert(device);
1111 }
1112 }
1113 return;
1114 }
1115
1116 DCHECK(std::all_of(
1117 sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1118 [&](int64 device) { return available_devices.contains(device); }));
1119 sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
1120 int64 device) { used->insert(device); });
1121 }
1122
1123 } // namespace
1124
DevicesForSharding(const HloSharding & sharding,const std::vector<int64> & available_devices)1125 std::vector<int64> DevicesForSharding(
1126 const HloSharding& sharding, const std::vector<int64>& available_devices) {
1127 absl::flat_hash_set<int64> available_set;
1128 for (int64 device : available_devices) {
1129 available_set.insert(device);
1130 }
1131 absl::flat_hash_set<int64> used_set;
1132 DevicesForShardingInternal(sharding, available_set, &used_set);
1133 std::vector<int64> devices;
1134 for (int64 device : available_devices) {
1135 if (used_set.contains(device)) {
1136 devices.push_back(device);
1137 }
1138 }
1139 return devices;
1140 }
1141
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64> dims_to_replicate)1142 HloSharding PartiallyReplicateTiledShardingOnDims(
1143 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) {
1144 if (sharding.IsTileMaximal()) {
1145 return sharding;
1146 }
1147 int64 group_count = 1;
1148 for (int64 dim : dims_to_replicate) {
1149 if (sharding.ReplicateOnLastTileDim()) {
1150 CHECK_LT(dim, sharding.tile_assignment().num_dimensions());
1151 }
1152 group_count *= sharding.tile_assignment().dim(dim);
1153 }
1154 if (group_count == 1) {
1155 return sharding;
1156 }
1157 if (group_count == sharding.NumTiles()) {
1158 return HloSharding::Replicate(sharding.metadata());
1159 }
1160 std::vector<int64> dim_permutation(
1161 sharding.tile_assignment().num_dimensions());
1162 std::iota(dim_permutation.begin(), dim_permutation.end(), 0);
1163 absl::c_sort(dim_permutation, [&](const int64 a, const int64 b) {
1164 return absl::c_linear_search(dims_to_replicate, a) <
1165 absl::c_linear_search(dims_to_replicate, b);
1166 });
1167 auto transposed = TransposeSharding(sharding, dim_permutation);
1168 auto new_tile = transposed.tile_assignment();
1169 std::vector<int64> new_tile_shape(
1170 sharding.tile_assignment().dimensions().begin(),
1171 sharding.tile_assignment().dimensions().end());
1172 for (int64 dim : dims_to_replicate) {
1173 new_tile_shape[dim] = 1;
1174 }
1175 if (sharding.ReplicateOnLastTileDim()) {
1176 new_tile_shape.back() *= group_count;
1177 } else {
1178 new_tile_shape.push_back(group_count);
1179 }
1180 new_tile.Reshape(new_tile_shape);
1181 return HloSharding::PartialTile(new_tile, sharding.metadata());
1182 }
1183
RemoveShapeDimensions(const HloSharding & sharding,const std::vector<int64> & dims_to_remove)1184 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1185 const std::vector<int64>& dims_to_remove) {
1186 if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1187 return sharding;
1188 }
1189 std::vector<int64> new_tile_shape;
1190 new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1191 dims_to_remove.size());
1192 for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1193 if (absl::c_linear_search(dims_to_remove, i)) {
1194 CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1195 } else {
1196 new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1197 }
1198 }
1199 auto new_tile = sharding.tile_assignment();
1200 new_tile.Reshape(new_tile_shape);
1201 return sharding.ReplicateOnLastTileDim()
1202 ? HloSharding::PartialTile(new_tile, sharding.metadata())
1203 : HloSharding::Tile(new_tile, sharding.metadata());
1204 }
1205
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64 const> src_to_tgt,absl::Span<int64 const> tgt_to_src)1206 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
1207 const HloSharding& source, absl::Span<int64 const> src_to_tgt,
1208 absl::Span<int64 const> tgt_to_src) {
1209 if (source.IsTileMaximal()) {
1210 return source;
1211 }
1212 if (source.ReplicateOnLastTileDim() &&
1213 src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1214 std::vector<int64> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1215 new_src_to_tgt.push_back(tgt_to_src.size());
1216 std::vector<int64> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1217 new_tgt_to_src.push_back(src_to_tgt.size());
1218 return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1219 new_tgt_to_src);
1220 }
1221 std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1222 int64 skipped_tgt_dims = 0;
1223 for (int64 i = 0; i < tgt_to_src.size(); ++i) {
1224 if (tgt_to_src[i] < 0) {
1225 skipped_tgt_dims++;
1226 } else {
1227 tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1228 }
1229 }
1230 int64 skipped_src_dims = absl::c_count(src_to_tgt, -1);
1231 std::vector<int64> perm(src_to_tgt.size());
1232 for (int64 i = 0; i < src_to_tgt.size(); ++i) {
1233 if (src_to_tgt[i] < 0) {
1234 if (source.tile_assignment().dim(i) > 1) {
1235 return absl::nullopt;
1236 }
1237 perm[src_to_tgt.size() - skipped_src_dims] = i;
1238 skipped_src_dims--;
1239 } else {
1240 perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1241 }
1242 }
1243 auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1244 auto reshape_tiles = tgt_sharding.tile_assignment();
1245 std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
1246 for (int64 i = 0; i < tgt_tiles.size(); ++i) {
1247 if (tgt_to_src[i] >= 0) {
1248 tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
1249 }
1250 }
1251 reshape_tiles.Reshape(tgt_tiles);
1252 return source.ReplicateOnLastTileDim()
1253 ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1254 : HloSharding::Tile(reshape_tiles, source.metadata());
1255 }
1256
GetGatherBatchParallelDims(const HloInstruction & hlo)1257 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
1258 const HloInstruction& hlo) {
1259 const auto& dnums = hlo.gather_dimension_numbers();
1260 int64 index_dim = dnums.index_vector_dim();
1261 // Try to identify if there's a dimension in the indices that is monotonically
1262 // increasing with a Iota across a certain dimension. This would mean that the
1263 // access in the relative dimension indexed by this index in the operand is
1264 // parallelizable and that we can shard the operand (and the index/output)
1265 // across such dimension.
1266 // For example the pattern:
1267 // %iota.1 = iota()
1268 // %indices = concatenate(..., %iota.1, ...)
1269 // ... = gather(..., %indices)
1270 // is common for tf.reverse_sequence and would match this case.
1271 absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1272 const HloInstruction* indices = hlo.operand(1);
1273 const int num_indices = dnums.start_index_map_size();
1274 std::vector<int64> index_parallel_in_dim(num_indices, -1);
1275 // Handle cases where we concatenate pieces of the indices one at a time.
1276 if (indices->opcode() == HloOpcode::kConcatenate &&
1277 indices->concatenate_dimension() == index_dim) {
1278 int concatenated_dims = 0;
1279 for (int i = 0; i < indices->operand_count(); ++i) {
1280 const HloInstruction* op = indices->operand(i);
1281 const int64 num_indices_from_element =
1282 op->shape().dimensions_size() > index_dim
1283 ? op->shape().dimensions(index_dim)
1284 : 1;
1285 if (auto* iota = DynCast<HloIotaInstruction>(op)) {
1286 if (iota->iota_dimension() != index_dim) {
1287 for (int j = 0; j < num_indices_from_element; ++j) {
1288 index_parallel_in_dim[concatenated_dims + j] =
1289 iota->iota_dimension();
1290 }
1291 }
1292 }
1293 concatenated_dims += num_indices_from_element;
1294 }
1295 } else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
1296 if (iota->iota_dimension() != index_dim) {
1297 // This is a case of a single iota with index_dim being out of bounds.
1298 const int64 num_indices_from_element =
1299 iota->shape().dimensions_size() > index_dim
1300 ? iota->shape().dimensions(index_dim)
1301 : 1;
1302 index_parallel_in_dim.assign(num_indices_from_element,
1303 iota->iota_dimension());
1304 }
1305 }
1306 absl::InlinedVector<int64, 1> indices_parallel_dims;
1307 absl::InlinedVector<int64, 1> operand_parallel_dims;
1308 // Map the parallelizable dimension from the iota to the dimensions of the
1309 // output and the operand. These dimensions are interconnected, but between
1310 // operands and index they could have different spots in the shape because the
1311 // position of the index dimension in the operand is determined by
1312 // start_index_map.
1313 for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1314 int index_parallel_dim = index_parallel_in_dim[i];
1315 if (index_parallel_dim == -1) {
1316 continue;
1317 }
1318 if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1319 return absl::nullopt;
1320 }
1321 // Considered parallel only if the slice is of size 1 over the operand.
1322 if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1323 indices_parallel_dims.push_back(index_parallel_dim);
1324 operand_parallel_dims.push_back(dnums.start_index_map(i));
1325 } else {
1326 index_parallel_in_dim[i] = -1;
1327 }
1328 }
1329 absl::c_sort(indices_parallel_dims);
1330 if (!indices_parallel_dims.empty()) {
1331 return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1332 index_parallel_in_dim};
1333 }
1334 return absl::nullopt;
1335 }
1336
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1337 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
1338 const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1339 absl::InlinedVector<int64, 1> output_parallel_dims;
1340 auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1341 const Shape gather_shape = gather.shape();
1342 auto dnums = gather.gather_dimension_numbers();
1343 for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1344 if (absl::c_linear_search(dnums.offset_dims(), i)) {
1345 continue;
1346 }
1347 const int index_dim =
1348 idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1349 if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1350 output_parallel_dims.push_back(i);
1351 }
1352 ++idx_dim;
1353 }
1354 return output_parallel_dims;
1355 }
1356
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1357 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
1358 const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1359 absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
1360 parallel_dims.operand_parallel_dims.size(), -1);
1361 auto dnums = gather.gather_dimension_numbers();
1362 CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1363 parallel_dims.operand_parallel_dims.size());
1364 for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1365 // This is the equivalent batch dimension of the indices that corresponds
1366 // to this index dimension.
1367 const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1368 // If it's not an index that is parallel skip.
1369 if (index_parallel_dim == -1) {
1370 continue;
1371 }
1372 // This is small so just look linearly. Populate the operand parallel
1373 // dimensions based on the order of the index batch dims (which is the same
1374 // order as the output).
1375 for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1376 if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1377 const int64 operand_parallel_dim = dnums.start_index_map(i);
1378 if (operand_parallel_dim_to_output[j] == -1) {
1379 operand_parallel_dim_to_output[j] = operand_parallel_dim;
1380 }
1381 break;
1382 }
1383 }
1384 }
1385 return operand_parallel_dim_to_output;
1386 }
1387
1388 } // namespace hlo_sharding_util
1389 } // namespace xla
1390