1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 29 30 namespace xla { 31 namespace hlo_sharding_util { 32 33 struct GatherParallelDims { 34 absl::InlinedVector<int64, 1> indices_parallel_dims; 35 absl::InlinedVector<int64, 1> operand_parallel_dims; 36 std::vector<int64> index_parallel_in_dim; 37 }; 38 39 // Returns true if the lhs sharding is preferable over the rhs sharding. 40 // The most specific sharding is tile maximal followed by single device tile 41 // maximal and finally replicated. This order aims to primarily reduce memory 42 // usage and secondly reduce total compute. 43 // Note: This does NOT provide a total ordering as we can have 2 different 44 // sharding with same preference level. 45 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs); 46 47 // Tries to refine `to_merge` by combining with `old`. Returns if the final 48 // `to_merge` is more specific than `old`. 49 bool MergeSharding(const HloSharding& old, HloSharding* to_merge, 50 bool may_combine_partial_sharding); 51 52 // Given a map<device, occurrence_count>, selects the device with higher 53 // occurrence count (if any). If top_count in not nullptr, it will receive the 54 // count of the dominant device returned. 55 absl::optional<int64> SelectDominantDevice( 56 const std::map<int64, int64>& device_map, int64* top_count); 57 58 // Assigns all the instructions of a computation, to a given device. 59 // This API does not recurse into called computations, and does not assign 60 // instructions which already have sharding. 61 Status AssignComputationDevice(HloComputation* computation, int64 device); 62 63 // Given an instruction container, returns the device which is most commonly 64 // occurring among the instructions. 65 absl::optional<int64> GetMostOccurringDevice( 66 absl::Span<HloInstruction* const> instructions); 67 68 // Given a set of computations, tries to extract the dominant device. A device 69 // is dominant if the combined occurrence among all the instructions of the 70 // input computations, is greater/equal than/to dominant_factor (real number 71 // from 0 to 1). 72 // This API does not recurse into called computations. 73 // If no device exists that satisfies the condition, the returned optional will 74 // hold no value. 75 StatusOr<absl::optional<int64>> GetDominantDevice( 76 absl::Span<HloComputation* const> computations, double dominant_factor); 77 78 // Returns the HloSharding with the tile dimensions and tile assignment 79 // transposed based on the specified dimension numbers. In case of a tile 80 // maximal sharding returns the original sharding. 81 HloSharding TransposeSharding(const HloSharding& sharding, 82 const std::vector<int64>& dimensions); 83 84 // Returns the HloSharding with the tile shape reshaped based on the source and 85 // target shapes and the tile assignment adjusted to correspond to the new tile 86 // shape or absl::nullopt if the resulting reshape would create an invalid 87 // sharding (non continuous or non uniformly sized tiles). In case of a tile 88 // maximal sharding returns the original sharding. 89 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, 90 const Shape& target_shape, 91 const HloSharding& sharding); 92 93 // Returns the HloSharding with the tile dimensions and tile assignment 94 // reversed based on the specified dimension numbers. In case of a tile 95 // maximal sharding returns the original sharding. 96 HloSharding ReverseSharding(const HloSharding& sharding, 97 absl::Span<const int64> dimensions); 98 99 // Returns a sharding tiled on unique dimension dim by reshaping the tile 100 // assignment of the sharding argument. Only dimensions in the dims span 101 // argument are considered for reshaping, the others are ignored. 102 // Assumptions: sharding is tile sharded, and dim must be included in dims. 103 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, 104 absl::Span<const int64> dims); 105 106 // Returns true if the provided module includes one or more instructions with 107 // a tile sharding. 108 bool ContainsTileSharding(const HloModule& module); 109 110 // Returns the preferred output sharding for a gather op based on the sharding 111 // of the indces. 112 HloSharding GatherOutputSharding(const HloSharding& index_sharding, 113 const HloInstruction* hlo); 114 115 // Returns the preferred index sharding for a gather op based on the sharding 116 // of the output. 117 HloSharding GatherIndexSharding(const HloSharding& output_sharding, 118 const HloInstruction* hlo); 119 120 // Returns a new HloSharding for a gather op so that only non offset dimensions 121 // are sharded. Assume "result" is returned by this function. It is ensured that 122 // "GetIndexSharding(result, hlo)" will have the same number of elements as 123 // "result". 124 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); 125 126 // Returns the preferred index sharding for a scatter op based on the sharding 127 // of the data. 128 HloSharding ScatterIndexSharding(const HloSharding& data_sharding, 129 const HloInstruction* hlo); 130 131 // Returns the preferred data sharding for a scatter op based on the sharding 132 // of the index. 133 HloSharding ScatterDataSharding(const HloSharding& index_sharding, 134 const HloInstruction* hlo); 135 136 // Returns a new index sharding for a scatter op so that we only shard on first 137 // "number of scatter_window_dims" dimensions. Assume "result" is returned by 138 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will 139 // have the same number of elements as "result". 140 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, 141 const HloInstruction& hlo); 142 143 // Returns a new data sharding for a scatter op so that we only shard on 144 // scatter_window_dims. Assume "result" is returned by this function. It is 145 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of 146 // elements as "result". 147 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, 148 const HloInstruction& hlo); 149 150 // Returns an output sharding of gather by passing through the data operand's 151 // sharding. 152 absl::optional<HloSharding> GatherOutputShardingFromDataOperand( 153 const HloSharding& data_operand_sharding, const HloInstruction& hlo, 154 const Shape& output_shape, const Shape& operand_shape); 155 156 // Returns a data operand sharding of gather by passing through the output's 157 // sharding. 158 absl::optional<HloSharding> GatherDataOperandShardingFromOutput( 159 const HloSharding& output_sharding, const HloInstruction& hlo); 160 161 // Returns an output sharding of scatter by passing through the update operand's 162 // sharding. 163 absl::optional<HloSharding> ScatterOutputShardingFromUpdate( 164 const HloSharding& update_sharding, const HloInstruction& hlo); 165 166 // Returns an update operand sharding of scatter by passing through the output's 167 // sharding. 168 absl::optional<HloSharding> ScatterUpdateShardingFromOutput( 169 const HloSharding& output_sharding, const HloInstruction& hlo); 170 171 // Returns an identity value and an HloOpcode for reduce computation of scatter 172 // instruction. 173 // - If computation is add/or, return 0/false with corresponding op code; 174 // - If computation is multiply/and, return 1/true with corresponding op code. 175 // - If computation is min/max, return max value/min value with corresponding op 176 // code. 177 // - Otherwise, return error status. 178 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>> 179 IdentityValueAndHloOpcodeForScatterReduceComputation( 180 const HloScatterInstruction& scatter); 181 182 // Given a sharding and a list of devices in the topology, return a 183 // list of the devices that `sharding` applies to. 184 std::vector<int64> DevicesForSharding( 185 const HloSharding& sharding, const std::vector<int64>& available_devices); 186 187 // Returns a sharding that replicates data across devices along the given 188 // dimensions in the original sharding. 189 HloSharding PartiallyReplicateTiledShardingOnDims( 190 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate); 191 192 // Returns a sharding the removes given tile dimensions. 193 // 194 // Precondition: if not tile maximal, the size of each tile dimension must be 1. 195 HloSharding RemoveShapeDimensions(const HloSharding& sharding, 196 const std::vector<int64>& dims_to_remove); 197 198 // Similar to TransposeSharding(), but allows removing/adding non-partitioned 199 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing 200 // dimension. 201 absl::optional<HloSharding> TransposeShardingWithCollapsedDims( 202 const HloSharding& source, absl::Span<int64 const> src_to_tgt, 203 absl::Span<int64 const> tgt_to_src); 204 205 // Returns identified parallel dimensions for Gather. 206 absl::optional<GatherParallelDims> GetGatherBatchParallelDims( 207 const HloInstruction& hlo); 208 209 // Returns the parallel dimensions of the output of a gather based on the 210 // parallel dimensions of the input. 211 absl::InlinedVector<int64, 1> GatherParallelOutputDims( 212 const HloInstruction& gather, const GatherParallelDims& parallel_dim); 213 214 // Returns the parallel dimensions of the data operand of a gather with the 215 // order of the parallel dimensions matching that of the parallel dimensions 216 // of the output. 217 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims( 218 const HloInstruction& gather, const GatherParallelDims& parallel_dims); 219 220 } // namespace hlo_sharding_util 221 } // namespace xla 222 223 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 224