1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <map> 23 #include <string> 24 #include <vector> 25 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/protobuf_util.h" 30 #include "tensorflow/compiler/xla/shape_tree.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/lib/hash/hash.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 // HLO shardings describe how an HLO instruction is split across multiple 40 // computations. 41 class HloSharding { 42 public: 43 // Creates a trivial sharding that replicates a maximal tile across all 44 // devices. 45 static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) { 46 return HloSharding(/*manual=*/false, /*replicated=*/true, metadata); 47 } 48 49 // Creates a sharding that represents the op is manually partitioned. 50 static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) { 51 return HloSharding(/*manual=*/true, /*replicated=*/false, metadata); 52 } 53 54 // Creates a sharding that emulates device placement; a tile shape equal to 55 // the input shape (one tile) assigned to a single device. 56 static HloSharding AssignDevice(int64 device_id, 57 absl::Span<const OpMetadata> metadata = {}); 58 59 // Creates a new sharding which splits a shape into tiles amongst the devices 60 // specified by `tile_assignment`. 61 static HloSharding Tile(const Array<int64>& tile_assignment, 62 absl::Span<const OpMetadata> metadata = {}) { 63 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false, 64 metadata); 65 } 66 67 // Creates a new sharding where data is replicated within each replication 68 // group, and sharded across replication groups according to 69 // group_tile_assignment. Replication group members will be sorted. 70 static HloSharding PartialTile( 71 const Array<int64>& group_tile_assignment, 72 absl::Span<const absl::Span<const int64>> replication_groups, 73 absl::Span<const OpMetadata> metadata = {}); 74 75 // Creates a partially replicated tiled sharding with device-level tile 76 // assignment, where the last dimension is the additional replication 77 // dimension. Replication group members will be sorted. 78 static HloSharding PartialTile( 79 const Array<int64>& tile_assignment_last_dim_replicate, 80 absl::Span<const OpMetadata> metadata = {}); 81 82 // Creates a new sharding which splits a one-dimensional input shape into 83 // `num_tiles` tiles. 84 static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles, 85 absl::Span<const OpMetadata> metadata = {}); 86 87 // Creates a new sharding for a tuple type. The given ShapeTree must have 88 // elements for every leaf shape contained in the tuple. 89 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings); 90 91 // Creates a new sharding for a tuple type. The number of elements in 92 // shardings must match the number of leaf nodes in tuple_shape. For 93 // empty tuples, the shardings array must have one element. 94 static HloSharding Tuple(const Shape& tuple_shape, 95 absl::Span<const HloSharding> shardings); 96 97 // Creates a new sharding for a tuple type, with a single input sharding 98 // repeated on each leaf. 99 static HloSharding SingleTuple(const Shape& tuple_shape, 100 const HloSharding& sharding); 101 102 // If shape is an array, returns sharding, otherwise returns the tuple shaped 103 // sharding with all the leaf nodes having the same input sharding. 104 static HloSharding Single(const Shape& shape, const HloSharding& sharding); 105 106 // Create a new sharding from a protobuf OpSharding. 107 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 108 109 // Checks whether device is a reserved device number. A reserved device number 110 // has usually a special meaning, with dedicated handling logic. IsReservedDevice(int64 device)111 static bool IsReservedDevice(int64 device) { return device < 0; } 112 113 OpSharding ToProto() const; 114 115 // Note that this string canonically has outer curly braces, e.g. 116 // "{replicated}". 117 string ToString(bool include_metadata = false) const; 118 119 // Validate that this sharding can be applied to a tensor with shape `shape`. 120 Status Validate(const Shape& shape, int64 num_devices) const; 121 122 // Returns true if the sharding has tuple type. IsTuple()123 bool IsTuple() const { return tuple_; } 124 125 // Returns true if the sharding is trivial: replicate on all devices. IsReplicated()126 bool IsReplicated() const { 127 if (!IsTuple()) { 128 return replicated_; 129 } 130 return absl::c_all_of( 131 tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); 132 } 133 134 // Returns true if the tile size is the same as the input size. IsTileMaximal()135 bool IsTileMaximal() const { 136 if (!IsTuple()) { 137 return maximal_; 138 } 139 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 140 return s.IsTileMaximal(); 141 }); 142 } 143 144 // Returns whether the sharding represents manual partitioning. IsManual()145 bool IsManual() const { 146 if (!IsTuple()) { 147 return manual_; 148 } 149 return absl::c_all_of(tuple_elements_, 150 [](const HloSharding& s) { return s.IsManual(); }); 151 } 152 153 // Returns if the sharding has partial replication and partial sharding. If 154 // true, data is sharded according to other dimensions of tile_assignment(), 155 // but replicated across devices along the last dimension. ReplicateOnLastTileDim()156 bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; } 157 158 // Returns true if the sharding defines an operation on the given device. 159 bool UsesDevice(int64 device) const; 160 161 // Retrieves a histogram of the devices used by the sharding. The returned 162 // map has the device number as key, and the occurrence count as value. 163 // If a sharding does not have a device, it will not be included in the 164 // histogram. The count argument, if not nullptr, will receive the total 165 // number of elements this sharding is made of (one for array, N leaves for 166 // tuples). 167 std::map<int64, int64> UsedDevices(int64* count) const; 168 169 // Returns the tile that should be executed on the given device. 170 // REQUIRES: !IsTuple() 171 std::vector<int64> TileIndexForDevice(int64 device) const; 172 173 // Returns the device that should execute the given tile. 174 // It is an error to call this if is_replicated() is true. 175 // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it 176 // returns the first device in that replicated subgroup; otherwise, 177 // index.size() should be the same as tile_assignment()'s rank and specifies 178 // the member of the replication subgroup. 179 // REQUIRES: !IsTuple() 180 int64 DeviceForTileIndex(absl::Span<const int64> index) const; 181 182 // Given a device ID, returns the offset within the specified shape of the 183 // tile that should be executed on the given core. This returns the lower 184 // extent of the tile in the input space. 185 // REQUIRES: !IsTuple() 186 std::vector<int64> TileOffsetForDevice(const Shape& shape, 187 int64 device) const; 188 189 // Given a device ID, returns the limit within the specified shape of the 190 // tile that should be executed on the given core. This returns the upper 191 // extent of the tile in the input space. 192 // REQUIRES: !IsTuple() 193 std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const; 194 195 // Returns the single device this op operates on. If the sharding does not 196 // span a single device, the return value will be empty. 197 // In order for a sharding to span a single device, every leaf sharding must 198 // be maximal and not replicated, and the used device must match. 199 absl::optional<int64> UniqueDevice() const; 200 201 // Retrieves the unique device or fails with a CHECK. 202 int64 GetUniqueDevice() const; 203 204 // Returns true if this op only uses a single device. HasUniqueDevice()205 bool HasUniqueDevice() const { return UniqueDevice().has_value(); } 206 207 // Returns the ShapeTree containing the shardings for each element of this 208 // tuple, if IsTuple, or a ShapeTree with a single element containing this 209 // sharding. Only the leaf elements are populated. This creates a new 210 // ShapeTree object so is not cheap. 211 StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const; GetAsShapeTree(const Shape & shape)212 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 213 return AsShapeTree(shape).ValueOrDie(); 214 } 215 216 // Retrieves the sub sharding at a given index, out of a tuple sharding. 217 // REQUIRES: IsTuple() 218 HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; 219 220 // If the current sharding is a tuple sharding, return itself as result. 221 // Otherwise returns a tuple sharding for the input shape, with all the leaves 222 // having this object sharding. 223 StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const; 224 225 // Extracts the sharding that is common within the current sharding. 226 // If the current sharding is not a tuple sharding, the current sharding will 227 // be returned. If it is a tuple, and all the tuple elements are common, the 228 // common element will be returned. Otherwise the optional will contain no 229 // value. 230 absl::optional<HloSharding> ExtractSingleSharding() const; 231 232 // Returns a copy of the sharding with no metadata. If sharding is of tuple 233 // type, sub shardings will have no metadata. 234 HloSharding WithoutMetadata() const; 235 236 // Returns a copy of the sharding with specified metadata. If metadata is 237 // already present, that metadata will not be replaced unless `overwrite` is 238 // set to true. If sharding is of tuple type, sub shardings metadata will be 239 // assigned instead. 240 HloSharding WithMetadata(absl::Span<const OpMetadata> metadata, 241 bool overwrite) const; 242 243 bool operator==(const HloSharding& other) const { 244 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 245 manual_ == other.manual_ && 246 tile_assignment_ == other.tile_assignment_ && 247 tuple_elements_ == other.tuple_elements_ && 248 replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_; 249 } 250 bool operator!=(const HloSharding& other) const { return !(*this == other); } 251 252 size_t Hash() const; 253 254 struct Hasher { operatorHasher255 size_t operator()(const HloSharding& sharding) const { 256 return sharding.Hash(); 257 } 258 }; 259 260 // Gets the tile assignment tensor. 261 // REQUIRES: !IsReplicated() && !IsTuple() tile_assignment()262 const Array<int64>& tile_assignment() const { return tile_assignment_; } 263 264 // Returns the flattened list of all the leaf shardings in a tuple shape, by 265 // pre-order walk (ShapeTree iterator order). 266 // REQUIRES: IsTuple(). tuple_elements()267 std::vector<HloSharding>& tuple_elements() { return tuple_elements_; } tuple_elements()268 const std::vector<HloSharding>& tuple_elements() const { 269 return tuple_elements_; 270 } 271 272 // Gets the tile shape. 273 // REQUIRES: !IsTuple() 274 Shape TileShape(const Shape& shape) const; 275 276 // Gets the tile shape on the device. 277 // REQUIRES: !IsTuple() 278 Shape TileShape(const Shape& shape, int64 device) const; 279 280 // Gets the number of tiles. If it has partial replication, this will not 281 // equal the device count. 282 int64 NumTiles() const; 283 // Like NumTiles() but considers only some specific dimensions passed as 284 // argument 285 int64 NumTiles(absl::Span<const int64> dims) const; 286 287 // Gets metadata from sharding. metadata()288 std::vector<OpMetadata>& metadata() { return metadata_; } metadata()289 const std::vector<OpMetadata>& metadata() const { return metadata_; } 290 291 private: HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)292 explicit HloSharding(bool manual, bool replicated, 293 absl::Span<const OpMetadata> metadata) 294 : replicated_(replicated), 295 maximal_(replicated), 296 tuple_(false), 297 manual_(manual), 298 tile_assignment_({0}), 299 replicate_on_last_tile_dim_(false), 300 metadata_(metadata.begin(), metadata.end()) {} 301 // device_id values: 302 // -2: magic number to mean unassigned device, used by spatial partitioning 303 // -1: the id of the host 304 // 0 or positive: the id of a device 305 // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once 306 // we have fully switched to the side-effect tokens. HloSharding(int64 device_id,absl::Span<const OpMetadata> metadata)307 explicit HloSharding(int64 device_id, absl::Span<const OpMetadata> metadata) 308 : replicated_(false), 309 maximal_(true), 310 tuple_(false), 311 manual_(false), 312 tile_assignment_({1}, device_id), 313 replicate_on_last_tile_dim_(false), 314 metadata_(metadata.begin(), metadata.end()) {} 315 explicit HloSharding(const Array<int64>& tile_assignment, 316 bool replicate_on_last_tile_dim, 317 absl::Span<const OpMetadata> metadata = {}) replicated_(false)318 : replicated_(false), 319 maximal_(false), 320 tuple_(false), 321 manual_(false), 322 tile_assignment_(tile_assignment), 323 replicate_on_last_tile_dim_(replicate_on_last_tile_dim), 324 metadata_(metadata.begin(), metadata.end()) {} HloSharding(const std::vector<HloSharding> & tuple_shardings)325 explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) 326 : replicated_(false), 327 maximal_(false), 328 tuple_(true), 329 manual_(false), 330 tile_assignment_({0}), 331 tuple_elements_(tuple_shardings), 332 replicate_on_last_tile_dim_(false) {} 333 334 // Checks that the number of elements in tuple_elements_ is consistent with 335 // the tuple shape passes as argument. 336 Status CheckLeafCount(const Shape& shape) const; 337 338 // Internal helper to validate a tuple sharding. 339 Status ValidateTuple(const Shape& shape, int64 num_devices) const; 340 341 // Internal helper to validate a non-tuple (leaf) sharding. 342 Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; 343 344 // Returns the number of tuple_elements_ entries to fit the shape. 345 static int64 RequiredLeaves(const Shape& shape); 346 347 bool replicated_; 348 bool maximal_; 349 bool tuple_; 350 bool manual_; 351 // This field is only used if replicated_ is false. If maximal_ is true, then 352 // the field contains a rank 1 array with a single element, which is the 353 // device the HLO is assigned to. If maximal_ is false, the field contains an 354 // array with the same rank as the corresponding HLO. The dimension sizes of 355 // the array describe the number of ways the HLO is partitioned along each 356 // dimension. The values of the array specify which device each tile of 357 // the HLO is assigned to. The index of each value determines which tile it 358 // takes. 359 // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is 360 // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and 361 // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the 362 // tile that contains the 2nd half of dimension 1 and the 1st half of 363 // dimension 3. 364 Array<int64> tile_assignment_; 365 // Only non-empty when tuple_ is true. If a tuple is empty then one entry is 366 // present for the root. This is a flattened list of all the leaf shardings in 367 // a tuple shape, by pre-order walk (ShapeTree iterator order). 368 std::vector<HloSharding> tuple_elements_; 369 // This flag is to support partial replication and partial sharding. If it is 370 // true, tile_assignment_ will have an extra dimension in addition to the data 371 // shape rank, and the added last dimension represents the subgroups of 372 // replications, i.e., elements in slice [..., :] will be replicated. 373 bool replicate_on_last_tile_dim_; 374 // This field is used to track the source of this sharding, usually derived 375 // from instructions. Multple metadata may be populated if sharding is 376 // combined with other shardings. Metadata are to not be populated when 377 // tuple_ == true and instead metadata should be set on individual tuple 378 // elements. 379 std::vector<OpMetadata> metadata_; 380 }; 381 382 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); 383 384 } // namespace xla 385 386 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 387