1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
28 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
29 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
30 
31 namespace xla {
32 namespace spmd {
33 
34 struct GatherParallelDimSharding {
35   HloSharding indices_sharding;
36   HloSharding operand_sharding;
37 };
38 
39 // Returns true if the given sharding contains any replicated sharding.
40 bool HasReplicatedSharding(const HloSharding& sharding);
41 
42 // Creates constant value instructions of the given shape. The literal must be a
43 // scalar shape and is broadcast to the given shape.
44 HloInstruction* CreateConstant(const Shape& shape, Literal value,
45                                SpmdBuilder* b);
46 // Creates zero value instructions of the given shape.
47 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b);
48 
49 // Creates one value instructions of the given shape.
50 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b);
51 
52 template <typename NativeT>
CreateR0WithType(PrimitiveType type,NativeT value,SpmdBuilder * b)53 HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
54                                  SpmdBuilder* b) {
55   auto literal = LiteralUtil::CreateR0(value)
56                      .ConvertToShape(ShapeUtil::MakeShape(type, {}))
57                      .ValueOrDie();
58   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
59 }
60 
CreateFirstWithType(PrimitiveType type,SpmdBuilder * b)61 inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) {
62   if (type == F32) {
63     auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
64     return CreateR0WithType(type, -float_pad_value, b);
65   }
66   auto literal = LiteralUtil::MinValue(type);
67   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
68 }
69 
CreateLastWithType(PrimitiveType type,SpmdBuilder * b)70 inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) {
71   if (type == F32) {
72     auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
73     return CreateR0WithType(type, float_pad_value, b);
74   }
75   auto literal = LiteralUtil::MaxValue(type);
76   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
77 }
78 
79 // Create a binary add computation of the given type and add to the module.
80 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
81 
82 // Returns true if the shape can be evenly partitioned for the given sharding.
83 // All tile sharded dimensions should be evenly divisible and there should be no
84 // single-device sharding. Replicate sharding is considered even partition.
85 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
86 
87 // Returns the shard shape of the given shape when it is partitioned for the
88 // target sharding.
89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
90 
91 // Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
92 // since this can be before layout assignment.
93 int64 ShapeSizeInBytes(const Shape& shape);
94 
95 // Returns the shard shape for a partition without padding due to uneven
96 // sharding.
97 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
98                                           const HloSharding& sharding,
99                                           int64 partition_id);
100 
101 // Generates the HLO instructions that represent the dimension offsets on any
102 // device. The size of the returned vector is the rank of the given shape.
103 // If `dims` is non-empty, the generated offsets will only be non-zero for those
104 // dimensions.
105 std::vector<HloInstruction*> MakePartitionOffsets(
106     const Shape& shape, const HloSharding& sharding,
107     HloInstruction* partition_id, SpmdBuilder* b,
108     absl::Span<const int64> dims = {});
109 
110 // Returns the offsets of the partition in the tile assignment.
111 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
112     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
113 
114 // Pads hlo to the desired shape using high padding. Either a builder or a
115 // computation needs to be supplied, but not both.
116 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
117                            SpmdBuilder* b,
118                            HloComputation* computation = nullptr);
119 
120 // Returns the padded shape when combining all partitions.
121 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
122                                           const HloSharding& sharding);
123 
124 // Pads the HLO (with base shape) for uneven tiled partition to make it evenly
125 // partitionable.
126 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
127     HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
128 
129 // Returns the index of the unique tile dimension. Returns absl::nullopt if the
130 // given sharding is not tiled or tiled along multiple dimensions.
131 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding);
132 
133 // Utilities for symbolic offset calculation and halo exchange.
134 class OffsetCalculation;
135 
136 // Represents a calculation over integers:
137 //   (shard_ordinal * multiplier + offset) / divisor
138 class MultiplyAddDivideOffsetCalculation {
139  public:
MultiplyAddDivideOffsetCalculation()140   MultiplyAddDivideOffsetCalculation()
141       : multiplier_(0), offset_(0), divisor_(1) {}
142   MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset,
143                                      int64 divisor);
144 
145   OffsetCalculation operator-(
146       const MultiplyAddDivideOffsetCalculation& other) const;
147 
148   bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
149     return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
150            divisor_ == other.divisor_;
151   }
152 
IsConstant()153   bool IsConstant() const { return multiplier_ == 0; }
154   void Simplify();
155   int64 Calculate(int64 shard_ordinal) const;
156   HloInstruction* Calculate(HloInstruction* shard_ordinal,
157                             SpmdBuilder* b) const;
158 
159   // Returns the maximum result for shard ordinals in the range
160   // [start_ordinal, limit_ordinal).
161   int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
162 
163  private:
164   int64 multiplier_;
165   int64 offset_;
166   int64 divisor_;
167 };
168 
169 // Represents a calculation over integers based on results of other calculations
170 // defined by an opcode. If the opcode is kCopy, it simply wraps an
171 // MultiplyAddDivideOffsetCalculation.
172 class OffsetCalculation {
173  public:
OffsetCalculation()174   OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
OffsetCalculation(const MultiplyAddDivideOffsetCalculation & copy_from)175   explicit OffsetCalculation(
176       const MultiplyAddDivideOffsetCalculation& copy_from)
177       : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
OffsetCalculation(const OffsetCalculation & copy_from)178   OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
OffsetCalculation(HloOpcode opcode,const MultiplyAddDivideOffsetCalculation & lhs,const MultiplyAddDivideOffsetCalculation & rhs)179   OffsetCalculation(HloOpcode opcode,
180                     const MultiplyAddDivideOffsetCalculation& lhs,
181                     const MultiplyAddDivideOffsetCalculation& rhs)
182       : opcode_(opcode),
183         lhs_(absl::make_unique<OffsetCalculation>(lhs)),
184         rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation(HloOpcode opcode,const OffsetCalculation & lhs,const OffsetCalculation & rhs)185   OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
186                     const OffsetCalculation& rhs)
187       : opcode_(opcode),
188         lhs_(absl::make_unique<OffsetCalculation>(lhs)),
189         rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
190 
191   OffsetCalculation& operator=(const OffsetCalculation& other);
192 
193   // Returns whether the calculation returns the same value for all shards. This
194   // is conservative and could return false even if it is actually constant.
195   bool IsConstant() const;
196 
197   OffsetCalculation operator-(const OffsetCalculation& other) const;
198   bool operator==(const OffsetCalculation& other) const;
199   int64 Calculate(int64 shard_ordinal) const;
200   HloInstruction* Calculate(HloInstruction* shard_ordinal,
201                             SpmdBuilder* b) const;
202 
203   // Returns the maximum result for shard ordinals in the range
204   // [start_ordinal, limit_ordinal).
205   int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
206 
207  private:
208   HloOpcode opcode_;
209   std::unique_ptr<OffsetCalculation> lhs_;
210   std::unique_ptr<OffsetCalculation> rhs_;
211   MultiplyAddDivideOffsetCalculation copy_from_;
212 };
213 
214 // Performs halo exchange on the given dimension based on the provided
215 // left/right halo size functions. Returns nullopt if the halo is beyond the
216 // direct neighbor of the shard.
217 absl::optional<HloInstruction*> ExchangeHalo(
218     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
219     const OffsetCalculation& right_halo_size_function, int64 dim,
220     const HloSharding& target,
221     const SPMDCollectiveOpsCreator& collective_ops_creator,
222     int64* next_channel_id, SpmdBuilder* b);
223 
224 // Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
225 // dimensions fails to exchange halo (halo is beyond the neighbor shard).
226 absl::optional<HloInstruction*> ExchangeHalo(
227     HloInstruction* hlo,
228     std::vector<OffsetCalculation> left_halo_size_functions,
229     std::vector<OffsetCalculation> right_halo_size_functions,
230     const HloSharding& target,
231     const SPMDCollectiveOpsCreator& collective_ops_creator,
232     int64* next_channel_id, SpmdBuilder* b);
233 
234 // Exchanges halos and performs pad/dynamic-slice on the concatenated data such
235 // that the result starts with the first needed element on each shard. It also
236 // masks off invalid data due to padding.
237 // Arguments:
238 //  hlo: the HLO op before halo exchange
239 //  explicit_left_padding_on_full_shape: the amount of left padding to be added
240 //   explicitly by this function on the base shape before partitioning. Without
241 //   base dilation, this is usually set to the window's padding_low so that the
242 //   sharded op do not need to add padding_low on the window; however, with base
243 //   dilation, this could only be set to a custom size.
244 //  padded_full_shape_size: the size of the padded full shape on the given
245 //   dimension, which includes explicit_left_padding_on_full_shape and required
246 //   right padding to make the shape evenly shardable.
247 //  shard_size_with_halo: the shard size on the dimension after halo exchange.
248 //   If different shards have different sizes, use the maximum size.
249 //  offset_on_padded_shape: the offset HLO (S32) that represents the start of
250 //   each shard on the padded full shape.
251 //  pad_value: the padding value used on the full shape.
252 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
253     HloInstruction* hlo, const Shape& base_shape,
254     const OffsetCalculation& left_halo_size_function,
255     const OffsetCalculation& right_halo_size_function,
256     int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
257     int64 shard_size_with_halo, int64 dim, const HloSharding& target,
258     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
259     HloInstruction* partition_ordinal,
260     const SPMDCollectiveOpsCreator& collective_ops_creator,
261     int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
262 
263 // Uses halo exchange to change from right-padding to left-padding for uneven
264 // tiled sharding on the given dimensions. Tiled sharding always pads uneven
265 // partitioned data on the right, but we need to swap it to the left for
266 // kReverse or kConvolution with window reversal.
267 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
268                                         absl::Span<const int64> dims);
269 
270 // Check if the computation is GT comparison and safe for NaNs.
271 bool IsNanSafeGt(HloComputation* computation);
272 
273 // Return k in TopK when input value is parttioned in the sort dimension.
274 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
275 
276 // Slices the first k elements at slice dimension.
277 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
278                             int64 slice_dim, int64 k);
279 
280 // Check if a dimension is sharded.
281 int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
282 
283 // Returns the list of source-target pairs of dimensions to swap during
284 // resharding via all-to-all. Reshard can be done by swapping each pair at a
285 // time.
286 absl::optional<std::vector<std::pair<int64, int64>>>
287 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
288                                    const HloSharding& target);
289 
290 // Returns whether the resharding can be done via collective-permute.
291 bool CanReshardWithCollectivePermute(const HloSharding& source,
292                                      const HloSharding& target);
293 
294 // Represents grouping devices in a tiled sharding along certain dimensions.
295 // Elements in group dimensions define different device groups, and the sharding
296 // represents the in-group sharding.
297 struct GroupedSharding {
GroupedShardingGroupedSharding298   GroupedSharding(std::vector<std::vector<int64>> device_groups,
299                   std::vector<int64> group_dims,
300                   std::vector<int64> group_dim_sizes, int64 data_rank,
301                   HloSharding grouped_sharding)
302       : device_groups(std::move(device_groups)),
303         group_dims(std::move(group_dims)),
304         group_dim_sizes(std::move(group_dim_sizes)),
305         data_rank(data_rank),
306         sharding(std::move(grouped_sharding)) {}
307   std::vector<std::vector<int64>> device_groups;
308   std::vector<int64> group_dims;
309   std::vector<int64> group_dim_sizes;
310   int64 data_rank;
311   HloSharding sharding;
312 };
313 
314 // Creates a GroupedSharding for a tiled sharding with group dim shard sizes.
315 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
316                                     absl::Span<const int64> group_dims,
317                                     absl::Span<const int64> group_dim_shards);
318 
319 // Creates a GroupedSharding for a tiled sharding.
320 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
321                                     absl::Span<const int64> group_dims);
322 
323 // Reconstructs the ungrouped sharding from a GroupedSharding.
324 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding);
325 
326 // Returns a new GroupedSharding that has the same group definition of
327 // `reference`.
328 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
329                                 const GroupedSharding& reference,
330                                 bool ignore_group_order = false);
331 
332 // Align device groups between the two ahrdings. Equivalent in calling
333 // GroupShardingOnDims on the two sharding AlignGroupsWith and then
334 // UngroupSharding
335 HloSharding AlignShardingOnDims(const HloSharding& sharding,
336                                 absl::Span<const int64> sharding_dims,
337                                 const HloSharding& reference,
338                                 absl::Span<const int64> reference_dims);
339 
340 // Returns the per-group base shape, i.e., before applying the in-group
341 // sharding.
342 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
343                            const Shape& original_base_shape);
344 
345 // Creates the nested partitioner state for in-group patitioning.
346 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
347     const PartitionedHlo::PartitioningState& state,
348     const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b);
349 
350 // Partially shards a replicated HLO into groups along the group dimensions, and
351 // within each group data is still replicated.
352 HloInstruction* PerGroupSliceFromReplicated(
353     HloInstruction* replicated, HloInstruction* partition_id,
354     const std::vector<std::vector<int64>>& device_groups,
355     absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
356     SpmdBuilder* b);
357 
358 // Returns the opcode if `reduction_comp` represents a simple binary elementwise
359 // computation on the two operands.
360 absl::optional<HloOpcode> ParseReductionComputation(
361     const HloComputation* reduction_comp);
362 
363 // Pad the shape from partial replicate shape for `dst_sharding`.
364 // If dst_sharding needs more padding and per_shard_size increased in
365 // dst_sharding, halo exchange on the right side is needed.
366 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
367     HloInstruction* hlo, const Shape& base_shape,
368     const HloSharding& src_sharding, const HloSharding& dst_sharding,
369     const std::vector<int64>& expand_tile_dims,
370     const SPMDCollectiveOpsCreator& collective_ops_creator,
371     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
372 
373 // Get the compatible sharding from a partial replicate sharding to a desired
374 // target tiled sharding.
375 // Compatible means replicate sharding can transform to the target tile
376 // dimensions by dynamic slice.
377 // For example, if partial_sharding is
378 // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
379 // Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
380 // will be sharding={devices=[2,2]0,2,1,3}.
381 // If patial replicate sharding is not partial replicate or can't reshard to
382 // target_tile_dims by dynamic slice, return absl::nullopt.
383 // If target_sharding is already compatible, returns it.
384 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
385     const HloSharding& partial_sharding, const HloSharding& target_sharding);
386 
387 // Do left halo exchange if all-reduce directly from tile sharding to partial
388 // replicate sharding will remove useful data from the source.
389 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
390     HloInstruction* hlo, const Shape& base_shape,
391     const HloSharding& src_sharding, const HloSharding& dst_sharding,
392     const std::vector<int64>& replicate_dims,
393     const SPMDCollectiveOpsCreator& collective_ops_creator,
394     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
395 
396 // Finds a list of dimensions that can be grouped on such that it will have the
397 // specified device groups. Group order and dimension order are ignored.
398 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
399     const HloSharding& sharding,
400     const std::vector<std::vector<int64>>& device_groups);
401 
402 // Create a sharding that matches the provided source sharding on the
403 // specified dimensions. 'target_dims' and 'source_dims' represent the
404 // dimensions for which the sharding should match in their respective shape.
405 // If some devices from the source sharding are left over (because not all the
406 // devices are allocated to 'source_dims' dimensions) then partial replication
407 // is employed to make sure the number of devices for the two sharding match.
408 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
409                                          const HloSharding& source_sharding,
410                                          absl::Span<const int64> target_dims,
411                                          absl::Span<const int64> source_dims);
412 
413 // Returns if the sharding across operand and indices of a gather is across
414 // parallel dimensions and matches what SPMD partitioner supports.
415 absl::optional<GatherParallelDimSharding>
416 GatherOperandsShardedAcrossParallelDims(
417     const HloInstruction& operand, const HloInstruction& indices,
418     const hlo_sharding_util::GatherParallelDims& parallel_dims);
419 
420 }  // namespace spmd
421 }  // namespace xla
422 
423 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
424