1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/array.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 namespace hlo_sharding_util {
36 
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)37 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
38   CHECK_EQ(lhs.IsTuple(), rhs.IsTuple());
39   if (lhs.IsTuple()) {
40     // For tuples we consider lhs to have a better sharding if none of the
41     // elements are worse and at least one element is better then in rhs
42     // sharding.
43     const auto& lhs_shardings = lhs.tuple_elements();
44     const auto& rhs_shardings = rhs.tuple_elements();
45     CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
46     bool is_better = false;
47     for (int64 i = 0; i < lhs_shardings.size(); ++i) {
48       if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
49         return false;
50       }
51       if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
52         is_better = true;
53       }
54     }
55     return is_better;
56   }
57   if (!rhs.IsTileMaximal()) {
58     return lhs.NumTiles() > rhs.NumTiles();
59   } else if (!rhs.IsReplicated()) {
60     // If we are not replicated then only tiled (not tile maximal) shardings
61     // can improve us.
62     return !lhs.IsTileMaximal();
63   } else {
64     // If we are replicated then any non-replicated sharding can improve us.
65     return !lhs.IsReplicated();
66   }
67 }
68 
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)69 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
70                    bool may_combine_partial_sharding) {
71   if (old.IsTuple()) {
72     CHECK(to_merge->IsTuple());
73     bool changed = false;
74     for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
75       changed |=
76           MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
77                         may_combine_partial_sharding);
78     }
79     return changed;
80   }
81   if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
82       !to_merge->ReplicateOnLastTileDim() ||
83       old.tile_assignment().num_elements() !=
84           to_merge->tile_assignment().num_elements()) {
85     return IsShardingMoreSpecific(*to_merge, old);
86   }
87   // Combine the tile dimension sizes from new and old.
88   int64 num_devices = old.tile_assignment().num_elements();
89   std::vector<int64> new_tile_dims;
90   bool compatible = true;
91   new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
92   for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
93     int64 new_dim = to_merge->tile_assignment().dim(i);
94     int64 old_dim = old.tile_assignment().dim(i);
95     if (new_dim == 1) {
96       new_tile_dims.push_back(old_dim);
97     } else if (old_dim == 1) {
98       new_tile_dims.push_back(new_dim);
99     } else if (new_dim == old_dim) {
100       new_tile_dims.push_back(new_dim);
101     } else {
102       compatible = false;
103       break;
104     }
105   }
106   int64 replication = num_devices / Product(new_tile_dims);
107   if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
108       replication >= old.tile_assignment().dimensions().back()) {
109     return IsShardingMoreSpecific(*to_merge, old);
110   }
111   new_tile_dims.push_back(replication);
112   Array<int64> new_tile(new_tile_dims);
113   // Maps from replication group ID to sorted members.
114   absl::flat_hash_map<int64, std::set<int64>> old_group_members;
115   absl::flat_hash_map<int64, std::set<int64>> new_group_members;
116   auto get_group_index = [&](absl::Span<const int64> tile_indices,
117                              const HloSharding& sharding) {
118     int64 group_id = 0;
119     for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
120       group_id *= to_merge->tile_assignment().dim(i);
121       group_id += tile_indices[i];
122     }
123     return group_id;
124   };
125   old.tile_assignment().Each(
126       [&](absl::Span<const int64> indices, int64 device) {
127         old_group_members[get_group_index(indices, old)].insert(device);
128       });
129   to_merge->tile_assignment().Each(
130       [&](absl::Span<const int64> indices, int64 device) {
131         new_group_members[get_group_index(indices, *to_merge)].insert(device);
132       });
133   // Try to find the intersection of old and new replication groups, in
134   // order to determine the merged tile assignment.
135   new_tile.Each([&](absl::Span<const int64> indices, int64* device) {
136     if (!compatible) {
137       return;
138     }
139     std::vector<int64> old_index(indices.begin(), indices.end());
140     std::vector<int64> new_index = old_index;
141     for (int64 i = 0; i < indices.size() - 1; ++i) {
142       if (old.tile_assignment().dim(i) == 1) {
143         old_index[i] = 0;
144       }
145       if (to_merge->tile_assignment().dim(i) == 1) {
146         new_index[i] = 0;
147       }
148     }
149     int64 old_group_id = get_group_index(old_index, old);
150     int64 new_group_id = get_group_index(new_index, *to_merge);
151     if (old_group_members[old_group_id].empty() ||
152         new_group_members[new_group_id].empty()) {
153       compatible = false;
154       return;
155     }
156 
157     int64 smallest_old = *old_group_members[old_group_id].begin();
158     int64 smallest_new = *new_group_members[new_group_id].begin();
159     if (smallest_old < smallest_new) {
160       if (old_group_members[old_group_id].count(smallest_new) == 0) {
161         compatible = false;
162         return;
163       }
164       *device = smallest_new;
165     } else {
166       if (new_group_members[new_group_id].count(smallest_old) == 0) {
167         compatible = false;
168         return;
169       }
170       *device = smallest_old;
171     }
172     old_group_members[old_group_id].erase(*device);
173     new_group_members[new_group_id].erase(*device);
174   });
175   if (compatible) {
176     std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
177     merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
178     const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
179                               protobuf_util::ProtobufEqualsWrapper>
180         metadata_set(merged_metadata.begin(), merged_metadata.end());
181     absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
182                     [&metadata_set](const OpMetadata& data) {
183                       return !ContainsKey(metadata_set, data);
184                     });
185     if (replication == 1) {
186       new_tile_dims.pop_back();
187       new_tile.Reshape(new_tile_dims);
188       *to_merge = HloSharding::Tile(new_tile, merged_metadata);
189     } else {
190       *to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
191     }
192     return true;
193   }
194   return IsShardingMoreSpecific(*to_merge, old);
195 }
196 
SelectDominantDevice(const std::map<int64,int64> & device_map,int64 * top_count)197 absl::optional<int64> SelectDominantDevice(
198     const std::map<int64, int64>& device_map, int64* top_count) {
199   int64 device = 0;
200   int64 count = 0;
201   for (auto& it : device_map) {
202     if (it.second > count) {
203       count = it.second;
204       device = it.first;
205     }
206   }
207   if (top_count != nullptr) {
208     *top_count = count;
209   }
210   return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
211 }
212 
AssignComputationDevice(HloComputation * computation,int64 device)213 Status AssignComputationDevice(HloComputation* computation, int64 device) {
214   VLOG(4) << "Assigning device " << device << " to " << computation->name()
215           << " computation";
216   for (HloInstruction* instruction : computation->instructions()) {
217     if (!instruction->has_sharding()) {
218       VLOG(4) << "Assigning device " << device << " to " << instruction->name();
219       instruction->set_device_sharding(device);
220     }
221   }
222   return Status::OK();
223 }
224 
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)225 absl::optional<int64> GetMostOccurringDevice(
226     absl::Span<HloInstruction* const> instructions) {
227   std::map<int64, int64> device_map;
228   for (HloInstruction* instruction : instructions) {
229     if (instruction->has_sharding()) {
230       for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
231         // The UsedDevices() API returns a map<device, occurrence_count>.
232         device_map[it.first] += it.second;
233       }
234     }
235   }
236   return SelectDominantDevice(device_map, nullptr);
237 }
238 
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)239 StatusOr<absl::optional<int64>> GetDominantDevice(
240     absl::Span<HloComputation* const> computations, double dominant_factor) {
241   int64 instruction_count = 0;
242   std::map<int64, int64> device_map;
243   for (HloComputation* computation : computations) {
244     for (HloInstruction* instruction : computation->instructions()) {
245       int64 count = 1;
246       if (instruction->has_sharding()) {
247         for (auto& it : instruction->sharding().UsedDevices(&count)) {
248           // The UsedDevices() API returns a map<device, occurrence_count>.
249           device_map[it.first] += it.second;
250         }
251       }
252       instruction_count += count;
253     }
254   }
255   int64 count;
256   absl::optional<int64> device = SelectDominantDevice(device_map, &count);
257   absl::optional<int64> dominant_device;
258   if (device) {
259     double factor =
260         static_cast<double>(count) / static_cast<double>(instruction_count);
261     if (factor >= dominant_factor) {
262       dominant_device = device;
263     }
264   }
265   return dominant_device;
266 }
267 
TransposeSharding(const HloSharding & sharding,const std::vector<int64> & dimensions)268 HloSharding TransposeSharding(const HloSharding& sharding,
269                               const std::vector<int64>& dimensions) {
270   if (sharding.IsTileMaximal()) {
271     return sharding;
272   }
273   auto perm_dimensions = dimensions;
274   if (sharding.ReplicateOnLastTileDim() &&
275       dimensions.size() < sharding.tile_assignment().num_dimensions()) {
276     perm_dimensions.push_back(dimensions.size());
277   }
278   const int64 rank = perm_dimensions.size();
279   std::vector<int64> tile_assignment_dim(rank);
280   for (int64 i = 0; i < rank; ++i) {
281     tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]);
282   }
283   Array<int64> tile_assignment = sharding.tile_assignment();
284   tile_assignment.Reshape(tile_assignment_dim);
285   tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
286     std::vector<int64> src_indices(indices.size(), -1);
287     for (int64 i = 0; i < indices.size(); ++i) {
288       src_indices[perm_dimensions[i]] = indices[i];
289     }
290     *value = sharding.tile_assignment()(src_indices);
291   });
292   return sharding.ReplicateOnLastTileDim()
293              ? HloSharding::PartialTile(tile_assignment, sharding.metadata())
294              : HloSharding::Tile(tile_assignment, sharding.metadata());
295 }
296 
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)297 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
298                                             const Shape& target_shape,
299                                             const HloSharding& sharding) {
300   if (sharding.IsTileMaximal()) {
301     return sharding;
302   }
303 
304   // In case of a tiled sharding the reshaped sharding will be a valid if the
305   // reshape is composed from the following operations:
306   // * Adding or removing dimensions with size 1.
307   // * Merging consecutive dimensions where only the most major is sharded.
308   // * Splitting a dimension to consecutive dimensions.
309   // * Any reshaping of unsharded dimensions.
310   // Note that merge and split can happen consecutively on the same dimension,
311   // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
312   // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
313   // to make supporting such cases easy.
314   const Shape tile_shape = sharding.TileShape(source_shape);
315   std::vector<int64> target_tile_assignment_dimensions;
316   std::vector<int64> source_dims_stack(source_shape.rank());
317   std::vector<int64> target_dims_stack(target_shape.rank());
318   std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
319   for (int64 i = 0; i < source_shape.rank(); ++i) {
320     source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
321     sharding_tile_dims_stack[i] =
322         sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
323   }
324   for (int64 i = 0; i < target_shape.rank(); ++i) {
325     target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
326   }
327   while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
328     if (target_dims_stack.empty()) {
329       if (Product(sharding_tile_dims_stack) != 1) {
330         return absl::nullopt;
331       }
332       break;
333     }
334     int64 s_size = 1;
335     int64 t_size = 1;
336     int64 s_partitions = 1;
337     if (!source_dims_stack.empty()) {
338       s_size = source_dims_stack.back();
339       source_dims_stack.pop_back();
340       s_partitions = sharding_tile_dims_stack.back();
341       sharding_tile_dims_stack.pop_back();
342     }
343     t_size = target_dims_stack.back();
344     target_dims_stack.pop_back();
345     if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
346       // No more partitions left.
347       target_tile_assignment_dimensions.push_back(1);
348       continue;
349     }
350     if (s_size == t_size) {
351       // Same dimension.
352       target_tile_assignment_dimensions.push_back(s_partitions);
353     } else if (t_size == 1) {
354       // Trivial dimension added.
355       target_tile_assignment_dimensions.push_back(1);
356       source_dims_stack.push_back(s_size);
357       sharding_tile_dims_stack.push_back(s_partitions);
358     } else if (s_size == 1) {
359       // Trivial dimension removed.
360       if (s_partitions != 1) {
361         return absl::nullopt;
362       }
363       target_dims_stack.push_back(t_size);
364     } else if (s_size > t_size) {
365       // Dimension split.
366       if (s_size % t_size != 0 || s_size % s_partitions != 0) {
367         return absl::nullopt;
368       }
369       if (t_size % s_partitions == 0) {
370         target_tile_assignment_dimensions.push_back(s_partitions);
371         // We have part of the s_size unprocessed, so put it back to stack.
372         source_dims_stack.push_back(s_size / t_size);
373         sharding_tile_dims_stack.push_back(1);
374       } else if (s_partitions % t_size == 0) {
375         target_tile_assignment_dimensions.push_back(t_size);
376         // We have part of the s_size unprocessed, so put it back to stack.
377         source_dims_stack.push_back(s_size / t_size);
378         sharding_tile_dims_stack.push_back(s_partitions / t_size);
379       } else {
380         return absl::nullopt;
381       }
382     } else {
383       // Dimension merge. Also merge the source dimension with the next, and
384       // process it next time.
385       if (s_size % s_partitions != 0) {
386         return absl::nullopt;
387       }
388       CHECK(!source_dims_stack.empty());
389       if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
390         // If the next dimension to combine is sharded, we require that the
391         // current dimension's shard size to be 1. Otherwise, the new shard
392         // would be non-contiguous.
393         return absl::nullopt;
394       }
395       source_dims_stack.back() *= s_size;
396       sharding_tile_dims_stack.back() *= s_partitions;
397       target_dims_stack.push_back(t_size);
398     }
399   }
400   Array<int64> new_tile_assignment = sharding.tile_assignment();
401   if (sharding.ReplicateOnLastTileDim()) {
402     target_tile_assignment_dimensions.push_back(
403         sharding.tile_assignment().dimensions().back());
404   }
405   new_tile_assignment.Reshape(target_tile_assignment_dimensions);
406   return sharding.ReplicateOnLastTileDim()
407              ? HloSharding::PartialTile(new_tile_assignment,
408                                         sharding.metadata())
409              : HloSharding::Tile(new_tile_assignment, sharding.metadata());
410 }
411 
ReverseSharding(const HloSharding & sharding,absl::Span<const int64> dimensions)412 HloSharding ReverseSharding(const HloSharding& sharding,
413                             absl::Span<const int64> dimensions) {
414   if (sharding.IsTileMaximal() || dimensions.empty()) {
415     return sharding;
416   }
417 
418   Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
419   new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
420     std::vector<int64> original_indices(indices.begin(), indices.end());
421     for (int64 d : dimensions) {
422       original_indices[d] =
423           new_tile_assignment.dim(d) - 1 - original_indices[d];
424     }
425     *device = sharding.tile_assignment()(original_indices);
426   });
427   return sharding.ReplicateOnLastTileDim()
428              ? HloSharding::PartialTile(new_tile_assignment,
429                                         sharding.metadata())
430              : HloSharding::Tile(new_tile_assignment, sharding.metadata());
431 }
432 
ReshapeToTileDimension(const HloSharding & sharding,int64 dim,absl::Span<const int64> dims)433 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
434                                    absl::Span<const int64> dims) {
435   CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
436   CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
437   // We optimize the tile assignment on the single dimension dim in a way to
438   // minimize communication among devices caused by the reshard:
439   // +---+---+               +---+---+              +-+-+-+-+
440   // |   |   |               |   0   |              | | | | |
441   // | 0 | 1 |               +-------+              | | | | |
442   // |   |   |  reshape on   |   1   |  reshape on  | | | | |
443   // +---+---+   dim 0  =>   +-------+   dim 1  =>  |0|2|1|3|
444   // |   |   |               |   2   |              | | | | |
445   // | 2 | 3 |               +-------+              | | | | |
446   // |   |   |               |   3   |              | | | | |
447   // +---+---+               +---+---+              +-+-+-+-+
448 
449   std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
450   // Handle ignore dimensions.
451   std::vector<int64> ignore_sizes;
452   int64 ignore_size = 1;
453   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
454     if (absl::c_find(dims, i) == dims.end()) {
455       int64 size = sharding.tile_assignment().dim(i);
456       ignore_sizes.push_back(size);
457       tile_dims[i] = size;
458       ignore_size *= size;
459     }
460   }
461 
462   using Buckets = std::vector<std::vector<int64>>;
463   Array<Buckets> buckets(ignore_sizes,
464                          Buckets(sharding.tile_assignment().dim(dim)));
465   sharding.tile_assignment().Each(
466       [&](absl::Span<const int64> index, int64 device) {
467         std::vector<int64> ignore_index;
468         for (int64 i = 0; i < index.size(); ++i) {
469           if (absl::c_find(dims, i) == dims.end()) {
470             ignore_index.push_back(index[i]);
471           }
472         }
473         buckets(ignore_index)[index[dim]].push_back(device);
474       });
475   std::vector<int64> devices;
476   buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
477     for (auto& bucket : buckets) {
478       devices.insert(devices.end(), bucket.begin(), bucket.end());
479     }
480   });
481   tile_dims[dim] = devices.size() / ignore_size;
482   Array<int64> tile_assignment(tile_dims);
483   tile_assignment.SetValues(devices);
484   return HloSharding::Tile(tile_assignment, sharding.metadata());
485 }
486 
ContainsTileSharding(const HloModule & module)487 bool ContainsTileSharding(const HloModule& module) {
488   for (const HloComputation* computation : module.computations()) {
489     for (const HloInstruction* instruction : computation->instructions()) {
490       if (instruction->has_sharding() &&
491           !instruction->sharding().IsTileMaximal()) {
492         return true;
493       }
494     }
495   }
496   return false;
497 }
498 
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)499 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
500                                  const HloInstruction* hlo) {
501   if (index_sharding.IsTileMaximal()) {
502     return index_sharding;
503   }
504 
505   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
506   std::vector<int64> output_tile_assignment_dims;
507   for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
508     if (absl::c_binary_search(dnums.offset_dims(), i)) {
509       output_tile_assignment_dims.push_back(1);
510     } else {
511       const int64 new_tile_dimension =
512           index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
513       output_tile_assignment_dims.push_back(
514           index_sharding.tile_assignment().dim(new_tile_dimension));
515       ++index_dim;
516     }
517   }
518 
519   if (index_sharding.ReplicateOnLastTileDim()) {
520     output_tile_assignment_dims.push_back(
521         index_sharding.tile_assignment().dimensions().back());
522   }
523 
524   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
525   if (new_tile_assignment.num_elements() !=
526       Product(output_tile_assignment_dims)) {
527     return HloSharding::Replicate(index_sharding.metadata());
528   }
529   new_tile_assignment.Reshape(output_tile_assignment_dims);
530   return index_sharding.ReplicateOnLastTileDim()
531              ? HloSharding::PartialTile(new_tile_assignment,
532                                         index_sharding.metadata())
533              : HloSharding::Tile(new_tile_assignment,
534                                  index_sharding.metadata());
535 }
536 
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)537 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
538                                 const HloInstruction* hlo) {
539   CHECK(hlo->opcode() == HloOpcode::kGather);
540   if (output_sharding.IsTileMaximal()) {
541     return output_sharding;
542   }
543 
544   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
545   std::vector<int64> index_tile_assignment_dims;
546   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
547     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
548       index_tile_assignment_dims.push_back(
549           output_sharding.tile_assignment().dim(i));
550     }
551   }
552   int64 index_rank = hlo->operand(1)->shape().rank();
553 
554   // Vector indices sharding is not supported yet.
555   if (index_rank > index_tile_assignment_dims.size()) {
556     index_tile_assignment_dims.insert(
557         index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
558   }
559 
560   int64 partial_replication_size = 1;
561   if (output_sharding.ReplicateOnLastTileDim()) {
562     partial_replication_size *=
563         output_sharding.tile_assignment().dimensions().back();
564   }
565 
566   Array<int64> new_tile_assignment = output_sharding.tile_assignment();
567   const int64 index_tile_elements =
568       Product(index_tile_assignment_dims) * partial_replication_size;
569   if (new_tile_assignment.num_elements() != index_tile_elements) {
570     if (new_tile_assignment.num_elements() % index_tile_elements == 0) {
571       partial_replication_size *=
572           (new_tile_assignment.num_elements() / index_tile_elements);
573     } else {
574       return HloSharding::Replicate(output_sharding.metadata());
575     }
576   }
577   if (partial_replication_size > 1) {
578     index_tile_assignment_dims.push_back(partial_replication_size);
579   }
580   new_tile_assignment.Reshape(index_tile_assignment_dims);
581   return partial_replication_size > 1
582              ? HloSharding::PartialTile(new_tile_assignment,
583                                         output_sharding.metadata())
584              : HloSharding::Tile(new_tile_assignment,
585                                  output_sharding.metadata());
586 }
587 
GatherEffectiveOutputSharding(const HloInstruction & hlo)588 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
589   if (hlo.sharding().IsTileMaximal()) {
590     return hlo.sharding();
591   }
592 
593   const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
594   std::vector<int64> tile_assignment_dims(hlo.shape().rank());
595   int64 num_elements = 1;
596   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
597     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
598       tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
599       num_elements *= hlo.sharding().tile_assignment().dim(i);
600     } else {
601       tile_assignment_dims[i] = 1;
602     }
603   }
604   if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
605     // Output sharding is only on non offset dimensions. We use output sharding
606     // to shard this gather op directly.
607     return hlo.sharding();
608   }
609 
610   if (num_elements == 1) {
611     // Output sharding is only on offset dimensions. We do not shard this gather
612     // op. Return a tile maximal sharding with the first device in output
613     // sharding tile assignment.
614     return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
615                                      hlo.sharding().metadata());
616   }
617 
618   // Output sharding is on both offset and non offset dimensions. We shard the
619   // gather op only on non offset dimensions.
620   // For example:
621   // - the gather op has sharding [2,2]{0,1,2,3},
622   // - first dimension is non offset dimension,
623   // - second dimension is offset dimension,
624   // Then the result sharding will be [2,1]{0,2}.
625   std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
626       slice_limits(hlo.shape().rank());
627   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
628     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
629       slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
630     } else {
631       slice_limits[i] = 1;
632     }
633   }
634   Array<int64> tile_assignment =
635       hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
636   return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
637 }
638 
ScatterIndexSharding(const HloSharding & data_sharding,const HloInstruction * hlo)639 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
640                                  const HloInstruction* hlo) {
641   if (data_sharding.IsTileMaximal()) {
642     return data_sharding;
643   }
644 
645   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
646   std::vector<int64> index_tile_assignment_dims;
647   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
648     if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
649       index_tile_assignment_dims.push_back(
650           data_sharding.tile_assignment().dim(i));
651     }
652   }
653   if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
654     index_tile_assignment_dims.push_back(1);
655   }
656   if (data_sharding.ReplicateOnLastTileDim()) {
657     index_tile_assignment_dims.push_back(
658         data_sharding.tile_assignment().dimensions().back());
659   }
660   Array<int64> new_tile_assignment = data_sharding.tile_assignment();
661   if (new_tile_assignment.num_elements() !=
662       Product(index_tile_assignment_dims)) {
663     return HloSharding::Replicate(data_sharding.metadata());
664   }
665   new_tile_assignment.Reshape(index_tile_assignment_dims);
666   return data_sharding.ReplicateOnLastTileDim()
667              ? HloSharding::PartialTile(new_tile_assignment,
668                                         data_sharding.metadata())
669              : HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
670 }
671 
ScatterDataSharding(const HloSharding & index_sharding,const HloInstruction * hlo)672 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
673                                 const HloInstruction* hlo) {
674   if (index_sharding.IsTileMaximal()) {
675     return index_sharding;
676   }
677 
678   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
679   std::vector<int64> data_tile_assignment_dims;
680   for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
681     if (absl::c_binary_search(dnums.update_window_dims(), i)) {
682       data_tile_assignment_dims.push_back(1);
683     } else {
684       data_tile_assignment_dims.push_back(
685           index_sharding.tile_assignment().dim(index_dim));
686       index_dim++;
687     }
688   }
689   if (index_sharding.ReplicateOnLastTileDim()) {
690     data_tile_assignment_dims.push_back(
691         index_sharding.tile_assignment().dimensions().back());
692   }
693   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
694   if (new_tile_assignment.num_elements() !=
695       Product(data_tile_assignment_dims)) {
696     return HloSharding::Replicate(index_sharding.metadata());
697   }
698   new_tile_assignment.Reshape(data_tile_assignment_dims);
699   return index_sharding.ReplicateOnLastTileDim()
700              ? HloSharding::PartialTile(new_tile_assignment,
701                                         index_sharding.metadata())
702              : HloSharding::Tile(new_tile_assignment,
703                                  index_sharding.metadata());
704 }
705 
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloInstruction & hlo)706 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
707                                           const HloInstruction& hlo) {
708   if (index_sharding.IsTileMaximal()) {
709     return index_sharding;
710   }
711 
712   // Only shard on first "number of scatter_window_dims" dimensions.
713   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
714   int64 num_elements = 1;
715   int64 index_dim = 0;
716   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
717     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
718       num_elements *= index_sharding.tile_assignment().dim(index_dim);
719       index_dim++;
720     }
721   }
722   if (num_elements == index_sharding.tile_assignment().num_elements()) {
723     // Index sharding is only on scatter_window_dims. We use this index sharding
724     // directly.
725     return index_sharding;
726   }
727 
728   // Index sharding is only on update_window_dims. We do not shard this scatter
729   // op. Return a tile maximal sharding with the first device in index sharding
730   // tile assignment.
731   if (num_elements == 1) {
732     return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
733                                      index_sharding.metadata());
734   }
735 
736   const int64 index_rank = hlo.operand(1)->shape().rank();
737   std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
738   for (int64 i = 0; i < index_rank; ++i) {
739     if (i < index_dim) {
740       slice_limits[i] = index_sharding.tile_assignment().dim(i);
741     } else {
742       slice_limits[i] = 1;
743     }
744   }
745   Array<int64> tile_assignment =
746       index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
747   return HloSharding::Tile(tile_assignment, index_sharding.metadata());
748 }
749 
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloInstruction & hlo)750 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
751                                          const HloInstruction& hlo) {
752   if (data_sharding.IsTileMaximal()) {
753     return data_sharding;
754   }
755 
756   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
757   const int64 data_rank = hlo.operand(2)->shape().rank();
758   std::vector<int64> tile_assignment_dims(data_rank, 1LL);
759   int64 num_elements = 1;
760   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
761     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
762       CHECK_LT(i, data_rank);
763       tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
764       num_elements *= data_sharding.tile_assignment().dim(i);
765     }
766   }
767   if (num_elements == data_sharding.tile_assignment().num_elements()) {
768     // Data sharding is only on scatter_window_dims. We use this data sharding
769     // directly.
770     return data_sharding;
771   }
772 
773   if (num_elements == 1) {
774     // Data sharding is only on update_window_dims. We do not shard this
775     // scatter op. Return a tile maximal sharding with the first device in
776     // data sharding tile assignment.
777     return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
778                                      data_sharding.metadata());
779   }
780 
781   // Data sharding is on both update_window_dims and scatter_window_dims. We
782   // shard the scatter op only on scatter_window_dims. For example:
783   // - the scatter data has sharding [2,2]{0,1,2,3},
784   // - first dimension is scatter_window_dims,
785   // - second dimension is update_window_dims,
786   // Then the result sharding will be [2,1]{0,2}.
787   std::vector<int64> slice_starts(data_rank, 0LL);
788   Array<int64> tile_assignment =
789       data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
790   return HloSharding::Tile(tile_assignment, data_sharding.metadata());
791 }
792 
793 namespace {
794 
795 // If partitioning in the operand only happens in dimensions in passthrough
796 // dimensions (offset dimensions in the gather output (or scatter update) that
797 // have the same size as the operand), returns the corresponding output (or
798 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)799 absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
800     const Shape& operand_shape, const HloSharding& operand_sharding,
801     const Shape& update_or_gather_shape,
802     absl::Span<const int64> collapsed_or_inserted_dims,
803     absl::Span<const int64> index_map,
804     absl::Span<const int64> offset_or_window_dims,
805     absl::Span<const int64> slice_size) {
806   if (operand_sharding.IsTileMaximal()) {
807     return operand_sharding;
808   }
809   std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
810   int64 collapsed = 0;
811   for (int64 i = 0; i < operand_shape.rank(); ++i) {
812     int64 dim_partitions = operand_sharding.tile_assignment().dim(i);
813     if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
814         absl::c_linear_search(index_map, i)) {
815       if (dim_partitions > 1) {
816         return absl::nullopt;
817       }
818       collapsed++;
819       continue;
820     }
821     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
822       return absl::nullopt;
823     }
824     int64 offset_dim = offset_or_window_dims[i - collapsed];
825     if (i - collapsed > 0 &&
826         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
827       // Output offsets are transposed, we do not support this case.
828       return absl::nullopt;
829     }
830     passthrough_tile[offset_dim] = dim_partitions;
831   }
832   if (operand_sharding.ReplicateOnLastTileDim()) {
833     passthrough_tile.push_back(
834         operand_sharding.tile_assignment().dimensions().back());
835   }
836   Array<int64> tile_assignment = operand_sharding.tile_assignment();
837   tile_assignment.Reshape(passthrough_tile);
838   return operand_sharding.ReplicateOnLastTileDim()
839              ? HloSharding::PartialTile(tile_assignment,
840                                         operand_sharding.metadata())
841              : HloSharding::Tile(tile_assignment, operand_sharding.metadata());
842 }
843 
844 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)845 absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
846     const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
847     absl::Span<const int64> collapsed_or_inserted_dims,
848     absl::Span<const int64> index_map,
849     absl::Span<const int64> offset_or_window_dims,
850     absl::Span<const int64> slice_size) {
851   if (update_or_gather_sharding.IsTileMaximal()) {
852     return update_or_gather_sharding;
853   }
854   std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
855   int64 collapsed = 0;
856   for (int64 i = 0; i < operand_shape.rank(); ++i) {
857     if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
858         absl::c_linear_search(index_map, i)) {
859       collapsed++;
860       continue;
861     }
862     int64 offset_dim = offset_or_window_dims[i - collapsed];
863     int64 dim_partitions =
864         update_or_gather_sharding.tile_assignment().dim(offset_dim);
865     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
866       return absl::nullopt;
867     }
868     if (i - collapsed > 0 &&
869         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
870       // Output offsets are transposed, we do not support this case.
871       return absl::nullopt;
872     }
873     passthrough_tile[i] = dim_partitions;
874   }
875 
876   if (update_or_gather_sharding.ReplicateOnLastTileDim()) {
877     passthrough_tile.push_back(
878         update_or_gather_sharding.tile_assignment().dimensions().back());
879   }
880   Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
881   if (tile_assignment.num_elements() != Product(passthrough_tile)) {
882     return absl::nullopt;
883   }
884   tile_assignment.Reshape(passthrough_tile);
885   return update_or_gather_sharding.ReplicateOnLastTileDim()
886              ? HloSharding::PartialTile(tile_assignment,
887                                         update_or_gather_sharding.metadata())
888              : HloSharding::Tile(tile_assignment,
889                                  update_or_gather_sharding.metadata());
890 }
891 
892 // Collect data operand sharding for a gather with parallel dimensions from
893 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)894 absl::optional<HloSharding> GatherParallelDataOperandSharding(
895     const HloSharding& output_sharding, const HloInstruction& gather,
896     const GatherParallelDims& parallel_dims) {
897   if (output_sharding.IsTileMaximal()) {
898     return output_sharding;
899   }
900   auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
901   auto output_aligned_operand_parallel_dims =
902       GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
903   const Shape gather_shape = gather.shape();
904   CHECK_EQ(output_parallel_dims.size(),
905            output_aligned_operand_parallel_dims.size());
906   std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
907                                              1);
908   for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
909     if (parallel_idx >= output_parallel_dims.size() ||
910         output_parallel_dims[parallel_idx] != i) {
911       continue;
912     }
913     const int64 operand_dim =
914         output_aligned_operand_parallel_dims[parallel_idx++];
915     operand_tile_assignment[operand_dim] =
916         output_sharding.tile_assignment().dim(i);
917   }
918   int64 partially_replicated_size = 1;
919   if (output_sharding.ReplicateOnLastTileDim()) {
920     partially_replicated_size *=
921         output_sharding.tile_assignment().dimensions().back();
922   }
923   Array<int64> tile_assignment = output_sharding.tile_assignment();
924   const int64 operand_tile_elements =
925       Product(operand_tile_assignment) * partially_replicated_size;
926   if (tile_assignment.num_elements() != operand_tile_elements) {
927     if (tile_assignment.num_elements() % operand_tile_elements == 0) {
928       partially_replicated_size *=
929           (tile_assignment.num_elements() / operand_tile_elements);
930     } else {
931       return absl::nullopt;
932     }
933   }
934   if (partially_replicated_size > 1) {
935     operand_tile_assignment.push_back(partially_replicated_size);
936   }
937   tile_assignment.Reshape(operand_tile_assignment);
938   return partially_replicated_size > 1
939              ? HloSharding::PartialTile(tile_assignment,
940                                         output_sharding.metadata())
941              : HloSharding::Tile(tile_assignment, output_sharding.metadata());
942 }
943 
944 }  // namespace
945 
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,const Shape & output_shape,const Shape & operand_shape)946 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
947     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
948     const Shape& output_shape, const Shape& operand_shape) {
949   const auto& dnums = hlo.gather_dimension_numbers();
950   std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
951                                           dnums.collapsed_slice_dims().end());
952   std::vector<int64> start_index_map(dnums.start_index_map().begin(),
953                                      dnums.start_index_map().end());
954   std::vector<int64> offset_dims(dnums.offset_dims().begin(),
955                                  dnums.offset_dims().end());
956   return PassthroughOperandToGatherOutputOrScatterUpdate(
957       operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
958       start_index_map, offset_dims, hlo.gather_slice_sizes());
959 }
960 
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)961 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
962     const HloSharding& output_sharding, const HloInstruction& hlo) {
963   const auto& dnums = hlo.gather_dimension_numbers();
964   std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
965                                           dnums.collapsed_slice_dims().end());
966   std::vector<int64> start_index_map(dnums.start_index_map().begin(),
967                                      dnums.start_index_map().end());
968   std::vector<int64> offset_dims(dnums.offset_dims().begin(),
969                                  dnums.offset_dims().end());
970 
971   absl::optional<HloSharding> parallel_sharding;
972   auto parallel_dims = GetGatherBatchParallelDims(hlo);
973   absl::Span<const int64> operand_parallel_dims;
974   if (parallel_dims) {
975     // Prioritize parallel sharding first as this is how it is in
976     // spmd_partitioner.
977     parallel_sharding =
978         GatherParallelDataOperandSharding(hlo.sharding(), hlo, *parallel_dims);
979     operand_parallel_dims = parallel_dims->operand_parallel_dims;
980   }
981   HloSharding filtered_output_sharding = PartiallyReplicateTiledShardingOnDims(
982       output_sharding, operand_parallel_dims);
983   absl::optional<HloSharding> passthrough_sharding =
984       PassthroughGatherOutputOrScatterUpdateToOperand(
985           hlo.operand(0)->shape(), filtered_output_sharding,
986           collapsed_slice_dims, start_index_map, offset_dims,
987           hlo.gather_slice_sizes());
988   // Try to merge the two shardings or return the one that is present if only
989   // one of the two is.
990   if (!passthrough_sharding) {
991     return parallel_sharding;
992   }
993   if (!parallel_sharding) {
994     return passthrough_sharding;
995   }
996   if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
997                     /*may_combine_partial_sharding=*/true)) {
998     return passthrough_sharding;
999   }
1000   if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1001                     /*may_combine_partial_sharding=*/true)) {
1002     return parallel_sharding;
1003   }
1004   return absl::nullopt;
1005 }
1006 
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloInstruction & hlo)1007 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
1008     const HloSharding& update_sharding, const HloInstruction& hlo) {
1009   const auto& dnums = hlo.scatter_dimension_numbers();
1010   std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1011                                           dnums.inserted_window_dims().end());
1012   std::vector<int64> scatter_dims_to_operand_dims(
1013       dnums.scatter_dims_to_operand_dims().begin(),
1014       dnums.scatter_dims_to_operand_dims().end());
1015   std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1016                                         dnums.update_window_dims().end());
1017   std::vector<int64> slice_size(hlo.shape().rank(), 1);
1018   int64 num_update_window_dims = 0;
1019   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
1020     if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1021       continue;
1022     }
1023     slice_size[i] = hlo.operand(2)->shape().dimensions(
1024         dnums.update_window_dims(num_update_window_dims++));
1025   }
1026   return PassthroughGatherOutputOrScatterUpdateToOperand(
1027       hlo.shape(), update_sharding, inserted_window_dims,
1028       scatter_dims_to_operand_dims, update_window_dims, slice_size);
1029 }
1030 
ScatterUpdateShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1031 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
1032     const HloSharding& output_sharding, const HloInstruction& hlo) {
1033   const auto& dnums = hlo.scatter_dimension_numbers();
1034   std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1035                                           dnums.inserted_window_dims().end());
1036   std::vector<int64> scatter_dims_to_operand_dims(
1037       dnums.scatter_dims_to_operand_dims().begin(),
1038       dnums.scatter_dims_to_operand_dims().end());
1039   std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1040                                         dnums.update_window_dims().end());
1041   std::vector<int64> slice_size(hlo.shape().rank(), 1);
1042   int64 num_update_window_dims = 0;
1043   for (int64 i = 0; i < hlo.shape().rank(); ++i) {
1044     if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1045       continue;
1046     }
1047     slice_size[i] = hlo.operand(2)->shape().dimensions(
1048         dnums.update_window_dims(num_update_window_dims++));
1049   }
1050   return PassthroughOperandToGatherOutputOrScatterUpdate(
1051       hlo.shape(), output_sharding, hlo.operand(2)->shape(),
1052       inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
1053       slice_size);
1054 }
1055 
1056 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1057 IdentityValueAndHloOpcodeForScatterReduceComputation(
1058     const HloScatterInstruction& scatter) {
1059   auto computation = scatter.to_apply();
1060   // We only handle computations with 2 parameters and only 1 calculation.
1061   if (computation->instruction_count() != 3) {
1062     return Status(
1063         tensorflow::error::Code::INVALID_ARGUMENT,
1064         "Expected scatter reduce computation with 2 parameters and only 1 "
1065         "calculation");
1066   }
1067 
1068   auto root_instruction = computation->root_instruction();
1069   if (root_instruction->opcode() == HloOpcode::kAdd ||
1070       root_instruction->opcode() == HloOpcode::kOr) {
1071     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1072                               scatter.shape().element_type())),
1073                           root_instruction->opcode());
1074   } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1075              root_instruction->opcode() == HloOpcode::kAnd) {
1076     return std::make_pair(HloInstruction::CreateConstant(
1077                               LiteralUtil::One(scatter.shape().element_type())),
1078                           root_instruction->opcode());
1079   } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1080     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1081                               scatter.shape().element_type())),
1082                           root_instruction->opcode());
1083   } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1084     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1085                               scatter.shape().element_type())),
1086                           root_instruction->opcode());
1087   }
1088 
1089   return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1090                 "Expected scatter reduce computation which is "
1091                 "add/or/multiply/add/min/max");
1092 }
1093 
1094 namespace {
1095 
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64> & available_devices,absl::flat_hash_set<int64> * used)1096 void DevicesForShardingInternal(
1097     const HloSharding& sharding,
1098     const absl::flat_hash_set<int64>& available_devices,
1099     absl::flat_hash_set<int64>* used) {
1100   if (sharding.IsTuple()) {
1101     for (const auto& subsharding : sharding.tuple_elements()) {
1102       DevicesForShardingInternal(subsharding, available_devices, used);
1103     }
1104     return;
1105   }
1106 
1107   if (sharding.IsReplicated()) {
1108     for (int64 device : available_devices) {
1109       if (!HloSharding::IsReservedDevice(device)) {
1110         used->insert(device);
1111       }
1112     }
1113     return;
1114   }
1115 
1116   DCHECK(std::all_of(
1117       sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1118       [&](int64 device) { return available_devices.contains(device); }));
1119   sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
1120                                       int64 device) { used->insert(device); });
1121 }
1122 
1123 }  // namespace
1124 
DevicesForSharding(const HloSharding & sharding,const std::vector<int64> & available_devices)1125 std::vector<int64> DevicesForSharding(
1126     const HloSharding& sharding, const std::vector<int64>& available_devices) {
1127   absl::flat_hash_set<int64> available_set;
1128   for (int64 device : available_devices) {
1129     available_set.insert(device);
1130   }
1131   absl::flat_hash_set<int64> used_set;
1132   DevicesForShardingInternal(sharding, available_set, &used_set);
1133   std::vector<int64> devices;
1134   for (int64 device : available_devices) {
1135     if (used_set.contains(device)) {
1136       devices.push_back(device);
1137     }
1138   }
1139   return devices;
1140 }
1141 
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64> dims_to_replicate)1142 HloSharding PartiallyReplicateTiledShardingOnDims(
1143     const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) {
1144   if (sharding.IsTileMaximal()) {
1145     return sharding;
1146   }
1147   int64 group_count = 1;
1148   for (int64 dim : dims_to_replicate) {
1149     if (sharding.ReplicateOnLastTileDim()) {
1150       CHECK_LT(dim, sharding.tile_assignment().num_dimensions());
1151     }
1152     group_count *= sharding.tile_assignment().dim(dim);
1153   }
1154   if (group_count == 1) {
1155     return sharding;
1156   }
1157   if (group_count == sharding.NumTiles()) {
1158     return HloSharding::Replicate(sharding.metadata());
1159   }
1160   std::vector<int64> dim_permutation(
1161       sharding.tile_assignment().num_dimensions());
1162   std::iota(dim_permutation.begin(), dim_permutation.end(), 0);
1163   absl::c_sort(dim_permutation, [&](const int64 a, const int64 b) {
1164     return absl::c_linear_search(dims_to_replicate, a) <
1165            absl::c_linear_search(dims_to_replicate, b);
1166   });
1167   auto transposed = TransposeSharding(sharding, dim_permutation);
1168   auto new_tile = transposed.tile_assignment();
1169   std::vector<int64> new_tile_shape(
1170       sharding.tile_assignment().dimensions().begin(),
1171       sharding.tile_assignment().dimensions().end());
1172   for (int64 dim : dims_to_replicate) {
1173     new_tile_shape[dim] = 1;
1174   }
1175   if (sharding.ReplicateOnLastTileDim()) {
1176     new_tile_shape.back() *= group_count;
1177   } else {
1178     new_tile_shape.push_back(group_count);
1179   }
1180   new_tile.Reshape(new_tile_shape);
1181   return HloSharding::PartialTile(new_tile, sharding.metadata());
1182 }
1183 
RemoveShapeDimensions(const HloSharding & sharding,const std::vector<int64> & dims_to_remove)1184 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1185                                   const std::vector<int64>& dims_to_remove) {
1186   if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1187     return sharding;
1188   }
1189   std::vector<int64> new_tile_shape;
1190   new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1191                          dims_to_remove.size());
1192   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1193     if (absl::c_linear_search(dims_to_remove, i)) {
1194       CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1195     } else {
1196       new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1197     }
1198   }
1199   auto new_tile = sharding.tile_assignment();
1200   new_tile.Reshape(new_tile_shape);
1201   return sharding.ReplicateOnLastTileDim()
1202              ? HloSharding::PartialTile(new_tile, sharding.metadata())
1203              : HloSharding::Tile(new_tile, sharding.metadata());
1204 }
1205 
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64 const> src_to_tgt,absl::Span<int64 const> tgt_to_src)1206 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
1207     const HloSharding& source, absl::Span<int64 const> src_to_tgt,
1208     absl::Span<int64 const> tgt_to_src) {
1209   if (source.IsTileMaximal()) {
1210     return source;
1211   }
1212   if (source.ReplicateOnLastTileDim() &&
1213       src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1214     std::vector<int64> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1215     new_src_to_tgt.push_back(tgt_to_src.size());
1216     std::vector<int64> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1217     new_tgt_to_src.push_back(src_to_tgt.size());
1218     return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1219                                               new_tgt_to_src);
1220   }
1221   std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1222   int64 skipped_tgt_dims = 0;
1223   for (int64 i = 0; i < tgt_to_src.size(); ++i) {
1224     if (tgt_to_src[i] < 0) {
1225       skipped_tgt_dims++;
1226     } else {
1227       tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1228     }
1229   }
1230   int64 skipped_src_dims = absl::c_count(src_to_tgt, -1);
1231   std::vector<int64> perm(src_to_tgt.size());
1232   for (int64 i = 0; i < src_to_tgt.size(); ++i) {
1233     if (src_to_tgt[i] < 0) {
1234       if (source.tile_assignment().dim(i) > 1) {
1235         return absl::nullopt;
1236       }
1237       perm[src_to_tgt.size() - skipped_src_dims] = i;
1238       skipped_src_dims--;
1239     } else {
1240       perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1241     }
1242   }
1243   auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1244   auto reshape_tiles = tgt_sharding.tile_assignment();
1245   std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
1246   for (int64 i = 0; i < tgt_tiles.size(); ++i) {
1247     if (tgt_to_src[i] >= 0) {
1248       tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
1249     }
1250   }
1251   reshape_tiles.Reshape(tgt_tiles);
1252   return source.ReplicateOnLastTileDim()
1253              ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1254              : HloSharding::Tile(reshape_tiles, source.metadata());
1255 }
1256 
GetGatherBatchParallelDims(const HloInstruction & hlo)1257 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
1258     const HloInstruction& hlo) {
1259   const auto& dnums = hlo.gather_dimension_numbers();
1260   int64 index_dim = dnums.index_vector_dim();
1261   // Try to identify if there's a dimension in the indices that is monotonically
1262   // increasing with a Iota across a certain dimension. This would mean that the
1263   // access in the relative dimension indexed by this index in the operand is
1264   // parallelizable and that we can shard the operand (and the index/output)
1265   // across such dimension.
1266   // For example the pattern:
1267   //   %iota.1 = iota()
1268   //   %indices = concatenate(..., %iota.1, ...)
1269   //   ... = gather(..., %indices)
1270   // is common for tf.reverse_sequence and would match this case.
1271   absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1272   const HloInstruction* indices = hlo.operand(1);
1273   const int num_indices = dnums.start_index_map_size();
1274   std::vector<int64> index_parallel_in_dim(num_indices, -1);
1275   // Handle cases where we concatenate pieces of the indices one at a time.
1276   if (indices->opcode() == HloOpcode::kConcatenate &&
1277       indices->concatenate_dimension() == index_dim) {
1278     int concatenated_dims = 0;
1279     for (int i = 0; i < indices->operand_count(); ++i) {
1280       const HloInstruction* op = indices->operand(i);
1281       const int64 num_indices_from_element =
1282           op->shape().dimensions_size() > index_dim
1283               ? op->shape().dimensions(index_dim)
1284               : 1;
1285       if (auto* iota = DynCast<HloIotaInstruction>(op)) {
1286         if (iota->iota_dimension() != index_dim) {
1287           for (int j = 0; j < num_indices_from_element; ++j) {
1288             index_parallel_in_dim[concatenated_dims + j] =
1289                 iota->iota_dimension();
1290           }
1291         }
1292       }
1293       concatenated_dims += num_indices_from_element;
1294     }
1295   } else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
1296     if (iota->iota_dimension() != index_dim) {
1297       // This is a case of a single iota with index_dim being out of bounds.
1298       const int64 num_indices_from_element =
1299           iota->shape().dimensions_size() > index_dim
1300               ? iota->shape().dimensions(index_dim)
1301               : 1;
1302       index_parallel_in_dim.assign(num_indices_from_element,
1303                                    iota->iota_dimension());
1304     }
1305   }
1306   absl::InlinedVector<int64, 1> indices_parallel_dims;
1307   absl::InlinedVector<int64, 1> operand_parallel_dims;
1308   // Map the parallelizable dimension from the iota to the dimensions of the
1309   // output and the operand. These dimensions are interconnected, but between
1310   // operands and index they could have different spots in the shape because the
1311   // position of the index dimension in the operand is determined by
1312   // start_index_map.
1313   for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1314     int index_parallel_dim = index_parallel_in_dim[i];
1315     if (index_parallel_dim == -1) {
1316       continue;
1317     }
1318     if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1319       return absl::nullopt;
1320     }
1321     // Considered parallel only if the slice is of size 1 over the operand.
1322     if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1323       indices_parallel_dims.push_back(index_parallel_dim);
1324       operand_parallel_dims.push_back(dnums.start_index_map(i));
1325     } else {
1326       index_parallel_in_dim[i] = -1;
1327     }
1328   }
1329   absl::c_sort(indices_parallel_dims);
1330   if (!indices_parallel_dims.empty()) {
1331     return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1332                               index_parallel_in_dim};
1333   }
1334   return absl::nullopt;
1335 }
1336 
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1337 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
1338     const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1339   absl::InlinedVector<int64, 1> output_parallel_dims;
1340   auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1341   const Shape gather_shape = gather.shape();
1342   auto dnums = gather.gather_dimension_numbers();
1343   for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1344     if (absl::c_linear_search(dnums.offset_dims(), i)) {
1345       continue;
1346     }
1347     const int index_dim =
1348         idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1349     if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1350       output_parallel_dims.push_back(i);
1351     }
1352     ++idx_dim;
1353   }
1354   return output_parallel_dims;
1355 }
1356 
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1357 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
1358     const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1359   absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
1360       parallel_dims.operand_parallel_dims.size(), -1);
1361   auto dnums = gather.gather_dimension_numbers();
1362   CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1363            parallel_dims.operand_parallel_dims.size());
1364   for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1365     // This is the equivalent batch dimension of the indices that corresponds
1366     // to this index dimension.
1367     const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1368     // If it's not an index that is parallel skip.
1369     if (index_parallel_dim == -1) {
1370       continue;
1371     }
1372     // This is small so just look linearly. Populate the operand parallel
1373     // dimensions based on the order of the index batch dims (which is the same
1374     // order as the output).
1375     for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1376       if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1377         const int64 operand_parallel_dim = dnums.start_index_map(i);
1378         if (operand_parallel_dim_to_output[j] == -1) {
1379           operand_parallel_dim_to_output[j] = operand_parallel_dim;
1380         }
1381         break;
1382       }
1383     }
1384   }
1385   return operand_parallel_dim_to_output;
1386 }
1387 
1388 }  // namespace hlo_sharding_util
1389 }  // namespace xla
1390