1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
29 
30 namespace xla {
31 namespace hlo_sharding_util {
32 
33 struct GatherParallelDims {
34   absl::InlinedVector<int64, 1> indices_parallel_dims;
35   absl::InlinedVector<int64, 1> operand_parallel_dims;
36   std::vector<int64> index_parallel_in_dim;
37 };
38 
39 // Returns true if the lhs sharding is preferable over the rhs sharding.
40 // The most specific sharding is tile maximal followed by single device tile
41 // maximal and finally replicated. This order aims to primarily reduce memory
42 // usage and secondly reduce total compute.
43 // Note: This does NOT provide a total ordering as we can have 2 different
44 // sharding with same preference level.
45 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs);
46 
47 // Tries to refine `to_merge` by combining with `old`. Returns if the final
48 // `to_merge` is more specific than `old`.
49 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
50                    bool may_combine_partial_sharding);
51 
52 // Given a map<device, occurrence_count>, selects the device with higher
53 // occurrence count (if any). If top_count in not nullptr, it will receive the
54 // count of the dominant device returned.
55 absl::optional<int64> SelectDominantDevice(
56     const std::map<int64, int64>& device_map, int64* top_count);
57 
58 // Assigns all the instructions of a computation, to a given device.
59 // This API does not recurse into called computations, and does not assign
60 // instructions which already have sharding.
61 Status AssignComputationDevice(HloComputation* computation, int64 device);
62 
63 // Given an instruction container, returns the device which is most commonly
64 // occurring among the instructions.
65 absl::optional<int64> GetMostOccurringDevice(
66     absl::Span<HloInstruction* const> instructions);
67 
68 // Given a set of computations, tries to extract the dominant device. A device
69 // is dominant if the combined occurrence among all the instructions of the
70 // input computations, is greater/equal than/to dominant_factor (real number
71 // from 0 to 1).
72 // This API does not recurse into called computations.
73 // If no device exists that satisfies the condition, the returned optional will
74 // hold no value.
75 StatusOr<absl::optional<int64>> GetDominantDevice(
76     absl::Span<HloComputation* const> computations, double dominant_factor);
77 
78 // Returns the HloSharding with the tile dimensions and tile assignment
79 // transposed based on the specified dimension numbers. In case of a tile
80 // maximal sharding returns the original sharding.
81 HloSharding TransposeSharding(const HloSharding& sharding,
82                               const std::vector<int64>& dimensions);
83 
84 // Returns the HloSharding with the tile shape reshaped based on the source and
85 // target shapes and the tile assignment adjusted to correspond to the new tile
86 // shape or absl::nullopt if the resulting reshape would create an invalid
87 // sharding (non continuous or non uniformly sized tiles). In case of a tile
88 // maximal sharding returns the original sharding.
89 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
90                                             const Shape& target_shape,
91                                             const HloSharding& sharding);
92 
93 // Returns the HloSharding with the tile dimensions and tile assignment
94 // reversed based on the specified dimension numbers. In case of a tile
95 // maximal sharding returns the original sharding.
96 HloSharding ReverseSharding(const HloSharding& sharding,
97                             absl::Span<const int64> dimensions);
98 
99 // Returns a sharding tiled on unique dimension dim by reshaping the tile
100 // assignment of the sharding argument. Only dimensions in the dims span
101 // argument are considered for reshaping, the others are ignored.
102 // Assumptions: sharding is tile sharded, and dim must be included in dims.
103 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
104                                    absl::Span<const int64> dims);
105 
106 // Returns true if the provided module includes one or more instructions with
107 // a tile sharding.
108 bool ContainsTileSharding(const HloModule& module);
109 
110 // Returns the preferred output sharding for a gather op based on the sharding
111 // of the indces.
112 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
113                                  const HloInstruction* hlo);
114 
115 // Returns the preferred index sharding for a gather op based on the sharding
116 // of the output.
117 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
118                                 const HloInstruction* hlo);
119 
120 // Returns a new HloSharding for a gather op so that only non offset dimensions
121 // are sharded. Assume "result" is returned by this function. It is ensured that
122 // "GetIndexSharding(result, hlo)" will have the same number of elements as
123 // "result".
124 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
125 
126 // Returns the preferred index sharding for a scatter op based on the sharding
127 // of the data.
128 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
129                                  const HloInstruction* hlo);
130 
131 // Returns the preferred data sharding for a scatter op based on the sharding
132 // of the index.
133 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
134                                 const HloInstruction* hlo);
135 
136 // Returns a new index sharding for a scatter op so that we only shard on first
137 // "number of scatter_window_dims" dimensions. Assume "result" is returned by
138 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will
139 // have the same number of elements as "result".
140 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
141                                           const HloInstruction& hlo);
142 
143 // Returns a new data sharding for a scatter op so that we only shard on
144 // scatter_window_dims. Assume "result" is returned by this function. It is
145 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
146 // elements as "result".
147 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
148                                          const HloInstruction& hlo);
149 
150 // Returns an output sharding of gather by passing through the data operand's
151 // sharding.
152 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
153     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
154     const Shape& output_shape, const Shape& operand_shape);
155 
156 // Returns a data operand sharding of gather by passing through the output's
157 // sharding.
158 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
159     const HloSharding& output_sharding, const HloInstruction& hlo);
160 
161 // Returns an output sharding of scatter by passing through the update operand's
162 // sharding.
163 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
164     const HloSharding& update_sharding, const HloInstruction& hlo);
165 
166 // Returns an update operand sharding of scatter by passing through the output's
167 // sharding.
168 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
169     const HloSharding& output_sharding, const HloInstruction& hlo);
170 
171 // Returns an identity value and an HloOpcode for reduce computation of scatter
172 // instruction.
173 // - If computation is add/or, return 0/false with corresponding op code;
174 // - If computation is multiply/and, return 1/true with corresponding op code.
175 // - If computation is min/max, return max value/min value with corresponding op
176 //   code.
177 // - Otherwise, return error status.
178 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
179 IdentityValueAndHloOpcodeForScatterReduceComputation(
180     const HloScatterInstruction& scatter);
181 
182 // Given a sharding and a list of devices in the topology, return a
183 // list of the devices that `sharding` applies to.
184 std::vector<int64> DevicesForSharding(
185     const HloSharding& sharding, const std::vector<int64>& available_devices);
186 
187 // Returns a sharding that replicates data across devices along the given
188 // dimensions in the original sharding.
189 HloSharding PartiallyReplicateTiledShardingOnDims(
190     const HloSharding& sharding, absl::Span<const int64> dims_to_replicate);
191 
192 // Returns a sharding the removes given tile dimensions.
193 //
194 // Precondition: if not tile maximal, the size of each tile dimension must be 1.
195 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
196                                   const std::vector<int64>& dims_to_remove);
197 
198 // Similar to TransposeSharding(), but allows removing/adding non-partitioned
199 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing
200 // dimension.
201 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
202     const HloSharding& source, absl::Span<int64 const> src_to_tgt,
203     absl::Span<int64 const> tgt_to_src);
204 
205 // Returns identified parallel dimensions for Gather.
206 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
207     const HloInstruction& hlo);
208 
209 // Returns the parallel dimensions of the output of a gather based on the
210 // parallel dimensions of the input.
211 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
212     const HloInstruction& gather, const GatherParallelDims& parallel_dim);
213 
214 // Returns the parallel dimensions of the data operand of a gather with the
215 // order of the parallel dimensions matching that of the parallel dimensions
216 // of the output.
217 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
218     const HloInstruction& gather, const GatherParallelDims& parallel_dims);
219 
220 }  // namespace hlo_sharding_util
221 }  // namespace xla
222 
223 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
224