1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/overflow_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
24 #include "tensorflow/core/lib/core/errors.h"
25
26 namespace xla {
27
28 using absl::StrCat;
29 using absl::StrJoin;
30
AssignDevice(int64 device_id,absl::Span<const OpMetadata> metadata)31 HloSharding HloSharding::AssignDevice(int64 device_id,
32 absl::Span<const OpMetadata> metadata) {
33 return HloSharding(device_id, metadata);
34 }
35
Tile1D(const Shape & input_shape,int64 num_tiles,absl::Span<const OpMetadata> metadata)36 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles,
37 absl::Span<const OpMetadata> metadata) {
38 CHECK_EQ(1, input_shape.rank());
39 CHECK_GT(num_tiles, 1);
40 std::vector<int64> dimensions(1, num_tiles);
41 Array<int64> assignment(dimensions);
42 std::iota(assignment.begin(), assignment.end(), 0);
43 return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
44 metadata);
45 }
46
PartialTile(const Array<int64> & group_tile_assignment,absl::Span<const absl::Span<const int64>> replication_groups,absl::Span<const OpMetadata> metadata)47 HloSharding HloSharding::PartialTile(
48 const Array<int64>& group_tile_assignment,
49 absl::Span<const absl::Span<const int64>> replication_groups,
50 absl::Span<const OpMetadata> metadata) {
51 CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
52 if (replication_groups.size() == 1) {
53 return Replicate(metadata);
54 }
55 auto new_tile_dims = group_tile_assignment.dimensions();
56 new_tile_dims.push_back(replication_groups[0].size());
57 auto new_tile_assignment = Array<int64>(new_tile_dims);
58 new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
59 std::vector<int64> group_index(indices.begin(), indices.end());
60 group_index.pop_back();
61 int64 group = group_tile_assignment(group_index);
62 *device = replication_groups[group][indices.back()];
63 });
64 return PartialTile(new_tile_assignment, metadata);
65 }
66
PartialTile(const Array<int64> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)67 HloSharding HloSharding::PartialTile(
68 const Array<int64>& tile_assignment_last_dim_replicate,
69 absl::Span<const OpMetadata> metadata) {
70 if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
71 tile_assignment_last_dim_replicate.dimensions().back() ==
72 tile_assignment_last_dim_replicate.num_elements()) {
73 return Replicate(metadata);
74 }
75 if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
76 auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
77 new_tile_dims.pop_back();
78 auto fully_tiled = tile_assignment_last_dim_replicate;
79 fully_tiled.Reshape(new_tile_dims);
80 return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
81 metadata);
82 }
83 std::vector<std::set<int64>> sorted_groups(
84 tile_assignment_last_dim_replicate.num_elements() /
85 tile_assignment_last_dim_replicate.dimensions().back());
86 auto get_group_id = [&](absl::Span<const int64> indices) {
87 int64 group_id = 0;
88 for (int64 i = 0; i < indices.size() - 1; ++i) {
89 group_id *= tile_assignment_last_dim_replicate.dim(i);
90 group_id += indices[i];
91 }
92 return group_id;
93 };
94 tile_assignment_last_dim_replicate.Each(
95 [&](absl::Span<const int64> indices, const int64 device) {
96 sorted_groups[get_group_id(indices)].insert(device);
97 });
98 Array<int64> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
99 sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
100 auto begin = sorted_groups[get_group_id(indices)].begin();
101 *device = *begin;
102 sorted_groups[get_group_id(indices)].erase(begin);
103 });
104 return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
105 metadata);
106 }
107
Tuple(const ShapeTree<HloSharding> & sub_shardings)108 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
109 std::vector<HloSharding> flattened_list;
110 flattened_list.reserve(sub_shardings.leaf_count());
111 for (const auto& index_to_sharding : sub_shardings.leaves()) {
112 flattened_list.push_back(index_to_sharding.second);
113 }
114 if (flattened_list.empty()) {
115 // Empty tuple sharding ends up having no leaves, but we want to allow
116 // empty tuple HLO instruction results to have sharding, so we fetch the
117 // root ({}) sharding value from the ShapeTree.
118 // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
119 // init as value at its root.
120 flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
121 }
122 return HloSharding(flattened_list);
123 }
124
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)125 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
126 absl::Span<const HloSharding> shardings) {
127 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
128 for (auto& sharding : shardings) {
129 CHECK(!sharding.IsTuple()) << sharding.ToString();
130 }
131 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
132 CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
133 << "Flat list has " << flattened_list.size() << ", required "
134 << RequiredLeaves(tuple_shape);
135 return HloSharding(flattened_list);
136 }
137
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)138 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
139 const HloSharding& sharding) {
140 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
141 CHECK(!sharding.IsTuple()) << sharding.ToString();
142 int64 leaf_count = RequiredLeaves(tuple_shape);
143 std::vector<HloSharding> flattened_list;
144 flattened_list.resize(leaf_count, sharding);
145 return HloSharding(flattened_list);
146 }
147
Single(const Shape & shape,const HloSharding & sharding)148 HloSharding HloSharding::Single(const Shape& shape,
149 const HloSharding& sharding) {
150 return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
151 }
152
ToString(bool include_metadata) const153 string HloSharding::ToString(bool include_metadata) const {
154 if (IsTuple()) {
155 CHECK(metadata_.empty());
156 std::vector<string> parts;
157 parts.reserve(tuple_elements_.size());
158 for (const HloSharding& element : tuple_elements_) {
159 parts.push_back(element.ToString(include_metadata));
160 }
161 return StrCat("{", absl::StrJoin(parts, ", "), "}");
162 }
163
164 std::string metadata;
165 if (include_metadata) {
166 if (metadata_.size() == 1) {
167 metadata =
168 StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
169 } else if (metadata_.size() > 1) {
170 std::vector<std::string> metadata_strings;
171 metadata_strings.reserve(metadata_.size());
172 for (const auto& single_metadata : metadata_) {
173 metadata_strings.push_back(
174 StrCat("{", OpMetadataToString(single_metadata), "}"));
175 }
176 metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
177 }
178 }
179
180 if (replicated_) {
181 return StrCat("{replicated", metadata, "}");
182 }
183
184 if (manual_) {
185 return StrCat("{manual", metadata, "}");
186 }
187 if (maximal_) {
188 return StrCat("{maximal device=",
189 static_cast<int64>(*tile_assignment_.begin()), metadata, "}");
190 }
191 return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
192 StrJoin(tile_assignment_, ","),
193 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
194 metadata, "}");
195 }
196
UsesDevice(int64 device) const197 bool HloSharding::UsesDevice(int64 device) const {
198 if (IsTuple()) {
199 return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
200 return s.UsesDevice(device);
201 });
202 }
203 const auto& devices = tile_assignment_;
204 return replicated_ || manual_ || absl::c_linear_search(devices, device);
205 }
206
UsedDevices(int64 * count) const207 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
208 int64 element_count = 1;
209 std::map<int64, int64> device_map;
210 if (IsTuple()) {
211 for (auto& tuple_element_sharding : tuple_elements()) {
212 auto unique_device = tuple_element_sharding.UniqueDevice();
213 if (unique_device) {
214 device_map[*unique_device] += 1;
215 }
216 }
217 element_count = tuple_elements().size();
218 } else {
219 auto unique_device = UniqueDevice();
220 if (unique_device) {
221 device_map[*unique_device] += 1;
222 }
223 }
224 if (count != nullptr) {
225 *count = element_count;
226 }
227 return device_map;
228 }
229
TileIndexForDevice(int64 device) const230 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
231 CHECK(!maximal_);
232 CHECK(!manual_);
233 CHECK(!IsTuple());
234 std::vector<int64> ret_index;
235 tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
236 if (d == device) {
237 ret_index = {index.begin(), index.end()};
238 }
239 });
240 CHECK(!ret_index.empty());
241 if (replicate_on_last_tile_dim_) {
242 ret_index.pop_back();
243 }
244 return ret_index;
245 }
246
DeviceForTileIndex(absl::Span<const int64> index) const247 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
248 CHECK(!replicated_);
249 CHECK(!manual_);
250 CHECK(!IsTuple());
251 if (maximal_) {
252 return *tile_assignment_.begin();
253 }
254 if (replicate_on_last_tile_dim_ &&
255 index.size() < tile_assignment().num_dimensions()) {
256 std::vector<int64> first_replicated_index(index.begin(), index.end());
257 first_replicated_index.push_back(0);
258 return tile_assignment_(first_replicated_index);
259 }
260 return tile_assignment_(index);
261 }
262
TileOffsetForDevice(const Shape & shape,int64 device) const263 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
264 int64 device) const {
265 CHECK(!IsTuple());
266 CHECK(!manual_);
267
268 if (maximal_) {
269 return std::vector<int64>(shape.dimensions_size(), 0);
270 }
271 if (replicate_on_last_tile_dim_) {
272 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1);
273 } else {
274 CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
275 }
276 std::vector<int64> index = TileIndexForDevice(device);
277 for (int64 i = 0; i < index.size(); ++i) {
278 const int64 shape_dim = shape.dimensions(i);
279 index[i] = std::min(
280 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
281 }
282 return index;
283 }
284
TileLimitForDevice(const Shape & shape,int64 device) const285 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
286 int64 device) const {
287 CHECK(!IsTuple());
288 CHECK(!manual_);
289
290 if (maximal_) {
291 return std::vector<int64>(shape.dimensions().begin(),
292 shape.dimensions().end());
293 }
294
295 CHECK_EQ(shape.dimensions_size() + (ReplicateOnLastTileDim() ? 1 : 0),
296 tile_assignment_.num_dimensions());
297 std::vector<int64> index = TileIndexForDevice(device);
298 for (int64 i = 0; i < index.size(); ++i) {
299 const int64 shape_dim = shape.dimensions(i);
300 index[i] = std::min(
301 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
302 shape_dim);
303 }
304 return index;
305 }
306
RequiredLeaves(const Shape & shape)307 int64 HloSharding::RequiredLeaves(const Shape& shape) {
308 // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
309 // ShapeUtil and ShapeTree are concerned, but they do have a single
310 // tuple_elements_ entry since we want to allow empty tuple results to
311 // have sharding.
312 const int64 leaf_count = ShapeUtil::GetLeafCount(shape);
313 return (leaf_count == 0) ? 1 : leaf_count;
314 }
315
CheckLeafCount(const Shape & shape) const316 Status HloSharding::CheckLeafCount(const Shape& shape) const {
317 int64 shape_leaves = RequiredLeaves(shape);
318 TF_RET_CHECK(shape_leaves == tuple_elements_.size())
319 << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
320 << " leaf nodes while this sharding has " << tuple_elements_.size();
321 return Status::OK();
322 }
323
AsShapeTree(const Shape & shape) const324 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
325 const Shape& shape) const {
326 if (IsTuple()) {
327 ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
328 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
329 auto it = tuple_elements_.begin();
330 for (auto& index_to_sharding : result.leaves()) {
331 index_to_sharding.second = *it++;
332 }
333 if (ShapeUtil::IsEmptyTuple(shape)) {
334 // Empty tuples have no leaves, but we want to assign them a sharding
335 // anyway, so we use the root element sharding.
336 *result.mutable_element(ShapeIndex({})) = *it;
337 }
338 return std::move(result);
339 } else {
340 return ShapeTree<HloSharding>(shape, *this);
341 }
342 }
343
GetTupleSharding(const Shape & shape) const344 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
345 if (IsTuple()) {
346 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
347 return *this;
348 }
349 return Tuple(ShapeTree<HloSharding>(shape, *this));
350 }
351
UniqueDevice() const352 absl::optional<int64> HloSharding::UniqueDevice() const {
353 if (IsTuple()) {
354 if (tuple_elements_.empty()) {
355 return absl::nullopt;
356 }
357 absl::optional<int64> unique_device;
358 for (auto& tuple_sharding : tuple_elements_) {
359 auto device = tuple_sharding.UniqueDevice();
360 if (!device || (unique_device && *device != *unique_device)) {
361 return absl::nullopt;
362 }
363 unique_device = device;
364 }
365 return unique_device;
366 }
367 if (!replicated_ && maximal_) {
368 return static_cast<int64>(*tile_assignment_.begin());
369 }
370 return absl::nullopt;
371 }
372
GetUniqueDevice() const373 int64 HloSharding::GetUniqueDevice() const {
374 auto device = UniqueDevice();
375 CHECK(device) << "Sharding does not have a unique device: " << *this;
376 return *device;
377 }
378
ValidateTuple(const Shape & shape,int64 num_devices) const379 Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
380 if (!shape.IsTuple()) {
381 return tensorflow::errors::InvalidArgument(
382 StrCat("Sharding is tuple-shaped but validation shape is not."));
383 }
384 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
385
386 // Now we've validated the number of tuple elements, it's safe to request a
387 // shape tree.
388 ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
389 for (const auto& index_to_sharding : shape_tree.leaves()) {
390 Status status = index_to_sharding.second.ValidateNonTuple(
391 ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
392 if (!status.ok()) {
393 tensorflow::errors::AppendToMessage(
394 &status, StrCat("Note: While validating sharding tuple element ",
395 index_to_sharding.first.ToString(), " which is ",
396 index_to_sharding.second.ToString()));
397 return status;
398 }
399 }
400 return Status::OK();
401 }
402
Validate(const Shape & shape,int64 num_devices) const403 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
404 if (shape.IsToken()) {
405 return Status::OK();
406 }
407 Status status = IsTuple() ? ValidateTuple(shape, num_devices)
408 : ValidateNonTuple(shape, num_devices);
409 if (!status.ok()) {
410 tensorflow::errors::AppendToMessage(
411 &status, StrCat("Note: While validating sharding ", ToString(),
412 " against shape ", ShapeUtil::HumanString(shape)));
413 }
414 return status;
415 }
416
ValidateNonTuple(const Shape & shape,int64 num_devices) const417 Status HloSharding::ValidateNonTuple(const Shape& shape,
418 int64 num_devices) const {
419 if (shape.IsTuple()) {
420 return tensorflow::errors::InvalidArgument(
421 StrCat("Validation shape is a tuple but sharding is not."));
422 }
423 if (replicated_) {
424 return Status::OK();
425 }
426
427 // All tile assignments must be less than the number of available cores and
428 // unique.
429 Status status = Status::OK();
430 absl::flat_hash_set<int64> seen_cores;
431 tile_assignment_.Each(
432 [&](absl::Span<const int64> indices, int32 core) {
433 // Don't overwrite a bad status, so we report the first error.
434 if (status.ok()) {
435 if (core >= num_devices) {
436 status = tensorflow::errors::InvalidArgument(StrCat(
437 "core ", core, " > ", num_devices, " in tile assignment"));
438 } else if (seen_cores.contains(core)) {
439 status = tensorflow::errors::InvalidArgument(
440 StrCat("core ", core, " is not unique in tile assignment"));
441 }
442 seen_cores.insert(core);
443 }
444 });
445 if (!status.ok()) {
446 return status;
447 }
448
449 if (IsTileMaximal() || IsManual()) {
450 return Status::OK();
451 }
452
453 // The tile assignment tensor must have the same rank as the input, or input
454 // rank + 1 for replicate_on_last_tile_dim_.
455 if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) !=
456 tile_assignment_.num_dimensions()) {
457 return tensorflow::errors::InvalidArgument(
458 "Number of tile assignment dimensions is different to the input rank. "
459 "sharding=",
460 ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
461 }
462
463 // The correct constructor has to be used to create tile maximal shardings.
464 if (tile_assignment_.num_elements() == 1) {
465 return tensorflow::errors::InvalidArgument(
466 "Tile assignment only contains a single device. If a replicated "
467 "sharding was intended, use HloSharding::Replicated(). If a device "
468 "placement was intended, use HloSharding::AssignDevice()");
469 }
470 return Status::OK();
471 }
472
FromProto(const OpSharding & proto)473 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
474 const OpSharding& proto) {
475 std::vector<OpMetadata> metadata(proto.metadata().begin(),
476 proto.metadata().end());
477 if (proto.type() == OpSharding::TUPLE) {
478 TF_RET_CHECK(metadata.empty())
479 << "Tuple sharding is expected to have no metadata.";
480 std::vector<HloSharding> tuple_shardings;
481 tuple_shardings.reserve(proto.tuple_shardings().size());
482 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
483 TF_ASSIGN_OR_RETURN(HloSharding sharding,
484 HloSharding::FromProto(tuple_sharding_proto));
485 tuple_shardings.push_back(sharding);
486 }
487 return HloSharding(tuple_shardings);
488 } else if (proto.type() == OpSharding::REPLICATED) {
489 return Replicate(metadata);
490 } else if (proto.type() == OpSharding::MANUAL) {
491 return Manual(metadata);
492 } else if (proto.tile_assignment_devices().size() == 1) {
493 return HloSharding(proto.tile_assignment_devices(0), metadata);
494 }
495
496 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
497 << "Maximal sharding is expected to have single device assignment, but "
498 << proto.tile_assignment_devices().size() << " has provided.";
499
500 TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
501 TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
502
503 // RE: the product of tile assignment tensor dimensions must be
504 // equal to tile_assignment_devices.size().
505 int64 product_of_dimensions = 1;
506 for (auto dimension : proto.tile_assignment_dimensions()) {
507 TF_RET_CHECK(dimension > 0);
508 product_of_dimensions =
509 MultiplyWithoutOverflow(product_of_dimensions, dimension);
510 TF_RET_CHECK(product_of_dimensions > 0);
511 }
512 TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
513
514 // Some versions of gcc cannot infer the TileAssignment constructor from a
515 // braced initializer-list, so create one manually.
516 std::vector<int64> devices(proto.tile_assignment_devices().begin(),
517 proto.tile_assignment_devices().end());
518 Array<int64> tile_assignment(
519 std::vector<int64>(proto.tile_assignment_dimensions().begin(),
520 proto.tile_assignment_dimensions().end()));
521 std::copy(proto.tile_assignment_devices().begin(),
522 proto.tile_assignment_devices().end(), tile_assignment.begin());
523 return proto.replicate_on_last_tile_dim()
524 ? PartialTile(tile_assignment, metadata)
525 : HloSharding(tile_assignment,
526 /*replicate_on_last_tile_dim=*/false, metadata);
527 }
528
ToProto() const529 OpSharding HloSharding::ToProto() const {
530 OpSharding result;
531
532 if (IsTuple()) {
533 CHECK(metadata_.empty());
534 for (const HloSharding& element : tuple_elements_) {
535 *result.add_tuple_shardings() = element.ToProto();
536 }
537 result.set_type(OpSharding::TUPLE);
538 return result;
539 }
540
541 result.mutable_metadata()->Reserve(metadata_.size());
542 for (const auto& metadata : metadata_) {
543 *result.add_metadata() = metadata;
544 }
545
546 for (int64 dim : tile_assignment_.dimensions()) {
547 result.add_tile_assignment_dimensions(dim);
548 }
549 for (auto device : tile_assignment_) {
550 result.add_tile_assignment_devices(device);
551 }
552 if (IsReplicated()) {
553 result.set_type(OpSharding::REPLICATED);
554 result.clear_tile_assignment_dimensions();
555 } else if (IsTileMaximal()) {
556 result.set_type(OpSharding::MAXIMAL);
557 } else if (IsManual()) {
558 result.set_type(OpSharding::MANUAL);
559 result.clear_tile_assignment_dimensions();
560 } else {
561 result.set_type(OpSharding::OTHER);
562 result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
563 }
564 return result;
565 }
566
TileShape(const Shape & shape) const567 Shape HloSharding::TileShape(const Shape& shape) const {
568 if (IsTileMaximal() || IsManual()) {
569 return shape;
570 }
571 Shape result_shape = shape;
572 for (int64 i = 0; i < shape.dimensions_size(); ++i) {
573 result_shape.set_dimensions(
574 i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
575 }
576 return result_shape;
577 }
578
TileShape(const Shape & shape,int64 device) const579 Shape HloSharding::TileShape(const Shape& shape, int64 device) const {
580 if (IsTileMaximal() || IsManual()) {
581 return shape;
582 }
583
584 std::vector<int64> index = TileIndexForDevice(device);
585 Shape result_shape = shape;
586 for (int64 i = 0; i < index.size(); ++i) {
587 const int64 shape_dim = shape.dimensions(i);
588 int64 offset = std::min(
589 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
590 int64 limit = std::min(
591 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
592 shape_dim);
593 result_shape.set_dimensions(i, limit - offset);
594 }
595 return result_shape;
596 }
597
NumTiles() const598 int64 HloSharding::NumTiles() const {
599 if (IsTileMaximal()) {
600 return 1;
601 }
602 CHECK(!IsManual());
603 if (ReplicateOnLastTileDim()) {
604 return tile_assignment().num_elements() /
605 tile_assignment().dimensions().back();
606 }
607 return tile_assignment().num_elements();
608 }
609
NumTiles(absl::Span<const int64> dims) const610 int64 HloSharding::NumTiles(absl::Span<const int64> dims) const {
611 if (IsTileMaximal()) {
612 return 1;
613 }
614 CHECK(!IsManual());
615 CHECK(!ReplicateOnLastTileDim() ||
616 !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
617 int64 num_tiles = 1;
618 for (auto d : dims) {
619 CHECK(d < tile_assignment().num_dimensions());
620 num_tiles *= tile_assignment().dim(d);
621 }
622 return num_tiles;
623 }
624
GetSubSharding(const Shape & shape,const ShapeIndex & index) const625 HloSharding HloSharding::GetSubSharding(const Shape& shape,
626 const ShapeIndex& index) const {
627 CHECK(IsTuple());
628 int64 sharding_index = 0;
629 const Shape* sub_shape = &shape;
630 for (int64 idx : index) {
631 for (int64 i = 0; i < idx; ++i) {
632 sharding_index +=
633 ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
634 }
635 sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
636 }
637 if (sub_shape->IsTuple()) {
638 auto begin_it = tuple_elements_.begin() + sharding_index;
639 std::vector<HloSharding> sub_shardings(
640 begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
641 return HloSharding::Tuple(*sub_shape, sub_shardings);
642 } else {
643 return tuple_elements_[sharding_index];
644 }
645 }
646
ExtractSingleSharding() const647 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
648 if (!IsTuple()) {
649 return *this;
650 }
651 if (tuple_elements_.empty()) {
652 return absl::nullopt;
653 }
654 for (int64 i = 1; i < tuple_elements_.size(); ++i) {
655 if (tuple_elements_[0] != tuple_elements_[i]) {
656 return absl::nullopt;
657 }
658 }
659 return tuple_elements_.front();
660 }
661
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const662 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
663 bool overwrite) const {
664 auto assign_metadata = [&](HloSharding& sharding) {
665 if (sharding.metadata_.empty() || overwrite) {
666 sharding.metadata_.assign(metadata.begin(), metadata.end());
667 }
668 };
669
670 HloSharding sharding = *this;
671 if (sharding.IsTuple()) {
672 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
673 assign_metadata(sub_sharding);
674 }
675 } else {
676 assign_metadata(sharding);
677 }
678 return sharding;
679 }
680
WithoutMetadata() const681 HloSharding HloSharding::WithoutMetadata() const {
682 HloSharding sharding = *this;
683 sharding.metadata_.clear();
684 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
685 sub_sharding.metadata_.clear();
686 }
687 return sharding;
688 }
689
Hash() const690 size_t HloSharding::Hash() const {
691 if (tuple_) {
692 size_t h = 0;
693 for (const auto& element : tuple_elements_) {
694 h = tensorflow::Hash64Combine(h, element.Hash());
695 }
696 return h;
697 }
698 if (replicated_) {
699 return 0;
700 }
701 if (manual_) {
702 return 1;
703 }
704 size_t h = 0;
705 for (uint32 v : tile_assignment_) {
706 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
707 }
708 if (replicate_on_last_tile_dim_) {
709 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(1));
710 }
711 return h;
712 }
713
operator <<(std::ostream & out,const HloSharding & sharding)714 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
715 out << sharding.ToString();
716 return out;
717 }
718
719 } // namespace xla
720