1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
18
19 #include <memory>
20 #include <string>
21
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
28 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
29 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
30
31 namespace xla {
32 namespace spmd {
33
34 struct GatherParallelDimSharding {
35 HloSharding indices_sharding;
36 HloSharding operand_sharding;
37 };
38
39 // Returns true if the given sharding contains any replicated sharding.
40 bool HasReplicatedSharding(const HloSharding& sharding);
41
42 // Creates constant value instructions of the given shape. The literal must be a
43 // scalar shape and is broadcast to the given shape.
44 HloInstruction* CreateConstant(const Shape& shape, Literal value,
45 SpmdBuilder* b);
46 // Creates zero value instructions of the given shape.
47 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b);
48
49 // Creates one value instructions of the given shape.
50 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b);
51
52 template <typename NativeT>
CreateR0WithType(PrimitiveType type,NativeT value,SpmdBuilder * b)53 HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
54 SpmdBuilder* b) {
55 auto literal = LiteralUtil::CreateR0(value)
56 .ConvertToShape(ShapeUtil::MakeShape(type, {}))
57 .ValueOrDie();
58 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
59 }
60
CreateFirstWithType(PrimitiveType type,SpmdBuilder * b)61 inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) {
62 if (type == F32) {
63 auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
64 return CreateR0WithType(type, -float_pad_value, b);
65 }
66 auto literal = LiteralUtil::MinValue(type);
67 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
68 }
69
CreateLastWithType(PrimitiveType type,SpmdBuilder * b)70 inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) {
71 if (type == F32) {
72 auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
73 return CreateR0WithType(type, float_pad_value, b);
74 }
75 auto literal = LiteralUtil::MaxValue(type);
76 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
77 }
78
79 // Create a binary add computation of the given type and add to the module.
80 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
81
82 // Returns true if the shape can be evenly partitioned for the given sharding.
83 // All tile sharded dimensions should be evenly divisible and there should be no
84 // single-device sharding. Replicate sharding is considered even partition.
85 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
86
87 // Returns the shard shape of the given shape when it is partitioned for the
88 // target sharding.
89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
90
91 // Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
92 // since this can be before layout assignment.
93 int64 ShapeSizeInBytes(const Shape& shape);
94
95 // Returns the shard shape for a partition without padding due to uneven
96 // sharding.
97 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
98 const HloSharding& sharding,
99 int64 partition_id);
100
101 // Generates the HLO instructions that represent the dimension offsets on any
102 // device. The size of the returned vector is the rank of the given shape.
103 // If `dims` is non-empty, the generated offsets will only be non-zero for those
104 // dimensions.
105 std::vector<HloInstruction*> MakePartitionOffsets(
106 const Shape& shape, const HloSharding& sharding,
107 HloInstruction* partition_id, SpmdBuilder* b,
108 absl::Span<const int64> dims = {});
109
110 // Returns the offsets of the partition in the tile assignment.
111 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
112 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
113
114 // Pads hlo to the desired shape using high padding. Either a builder or a
115 // computation needs to be supplied, but not both.
116 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
117 SpmdBuilder* b,
118 HloComputation* computation = nullptr);
119
120 // Returns the padded shape when combining all partitions.
121 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
122 const HloSharding& sharding);
123
124 // Pads the HLO (with base shape) for uneven tiled partition to make it evenly
125 // partitionable.
126 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
127 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
128
129 // Returns the index of the unique tile dimension. Returns absl::nullopt if the
130 // given sharding is not tiled or tiled along multiple dimensions.
131 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding);
132
133 // Utilities for symbolic offset calculation and halo exchange.
134 class OffsetCalculation;
135
136 // Represents a calculation over integers:
137 // (shard_ordinal * multiplier + offset) / divisor
138 class MultiplyAddDivideOffsetCalculation {
139 public:
MultiplyAddDivideOffsetCalculation()140 MultiplyAddDivideOffsetCalculation()
141 : multiplier_(0), offset_(0), divisor_(1) {}
142 MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset,
143 int64 divisor);
144
145 OffsetCalculation operator-(
146 const MultiplyAddDivideOffsetCalculation& other) const;
147
148 bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
149 return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
150 divisor_ == other.divisor_;
151 }
152
IsConstant()153 bool IsConstant() const { return multiplier_ == 0; }
154 void Simplify();
155 int64 Calculate(int64 shard_ordinal) const;
156 HloInstruction* Calculate(HloInstruction* shard_ordinal,
157 SpmdBuilder* b) const;
158
159 // Returns the maximum result for shard ordinals in the range
160 // [start_ordinal, limit_ordinal).
161 int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
162
163 private:
164 int64 multiplier_;
165 int64 offset_;
166 int64 divisor_;
167 };
168
169 // Represents a calculation over integers based on results of other calculations
170 // defined by an opcode. If the opcode is kCopy, it simply wraps an
171 // MultiplyAddDivideOffsetCalculation.
172 class OffsetCalculation {
173 public:
OffsetCalculation()174 OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
OffsetCalculation(const MultiplyAddDivideOffsetCalculation & copy_from)175 explicit OffsetCalculation(
176 const MultiplyAddDivideOffsetCalculation& copy_from)
177 : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
OffsetCalculation(const OffsetCalculation & copy_from)178 OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
OffsetCalculation(HloOpcode opcode,const MultiplyAddDivideOffsetCalculation & lhs,const MultiplyAddDivideOffsetCalculation & rhs)179 OffsetCalculation(HloOpcode opcode,
180 const MultiplyAddDivideOffsetCalculation& lhs,
181 const MultiplyAddDivideOffsetCalculation& rhs)
182 : opcode_(opcode),
183 lhs_(absl::make_unique<OffsetCalculation>(lhs)),
184 rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation(HloOpcode opcode,const OffsetCalculation & lhs,const OffsetCalculation & rhs)185 OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
186 const OffsetCalculation& rhs)
187 : opcode_(opcode),
188 lhs_(absl::make_unique<OffsetCalculation>(lhs)),
189 rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
190
191 OffsetCalculation& operator=(const OffsetCalculation& other);
192
193 // Returns whether the calculation returns the same value for all shards. This
194 // is conservative and could return false even if it is actually constant.
195 bool IsConstant() const;
196
197 OffsetCalculation operator-(const OffsetCalculation& other) const;
198 bool operator==(const OffsetCalculation& other) const;
199 int64 Calculate(int64 shard_ordinal) const;
200 HloInstruction* Calculate(HloInstruction* shard_ordinal,
201 SpmdBuilder* b) const;
202
203 // Returns the maximum result for shard ordinals in the range
204 // [start_ordinal, limit_ordinal).
205 int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
206
207 private:
208 HloOpcode opcode_;
209 std::unique_ptr<OffsetCalculation> lhs_;
210 std::unique_ptr<OffsetCalculation> rhs_;
211 MultiplyAddDivideOffsetCalculation copy_from_;
212 };
213
214 // Performs halo exchange on the given dimension based on the provided
215 // left/right halo size functions. Returns nullopt if the halo is beyond the
216 // direct neighbor of the shard.
217 absl::optional<HloInstruction*> ExchangeHalo(
218 HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
219 const OffsetCalculation& right_halo_size_function, int64 dim,
220 const HloSharding& target,
221 const SPMDCollectiveOpsCreator& collective_ops_creator,
222 int64* next_channel_id, SpmdBuilder* b);
223
224 // Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
225 // dimensions fails to exchange halo (halo is beyond the neighbor shard).
226 absl::optional<HloInstruction*> ExchangeHalo(
227 HloInstruction* hlo,
228 std::vector<OffsetCalculation> left_halo_size_functions,
229 std::vector<OffsetCalculation> right_halo_size_functions,
230 const HloSharding& target,
231 const SPMDCollectiveOpsCreator& collective_ops_creator,
232 int64* next_channel_id, SpmdBuilder* b);
233
234 // Exchanges halos and performs pad/dynamic-slice on the concatenated data such
235 // that the result starts with the first needed element on each shard. It also
236 // masks off invalid data due to padding.
237 // Arguments:
238 // hlo: the HLO op before halo exchange
239 // explicit_left_padding_on_full_shape: the amount of left padding to be added
240 // explicitly by this function on the base shape before partitioning. Without
241 // base dilation, this is usually set to the window's padding_low so that the
242 // sharded op do not need to add padding_low on the window; however, with base
243 // dilation, this could only be set to a custom size.
244 // padded_full_shape_size: the size of the padded full shape on the given
245 // dimension, which includes explicit_left_padding_on_full_shape and required
246 // right padding to make the shape evenly shardable.
247 // shard_size_with_halo: the shard size on the dimension after halo exchange.
248 // If different shards have different sizes, use the maximum size.
249 // offset_on_padded_shape: the offset HLO (S32) that represents the start of
250 // each shard on the padded full shape.
251 // pad_value: the padding value used on the full shape.
252 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
253 HloInstruction* hlo, const Shape& base_shape,
254 const OffsetCalculation& left_halo_size_function,
255 const OffsetCalculation& right_halo_size_function,
256 int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
257 int64 shard_size_with_halo, int64 dim, const HloSharding& target,
258 HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
259 HloInstruction* partition_ordinal,
260 const SPMDCollectiveOpsCreator& collective_ops_creator,
261 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
262
263 // Uses halo exchange to change from right-padding to left-padding for uneven
264 // tiled sharding on the given dimensions. Tiled sharding always pads uneven
265 // partitioned data on the right, but we need to swap it to the left for
266 // kReverse or kConvolution with window reversal.
267 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
268 absl::Span<const int64> dims);
269
270 // Check if the computation is GT comparison and safe for NaNs.
271 bool IsNanSafeGt(HloComputation* computation);
272
273 // Return k in TopK when input value is parttioned in the sort dimension.
274 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
275
276 // Slices the first k elements at slice dimension.
277 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
278 int64 slice_dim, int64 k);
279
280 // Check if a dimension is sharded.
281 int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
282
283 // Returns the list of source-target pairs of dimensions to swap during
284 // resharding via all-to-all. Reshard can be done by swapping each pair at a
285 // time.
286 absl::optional<std::vector<std::pair<int64, int64>>>
287 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
288 const HloSharding& target);
289
290 // Returns whether the resharding can be done via collective-permute.
291 bool CanReshardWithCollectivePermute(const HloSharding& source,
292 const HloSharding& target);
293
294 // Represents grouping devices in a tiled sharding along certain dimensions.
295 // Elements in group dimensions define different device groups, and the sharding
296 // represents the in-group sharding.
297 struct GroupedSharding {
GroupedShardingGroupedSharding298 GroupedSharding(std::vector<std::vector<int64>> device_groups,
299 std::vector<int64> group_dims,
300 std::vector<int64> group_dim_sizes, int64 data_rank,
301 HloSharding grouped_sharding)
302 : device_groups(std::move(device_groups)),
303 group_dims(std::move(group_dims)),
304 group_dim_sizes(std::move(group_dim_sizes)),
305 data_rank(data_rank),
306 sharding(std::move(grouped_sharding)) {}
307 std::vector<std::vector<int64>> device_groups;
308 std::vector<int64> group_dims;
309 std::vector<int64> group_dim_sizes;
310 int64 data_rank;
311 HloSharding sharding;
312 };
313
314 // Creates a GroupedSharding for a tiled sharding with group dim shard sizes.
315 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
316 absl::Span<const int64> group_dims,
317 absl::Span<const int64> group_dim_shards);
318
319 // Creates a GroupedSharding for a tiled sharding.
320 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
321 absl::Span<const int64> group_dims);
322
323 // Reconstructs the ungrouped sharding from a GroupedSharding.
324 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding);
325
326 // Returns a new GroupedSharding that has the same group definition of
327 // `reference`.
328 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
329 const GroupedSharding& reference,
330 bool ignore_group_order = false);
331
332 // Align device groups between the two ahrdings. Equivalent in calling
333 // GroupShardingOnDims on the two sharding AlignGroupsWith and then
334 // UngroupSharding
335 HloSharding AlignShardingOnDims(const HloSharding& sharding,
336 absl::Span<const int64> sharding_dims,
337 const HloSharding& reference,
338 absl::Span<const int64> reference_dims);
339
340 // Returns the per-group base shape, i.e., before applying the in-group
341 // sharding.
342 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
343 const Shape& original_base_shape);
344
345 // Creates the nested partitioner state for in-group patitioning.
346 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
347 const PartitionedHlo::PartitioningState& state,
348 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b);
349
350 // Partially shards a replicated HLO into groups along the group dimensions, and
351 // within each group data is still replicated.
352 HloInstruction* PerGroupSliceFromReplicated(
353 HloInstruction* replicated, HloInstruction* partition_id,
354 const std::vector<std::vector<int64>>& device_groups,
355 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
356 SpmdBuilder* b);
357
358 // Returns the opcode if `reduction_comp` represents a simple binary elementwise
359 // computation on the two operands.
360 absl::optional<HloOpcode> ParseReductionComputation(
361 const HloComputation* reduction_comp);
362
363 // Pad the shape from partial replicate shape for `dst_sharding`.
364 // If dst_sharding needs more padding and per_shard_size increased in
365 // dst_sharding, halo exchange on the right side is needed.
366 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
367 HloInstruction* hlo, const Shape& base_shape,
368 const HloSharding& src_sharding, const HloSharding& dst_sharding,
369 const std::vector<int64>& expand_tile_dims,
370 const SPMDCollectiveOpsCreator& collective_ops_creator,
371 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
372
373 // Get the compatible sharding from a partial replicate sharding to a desired
374 // target tiled sharding.
375 // Compatible means replicate sharding can transform to the target tile
376 // dimensions by dynamic slice.
377 // For example, if partial_sharding is
378 // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
379 // Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
380 // will be sharding={devices=[2,2]0,2,1,3}.
381 // If patial replicate sharding is not partial replicate or can't reshard to
382 // target_tile_dims by dynamic slice, return absl::nullopt.
383 // If target_sharding is already compatible, returns it.
384 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
385 const HloSharding& partial_sharding, const HloSharding& target_sharding);
386
387 // Do left halo exchange if all-reduce directly from tile sharding to partial
388 // replicate sharding will remove useful data from the source.
389 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
390 HloInstruction* hlo, const Shape& base_shape,
391 const HloSharding& src_sharding, const HloSharding& dst_sharding,
392 const std::vector<int64>& replicate_dims,
393 const SPMDCollectiveOpsCreator& collective_ops_creator,
394 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
395
396 // Finds a list of dimensions that can be grouped on such that it will have the
397 // specified device groups. Group order and dimension order are ignored.
398 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
399 const HloSharding& sharding,
400 const std::vector<std::vector<int64>>& device_groups);
401
402 // Create a sharding that matches the provided source sharding on the
403 // specified dimensions. 'target_dims' and 'source_dims' represent the
404 // dimensions for which the sharding should match in their respective shape.
405 // If some devices from the source sharding are left over (because not all the
406 // devices are allocated to 'source_dims' dimensions) then partial replication
407 // is employed to make sure the number of devices for the two sharding match.
408 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
409 const HloSharding& source_sharding,
410 absl::Span<const int64> target_dims,
411 absl::Span<const int64> source_dims);
412
413 // Returns if the sharding across operand and indices of a gather is across
414 // parallel dimensions and matches what SPMD partitioner supports.
415 absl::optional<GatherParallelDimSharding>
416 GatherOperandsShardedAcrossParallelDims(
417 const HloInstruction& operand, const HloInstruction& indices,
418 const hlo_sharding_util::GatherParallelDims& parallel_dims);
419
420 } // namespace spmd
421 } // namespace xla
422
423 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
424