1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/overflow_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 
26 namespace xla {
27 
28 using absl::StrCat;
29 using absl::StrJoin;
30 
AssignDevice(int64 device_id,absl::Span<const OpMetadata> metadata)31 HloSharding HloSharding::AssignDevice(int64 device_id,
32                                       absl::Span<const OpMetadata> metadata) {
33   return HloSharding(device_id, metadata);
34 }
35 
Tile1D(const Shape & input_shape,int64 num_tiles,absl::Span<const OpMetadata> metadata)36 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles,
37                                 absl::Span<const OpMetadata> metadata) {
38   CHECK_EQ(1, input_shape.rank());
39   CHECK_GT(num_tiles, 1);
40   std::vector<int64> dimensions(1, num_tiles);
41   Array<int64> assignment(dimensions);
42   std::iota(assignment.begin(), assignment.end(), 0);
43   return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
44                      metadata);
45 }
46 
PartialTile(const Array<int64> & group_tile_assignment,absl::Span<const absl::Span<const int64>> replication_groups,absl::Span<const OpMetadata> metadata)47 HloSharding HloSharding::PartialTile(
48     const Array<int64>& group_tile_assignment,
49     absl::Span<const absl::Span<const int64>> replication_groups,
50     absl::Span<const OpMetadata> metadata) {
51   CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
52   if (replication_groups.size() == 1) {
53     return Replicate(metadata);
54   }
55   auto new_tile_dims = group_tile_assignment.dimensions();
56   new_tile_dims.push_back(replication_groups[0].size());
57   auto new_tile_assignment = Array<int64>(new_tile_dims);
58   new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
59     std::vector<int64> group_index(indices.begin(), indices.end());
60     group_index.pop_back();
61     int64 group = group_tile_assignment(group_index);
62     *device = replication_groups[group][indices.back()];
63   });
64   return PartialTile(new_tile_assignment, metadata);
65 }
66 
PartialTile(const Array<int64> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)67 HloSharding HloSharding::PartialTile(
68     const Array<int64>& tile_assignment_last_dim_replicate,
69     absl::Span<const OpMetadata> metadata) {
70   if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
71       tile_assignment_last_dim_replicate.dimensions().back() ==
72           tile_assignment_last_dim_replicate.num_elements()) {
73     return Replicate(metadata);
74   }
75   if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
76     auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
77     new_tile_dims.pop_back();
78     auto fully_tiled = tile_assignment_last_dim_replicate;
79     fully_tiled.Reshape(new_tile_dims);
80     return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
81                        metadata);
82   }
83   std::vector<std::set<int64>> sorted_groups(
84       tile_assignment_last_dim_replicate.num_elements() /
85       tile_assignment_last_dim_replicate.dimensions().back());
86   auto get_group_id = [&](absl::Span<const int64> indices) {
87     int64 group_id = 0;
88     for (int64 i = 0; i < indices.size() - 1; ++i) {
89       group_id *= tile_assignment_last_dim_replicate.dim(i);
90       group_id += indices[i];
91     }
92     return group_id;
93   };
94   tile_assignment_last_dim_replicate.Each(
95       [&](absl::Span<const int64> indices, const int64 device) {
96         sorted_groups[get_group_id(indices)].insert(device);
97       });
98   Array<int64> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
99   sorted_tile.Each([&](absl::Span<const int64> indices, int64* device) {
100     auto begin = sorted_groups[get_group_id(indices)].begin();
101     *device = *begin;
102     sorted_groups[get_group_id(indices)].erase(begin);
103   });
104   return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
105                      metadata);
106 }
107 
Tuple(const ShapeTree<HloSharding> & sub_shardings)108 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
109   std::vector<HloSharding> flattened_list;
110   flattened_list.reserve(sub_shardings.leaf_count());
111   for (const auto& index_to_sharding : sub_shardings.leaves()) {
112     flattened_list.push_back(index_to_sharding.second);
113   }
114   if (flattened_list.empty()) {
115     // Empty tuple sharding ends up having no leaves, but we want to allow
116     // empty tuple HLO instruction results to have sharding, so we fetch the
117     // root ({}) sharding value from the ShapeTree.
118     // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
119     // init as value at its root.
120     flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
121   }
122   return HloSharding(flattened_list);
123 }
124 
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)125 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
126                                absl::Span<const HloSharding> shardings) {
127   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
128   for (auto& sharding : shardings) {
129     CHECK(!sharding.IsTuple()) << sharding.ToString();
130   }
131   std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
132   CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
133       << "Flat list has " << flattened_list.size() << ", required "
134       << RequiredLeaves(tuple_shape);
135   return HloSharding(flattened_list);
136 }
137 
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)138 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
139                                      const HloSharding& sharding) {
140   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
141   CHECK(!sharding.IsTuple()) << sharding.ToString();
142   int64 leaf_count = RequiredLeaves(tuple_shape);
143   std::vector<HloSharding> flattened_list;
144   flattened_list.resize(leaf_count, sharding);
145   return HloSharding(flattened_list);
146 }
147 
Single(const Shape & shape,const HloSharding & sharding)148 HloSharding HloSharding::Single(const Shape& shape,
149                                 const HloSharding& sharding) {
150   return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
151 }
152 
ToString(bool include_metadata) const153 string HloSharding::ToString(bool include_metadata) const {
154   if (IsTuple()) {
155     CHECK(metadata_.empty());
156     std::vector<string> parts;
157     parts.reserve(tuple_elements_.size());
158     for (const HloSharding& element : tuple_elements_) {
159       parts.push_back(element.ToString(include_metadata));
160     }
161     return StrCat("{", absl::StrJoin(parts, ", "), "}");
162   }
163 
164   std::string metadata;
165   if (include_metadata) {
166     if (metadata_.size() == 1) {
167       metadata =
168           StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
169     } else if (metadata_.size() > 1) {
170       std::vector<std::string> metadata_strings;
171       metadata_strings.reserve(metadata_.size());
172       for (const auto& single_metadata : metadata_) {
173         metadata_strings.push_back(
174             StrCat("{", OpMetadataToString(single_metadata), "}"));
175       }
176       metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
177     }
178   }
179 
180   if (replicated_) {
181     return StrCat("{replicated", metadata, "}");
182   }
183 
184   if (manual_) {
185     return StrCat("{manual", metadata, "}");
186   }
187   if (maximal_) {
188     return StrCat("{maximal device=",
189                   static_cast<int64>(*tile_assignment_.begin()), metadata, "}");
190   }
191   return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
192                 StrJoin(tile_assignment_, ","),
193                 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
194                 metadata, "}");
195 }
196 
UsesDevice(int64 device) const197 bool HloSharding::UsesDevice(int64 device) const {
198   if (IsTuple()) {
199     return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
200       return s.UsesDevice(device);
201     });
202   }
203   const auto& devices = tile_assignment_;
204   return replicated_ || manual_ || absl::c_linear_search(devices, device);
205 }
206 
UsedDevices(int64 * count) const207 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
208   int64 element_count = 1;
209   std::map<int64, int64> device_map;
210   if (IsTuple()) {
211     for (auto& tuple_element_sharding : tuple_elements()) {
212       auto unique_device = tuple_element_sharding.UniqueDevice();
213       if (unique_device) {
214         device_map[*unique_device] += 1;
215       }
216     }
217     element_count = tuple_elements().size();
218   } else {
219     auto unique_device = UniqueDevice();
220     if (unique_device) {
221       device_map[*unique_device] += 1;
222     }
223   }
224   if (count != nullptr) {
225     *count = element_count;
226   }
227   return device_map;
228 }
229 
TileIndexForDevice(int64 device) const230 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
231   CHECK(!maximal_);
232   CHECK(!manual_);
233   CHECK(!IsTuple());
234   std::vector<int64> ret_index;
235   tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
236     if (d == device) {
237       ret_index = {index.begin(), index.end()};
238     }
239   });
240   CHECK(!ret_index.empty());
241   if (replicate_on_last_tile_dim_) {
242     ret_index.pop_back();
243   }
244   return ret_index;
245 }
246 
DeviceForTileIndex(absl::Span<const int64> index) const247 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
248   CHECK(!replicated_);
249   CHECK(!manual_);
250   CHECK(!IsTuple());
251   if (maximal_) {
252     return *tile_assignment_.begin();
253   }
254   if (replicate_on_last_tile_dim_ &&
255       index.size() < tile_assignment().num_dimensions()) {
256     std::vector<int64> first_replicated_index(index.begin(), index.end());
257     first_replicated_index.push_back(0);
258     return tile_assignment_(first_replicated_index);
259   }
260   return tile_assignment_(index);
261 }
262 
TileOffsetForDevice(const Shape & shape,int64 device) const263 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
264                                                     int64 device) const {
265   CHECK(!IsTuple());
266   CHECK(!manual_);
267 
268   if (maximal_) {
269     return std::vector<int64>(shape.dimensions_size(), 0);
270   }
271   if (replicate_on_last_tile_dim_) {
272     CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1);
273   } else {
274     CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
275   }
276   std::vector<int64> index = TileIndexForDevice(device);
277   for (int64 i = 0; i < index.size(); ++i) {
278     const int64 shape_dim = shape.dimensions(i);
279     index[i] = std::min(
280         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
281   }
282   return index;
283 }
284 
TileLimitForDevice(const Shape & shape,int64 device) const285 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
286                                                    int64 device) const {
287   CHECK(!IsTuple());
288   CHECK(!manual_);
289 
290   if (maximal_) {
291     return std::vector<int64>(shape.dimensions().begin(),
292                               shape.dimensions().end());
293   }
294 
295   CHECK_EQ(shape.dimensions_size() + (ReplicateOnLastTileDim() ? 1 : 0),
296            tile_assignment_.num_dimensions());
297   std::vector<int64> index = TileIndexForDevice(device);
298   for (int64 i = 0; i < index.size(); ++i) {
299     const int64 shape_dim = shape.dimensions(i);
300     index[i] = std::min(
301         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
302         shape_dim);
303   }
304   return index;
305 }
306 
RequiredLeaves(const Shape & shape)307 int64 HloSharding::RequiredLeaves(const Shape& shape) {
308   // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
309   // ShapeUtil and ShapeTree are concerned, but they do have a single
310   // tuple_elements_ entry since we want to allow empty tuple results to
311   // have sharding.
312   const int64 leaf_count = ShapeUtil::GetLeafCount(shape);
313   return (leaf_count == 0) ? 1 : leaf_count;
314 }
315 
CheckLeafCount(const Shape & shape) const316 Status HloSharding::CheckLeafCount(const Shape& shape) const {
317   int64 shape_leaves = RequiredLeaves(shape);
318   TF_RET_CHECK(shape_leaves == tuple_elements_.size())
319       << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
320       << " leaf nodes while this sharding has " << tuple_elements_.size();
321   return Status::OK();
322 }
323 
AsShapeTree(const Shape & shape) const324 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
325     const Shape& shape) const {
326   if (IsTuple()) {
327     ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
328     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
329     auto it = tuple_elements_.begin();
330     for (auto& index_to_sharding : result.leaves()) {
331       index_to_sharding.second = *it++;
332     }
333     if (ShapeUtil::IsEmptyTuple(shape)) {
334       // Empty tuples have no leaves, but we want to assign them a sharding
335       // anyway, so we use the root element sharding.
336       *result.mutable_element(ShapeIndex({})) = *it;
337     }
338     return std::move(result);
339   } else {
340     return ShapeTree<HloSharding>(shape, *this);
341   }
342 }
343 
GetTupleSharding(const Shape & shape) const344 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
345   if (IsTuple()) {
346     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
347     return *this;
348   }
349   return Tuple(ShapeTree<HloSharding>(shape, *this));
350 }
351 
UniqueDevice() const352 absl::optional<int64> HloSharding::UniqueDevice() const {
353   if (IsTuple()) {
354     if (tuple_elements_.empty()) {
355       return absl::nullopt;
356     }
357     absl::optional<int64> unique_device;
358     for (auto& tuple_sharding : tuple_elements_) {
359       auto device = tuple_sharding.UniqueDevice();
360       if (!device || (unique_device && *device != *unique_device)) {
361         return absl::nullopt;
362       }
363       unique_device = device;
364     }
365     return unique_device;
366   }
367   if (!replicated_ && maximal_) {
368     return static_cast<int64>(*tile_assignment_.begin());
369   }
370   return absl::nullopt;
371 }
372 
GetUniqueDevice() const373 int64 HloSharding::GetUniqueDevice() const {
374   auto device = UniqueDevice();
375   CHECK(device) << "Sharding does not have a unique device: " << *this;
376   return *device;
377 }
378 
ValidateTuple(const Shape & shape,int64 num_devices) const379 Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
380   if (!shape.IsTuple()) {
381     return tensorflow::errors::InvalidArgument(
382         StrCat("Sharding is tuple-shaped but validation shape is not."));
383   }
384   TF_RETURN_IF_ERROR(CheckLeafCount(shape));
385 
386   // Now we've validated the number of tuple elements, it's safe to request a
387   // shape tree.
388   ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
389   for (const auto& index_to_sharding : shape_tree.leaves()) {
390     Status status = index_to_sharding.second.ValidateNonTuple(
391         ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
392     if (!status.ok()) {
393       tensorflow::errors::AppendToMessage(
394           &status, StrCat("Note: While validating sharding tuple element ",
395                           index_to_sharding.first.ToString(), " which is ",
396                           index_to_sharding.second.ToString()));
397       return status;
398     }
399   }
400   return Status::OK();
401 }
402 
Validate(const Shape & shape,int64 num_devices) const403 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
404   if (shape.IsToken()) {
405     return Status::OK();
406   }
407   Status status = IsTuple() ? ValidateTuple(shape, num_devices)
408                             : ValidateNonTuple(shape, num_devices);
409   if (!status.ok()) {
410     tensorflow::errors::AppendToMessage(
411         &status, StrCat("Note: While validating sharding ", ToString(),
412                         " against shape ", ShapeUtil::HumanString(shape)));
413   }
414   return status;
415 }
416 
ValidateNonTuple(const Shape & shape,int64 num_devices) const417 Status HloSharding::ValidateNonTuple(const Shape& shape,
418                                      int64 num_devices) const {
419   if (shape.IsTuple()) {
420     return tensorflow::errors::InvalidArgument(
421         StrCat("Validation shape is a tuple but sharding is not."));
422   }
423   if (replicated_) {
424     return Status::OK();
425   }
426 
427   // All tile assignments must be less than the number of available cores and
428   // unique.
429   Status status = Status::OK();
430   absl::flat_hash_set<int64> seen_cores;
431   tile_assignment_.Each(
432       [&](absl::Span<const int64> indices, int32 core) {
433         // Don't overwrite a bad status, so we report the first error.
434         if (status.ok()) {
435           if (core >= num_devices) {
436             status = tensorflow::errors::InvalidArgument(StrCat(
437                 "core ", core, " > ", num_devices, " in tile assignment"));
438           } else if (seen_cores.contains(core)) {
439             status = tensorflow::errors::InvalidArgument(
440                 StrCat("core ", core, " is not unique in tile assignment"));
441           }
442           seen_cores.insert(core);
443         }
444       });
445   if (!status.ok()) {
446     return status;
447   }
448 
449   if (IsTileMaximal() || IsManual()) {
450     return Status::OK();
451   }
452 
453   // The tile assignment tensor must have the same rank as the input, or input
454   // rank + 1 for replicate_on_last_tile_dim_.
455   if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) !=
456       tile_assignment_.num_dimensions()) {
457     return tensorflow::errors::InvalidArgument(
458         "Number of tile assignment dimensions is different to the input rank. "
459         "sharding=",
460         ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
461   }
462 
463   // The correct constructor has to be used to create tile maximal shardings.
464   if (tile_assignment_.num_elements() == 1) {
465     return tensorflow::errors::InvalidArgument(
466         "Tile assignment only contains a single device. If a replicated "
467         "sharding was intended, use HloSharding::Replicated(). If a device "
468         "placement was intended, use HloSharding::AssignDevice()");
469   }
470   return Status::OK();
471 }
472 
FromProto(const OpSharding & proto)473 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
474     const OpSharding& proto) {
475   std::vector<OpMetadata> metadata(proto.metadata().begin(),
476                                    proto.metadata().end());
477   if (proto.type() == OpSharding::TUPLE) {
478     TF_RET_CHECK(metadata.empty())
479         << "Tuple sharding is expected to have no metadata.";
480     std::vector<HloSharding> tuple_shardings;
481     tuple_shardings.reserve(proto.tuple_shardings().size());
482     for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
483       TF_ASSIGN_OR_RETURN(HloSharding sharding,
484                           HloSharding::FromProto(tuple_sharding_proto));
485       tuple_shardings.push_back(sharding);
486     }
487     return HloSharding(tuple_shardings);
488   } else if (proto.type() == OpSharding::REPLICATED) {
489     return Replicate(metadata);
490   } else if (proto.type() == OpSharding::MANUAL) {
491     return Manual(metadata);
492   } else if (proto.tile_assignment_devices().size() == 1) {
493     return HloSharding(proto.tile_assignment_devices(0), metadata);
494   }
495 
496   TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
497       << "Maximal sharding is expected to have single device assignment, but "
498       << proto.tile_assignment_devices().size() << " has provided.";
499 
500   TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
501   TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
502 
503   // RE: the product of tile assignment tensor dimensions must be
504   // equal to tile_assignment_devices.size().
505   int64 product_of_dimensions = 1;
506   for (auto dimension : proto.tile_assignment_dimensions()) {
507     TF_RET_CHECK(dimension > 0);
508     product_of_dimensions =
509         MultiplyWithoutOverflow(product_of_dimensions, dimension);
510     TF_RET_CHECK(product_of_dimensions > 0);
511   }
512   TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
513 
514   // Some versions of gcc cannot infer the TileAssignment constructor from a
515   // braced initializer-list, so create one manually.
516   std::vector<int64> devices(proto.tile_assignment_devices().begin(),
517                              proto.tile_assignment_devices().end());
518   Array<int64> tile_assignment(
519       std::vector<int64>(proto.tile_assignment_dimensions().begin(),
520                          proto.tile_assignment_dimensions().end()));
521   std::copy(proto.tile_assignment_devices().begin(),
522             proto.tile_assignment_devices().end(), tile_assignment.begin());
523   return proto.replicate_on_last_tile_dim()
524              ? PartialTile(tile_assignment, metadata)
525              : HloSharding(tile_assignment,
526                            /*replicate_on_last_tile_dim=*/false, metadata);
527 }
528 
ToProto() const529 OpSharding HloSharding::ToProto() const {
530   OpSharding result;
531 
532   if (IsTuple()) {
533     CHECK(metadata_.empty());
534     for (const HloSharding& element : tuple_elements_) {
535       *result.add_tuple_shardings() = element.ToProto();
536     }
537     result.set_type(OpSharding::TUPLE);
538     return result;
539   }
540 
541   result.mutable_metadata()->Reserve(metadata_.size());
542   for (const auto& metadata : metadata_) {
543     *result.add_metadata() = metadata;
544   }
545 
546   for (int64 dim : tile_assignment_.dimensions()) {
547     result.add_tile_assignment_dimensions(dim);
548   }
549   for (auto device : tile_assignment_) {
550     result.add_tile_assignment_devices(device);
551   }
552   if (IsReplicated()) {
553     result.set_type(OpSharding::REPLICATED);
554     result.clear_tile_assignment_dimensions();
555   } else if (IsTileMaximal()) {
556     result.set_type(OpSharding::MAXIMAL);
557   } else if (IsManual()) {
558     result.set_type(OpSharding::MANUAL);
559     result.clear_tile_assignment_dimensions();
560   } else {
561     result.set_type(OpSharding::OTHER);
562     result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
563   }
564   return result;
565 }
566 
TileShape(const Shape & shape) const567 Shape HloSharding::TileShape(const Shape& shape) const {
568   if (IsTileMaximal() || IsManual()) {
569     return shape;
570   }
571   Shape result_shape = shape;
572   for (int64 i = 0; i < shape.dimensions_size(); ++i) {
573     result_shape.set_dimensions(
574         i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
575   }
576   return result_shape;
577 }
578 
TileShape(const Shape & shape,int64 device) const579 Shape HloSharding::TileShape(const Shape& shape, int64 device) const {
580   if (IsTileMaximal() || IsManual()) {
581     return shape;
582   }
583 
584   std::vector<int64> index = TileIndexForDevice(device);
585   Shape result_shape = shape;
586   for (int64 i = 0; i < index.size(); ++i) {
587     const int64 shape_dim = shape.dimensions(i);
588     int64 offset = std::min(
589         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
590     int64 limit = std::min(
591         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
592         shape_dim);
593     result_shape.set_dimensions(i, limit - offset);
594   }
595   return result_shape;
596 }
597 
NumTiles() const598 int64 HloSharding::NumTiles() const {
599   if (IsTileMaximal()) {
600     return 1;
601   }
602   CHECK(!IsManual());
603   if (ReplicateOnLastTileDim()) {
604     return tile_assignment().num_elements() /
605            tile_assignment().dimensions().back();
606   }
607   return tile_assignment().num_elements();
608 }
609 
NumTiles(absl::Span<const int64> dims) const610 int64 HloSharding::NumTiles(absl::Span<const int64> dims) const {
611   if (IsTileMaximal()) {
612     return 1;
613   }
614   CHECK(!IsManual());
615   CHECK(!ReplicateOnLastTileDim() ||
616         !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
617   int64 num_tiles = 1;
618   for (auto d : dims) {
619     CHECK(d < tile_assignment().num_dimensions());
620     num_tiles *= tile_assignment().dim(d);
621   }
622   return num_tiles;
623 }
624 
GetSubSharding(const Shape & shape,const ShapeIndex & index) const625 HloSharding HloSharding::GetSubSharding(const Shape& shape,
626                                         const ShapeIndex& index) const {
627   CHECK(IsTuple());
628   int64 sharding_index = 0;
629   const Shape* sub_shape = &shape;
630   for (int64 idx : index) {
631     for (int64 i = 0; i < idx; ++i) {
632       sharding_index +=
633           ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
634     }
635     sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
636   }
637   if (sub_shape->IsTuple()) {
638     auto begin_it = tuple_elements_.begin() + sharding_index;
639     std::vector<HloSharding> sub_shardings(
640         begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
641     return HloSharding::Tuple(*sub_shape, sub_shardings);
642   } else {
643     return tuple_elements_[sharding_index];
644   }
645 }
646 
ExtractSingleSharding() const647 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
648   if (!IsTuple()) {
649     return *this;
650   }
651   if (tuple_elements_.empty()) {
652     return absl::nullopt;
653   }
654   for (int64 i = 1; i < tuple_elements_.size(); ++i) {
655     if (tuple_elements_[0] != tuple_elements_[i]) {
656       return absl::nullopt;
657     }
658   }
659   return tuple_elements_.front();
660 }
661 
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const662 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
663                                       bool overwrite) const {
664   auto assign_metadata = [&](HloSharding& sharding) {
665     if (sharding.metadata_.empty() || overwrite) {
666       sharding.metadata_.assign(metadata.begin(), metadata.end());
667     }
668   };
669 
670   HloSharding sharding = *this;
671   if (sharding.IsTuple()) {
672     for (HloSharding& sub_sharding : sharding.tuple_elements()) {
673       assign_metadata(sub_sharding);
674     }
675   } else {
676     assign_metadata(sharding);
677   }
678   return sharding;
679 }
680 
WithoutMetadata() const681 HloSharding HloSharding::WithoutMetadata() const {
682   HloSharding sharding = *this;
683   sharding.metadata_.clear();
684   for (HloSharding& sub_sharding : sharding.tuple_elements()) {
685     sub_sharding.metadata_.clear();
686   }
687   return sharding;
688 }
689 
Hash() const690 size_t HloSharding::Hash() const {
691   if (tuple_) {
692     size_t h = 0;
693     for (const auto& element : tuple_elements_) {
694       h = tensorflow::Hash64Combine(h, element.Hash());
695     }
696     return h;
697   }
698   if (replicated_) {
699     return 0;
700   }
701   if (manual_) {
702     return 1;
703   }
704   size_t h = 0;
705   for (uint32 v : tile_assignment_) {
706     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
707   }
708   if (replicate_on_last_tile_dim_) {
709     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(1));
710   }
711   return h;
712 }
713 
operator <<(std::ostream & out,const HloSharding & sharding)714 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
715   out << sharding.ToString();
716   return out;
717 }
718 
719 }  // namespace xla
720