1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/overflow_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace xla {
25 
26 using absl::StrCat;
27 using absl::StrJoin;
28 
AssignDevice(int64 device_id)29 HloSharding HloSharding::AssignDevice(int64 device_id) {
30   return HloSharding(device_id);
31 }
32 
Tile1D(const Shape & input_shape,int64 num_tiles)33 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
34   CHECK_EQ(1, input_shape.rank());
35   CHECK_GT(num_tiles, 1);
36   std::vector<int64> dimensions(1, num_tiles);
37   Array<int64> assignment(dimensions);
38   std::iota(assignment.begin(), assignment.end(), 0);
39   return HloSharding(assignment);
40 }
41 
Tuple(const ShapeTree<HloSharding> & sub_shardings)42 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
43   std::vector<HloSharding> flattened_list;
44   flattened_list.reserve(sub_shardings.leaf_count());
45   for (const auto& index_to_sharding : sub_shardings.leaves()) {
46     flattened_list.push_back(index_to_sharding.second);
47   }
48   if (flattened_list.empty()) {
49     // Empty tuple sharding ends up having no leaves, but we want to allow
50     // empty tuple HLO instruction results to have sharding, so we fetch the
51     // root ({}) sharding value from the ShapeTree.
52     // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
53     // init as value at its root.
54     flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
55   }
56   return HloSharding(flattened_list);
57 }
58 
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)59 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
60                                absl::Span<const HloSharding> shardings) {
61   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
62   for (auto& sharding : shardings) {
63     CHECK(!sharding.IsTuple()) << sharding.ToString();
64   }
65   std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
66   CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
67       << "Flat list has " << flattened_list.size() << ", required "
68       << RequiredLeaves(tuple_shape);
69   return HloSharding(flattened_list);
70 }
71 
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)72 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
73                                      const HloSharding& sharding) {
74   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
75   CHECK(!sharding.IsTuple()) << sharding.ToString();
76   int64 leaf_count = RequiredLeaves(tuple_shape);
77   std::vector<HloSharding> flattened_list;
78   flattened_list.resize(leaf_count, sharding);
79   return HloSharding(flattened_list);
80 }
81 
Single(const Shape & shape,const HloSharding & sharding)82 HloSharding HloSharding::Single(const Shape& shape,
83                                 const HloSharding& sharding) {
84   return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
85 }
86 
ToString() const87 string HloSharding::ToString() const {
88   if (IsTuple()) {
89     std::vector<string> parts;
90     parts.reserve(tuple_elements_.size());
91     for (const HloSharding& element : tuple_elements_) {
92       parts.push_back(element.ToString());
93     }
94     return StrCat("{", absl::StrJoin(parts, ", "), "}");
95   }
96 
97   if (replicated_) {
98     return "{replicated}";
99   }
100   if (maximal_) {
101     return StrCat(
102         "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
103   }
104   return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
105                 StrJoin(tile_assignment_, ","), "}");
106 }
107 
UsesDevice(int64 device) const108 bool HloSharding::UsesDevice(int64 device) const {
109   if (IsTuple()) {
110     return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
111       return s.UsesDevice(device);
112     });
113   }
114   const auto& devices = tile_assignment_;
115   return replicated_ || absl::c_linear_search(devices, device);
116 }
117 
UsedDevices(int64 * count) const118 std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
119   int64 element_count = 1;
120   std::map<int64, int64> device_map;
121   if (IsTuple()) {
122     for (auto& tuple_element_sharding : tuple_elements()) {
123       auto unique_device = tuple_element_sharding.UniqueDevice();
124       if (unique_device) {
125         device_map[*unique_device] += 1;
126       }
127     }
128     element_count = tuple_elements().size();
129   } else {
130     auto unique_device = UniqueDevice();
131     if (unique_device) {
132       device_map[*unique_device] += 1;
133     }
134   }
135   if (count != nullptr) {
136     *count = element_count;
137   }
138   return device_map;
139 }
140 
TileIndexForDevice(int64 device) const141 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
142   CHECK(!maximal_);
143   CHECK(!IsTuple());
144   std::vector<int64> ret_index;
145   tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
146     if (d == device) {
147       ret_index = {index.begin(), index.end()};
148     }
149   });
150   CHECK(!ret_index.empty());
151   return ret_index;
152 }
153 
DeviceForTileIndex(absl::Span<const int64> index) const154 int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
155   CHECK(!replicated_);
156   CHECK(!IsTuple());
157   if (maximal_) {
158     return *tile_assignment_.begin();
159   }
160   return tile_assignment_(index);
161 }
162 
TileOffsetForDevice(const Shape & shape,int64 device) const163 std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
164                                                     int64 device) const {
165   CHECK(!IsTuple());
166 
167   if (maximal_) {
168     return std::vector<int64>(shape.dimensions_size(), 0);
169   }
170 
171   CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
172   std::vector<int64> index = TileIndexForDevice(device);
173   for (int64 i = 0; i < index.size(); ++i) {
174     const int64 shape_dim = shape.dimensions(i);
175     index[i] = std::min(
176         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
177   }
178   return index;
179 }
180 
TileLimitForDevice(const Shape & shape,int64 device) const181 std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
182                                                    int64 device) const {
183   CHECK(!IsTuple());
184 
185   if (maximal_) {
186     return std::vector<int64>(shape.dimensions().begin(),
187                               shape.dimensions().end());
188   }
189 
190   CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
191   std::vector<int64> index = TileIndexForDevice(device);
192   for (int64 i = 0; i < index.size(); ++i) {
193     const int64 shape_dim = shape.dimensions(i);
194     index[i] = std::min(
195         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
196         shape_dim);
197   }
198   return index;
199 }
200 
RequiredLeaves(const Shape & shape)201 int64 HloSharding::RequiredLeaves(const Shape& shape) {
202   // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are
203   // concerned, but they do have a single tuple_elements_ entry since we want
204   // to allow empty tuple results to have sharding.
205   return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape);
206 }
207 
CheckLeafCount(const Shape & shape) const208 Status HloSharding::CheckLeafCount(const Shape& shape) const {
209   int64 shape_leaves = RequiredLeaves(shape);
210   TF_RET_CHECK(shape_leaves == tuple_elements_.size())
211       << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
212       << " leaf nodes while this sharding has " << tuple_elements_.size();
213   return Status::OK();
214 }
215 
AsShapeTree(const Shape & shape) const216 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
217     const Shape& shape) const {
218   if (IsTuple()) {
219     ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
220     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
221     auto it = tuple_elements_.begin();
222     for (auto& index_to_sharding : result.leaves()) {
223       index_to_sharding.second = *it++;
224     }
225     if (ShapeUtil::IsEmptyTuple(shape)) {
226       // Empty tuples have no leaves, but we want to assign them a sharding
227       // anyway, so we use the root element sharding.
228       *result.mutable_element(ShapeIndex({})) = *it;
229     }
230     return std::move(result);
231   } else {
232     return ShapeTree<HloSharding>(shape, *this);
233   }
234 }
235 
GetTupleSharding(const Shape & shape) const236 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
237   if (IsTuple()) {
238     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
239     return *this;
240   }
241   return Tuple(ShapeTree<HloSharding>(shape, *this));
242 }
243 
UniqueDevice() const244 absl::optional<int64> HloSharding::UniqueDevice() const {
245   if (IsTuple()) {
246     if (tuple_elements_.empty()) {
247       return absl::nullopt;
248     }
249     absl::optional<int64> unique_device;
250     for (auto& tuple_sharding : tuple_elements_) {
251       auto device = tuple_sharding.UniqueDevice();
252       if (!device || (unique_device && *device != *unique_device)) {
253         return absl::nullopt;
254       }
255       unique_device = device;
256     }
257     return unique_device;
258   }
259   if (!replicated_ && maximal_) {
260     return static_cast<int64>(*tile_assignment_.begin());
261   }
262   return absl::nullopt;
263 }
264 
GetUniqueDevice() const265 int64 HloSharding::GetUniqueDevice() const {
266   auto device = UniqueDevice();
267   CHECK(device) << "Sharding does not have a unique device: " << *this;
268   return *device;
269 }
270 
ValidateTuple(const Shape & shape,int64 num_devices) const271 Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
272   if (!shape.IsTuple()) {
273     return tensorflow::errors::InvalidArgument(
274         StrCat("Sharding is tuple-shaped but validation shape is not."));
275   }
276   TF_RETURN_IF_ERROR(CheckLeafCount(shape));
277 
278   // Now we've validated the number of tuple elements, it's safe to request a
279   // shape tree.
280   ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
281   for (const auto& index_to_sharding : shape_tree.leaves()) {
282     Status status = index_to_sharding.second.ValidateNonTuple(
283         ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
284     if (!status.ok()) {
285       tensorflow::errors::AppendToMessage(
286           &status, StrCat("Note: While validating sharding tuple element ",
287                           index_to_sharding.first.ToString(), " which is ",
288                           index_to_sharding.second.ToString()));
289       return status;
290     }
291   }
292   return Status::OK();
293 }
294 
Validate(const Shape & shape,int64 num_devices) const295 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
296   Status status = IsTuple() ? ValidateTuple(shape, num_devices)
297                             : ValidateNonTuple(shape, num_devices);
298   if (!status.ok()) {
299     tensorflow::errors::AppendToMessage(
300         &status, StrCat("Note: While validating sharding ", ToString(),
301                         " against shape ", ShapeUtil::HumanString(shape)));
302   }
303   return status;
304 }
305 
ValidateNonTuple(const Shape & shape,int64 num_devices) const306 Status HloSharding::ValidateNonTuple(const Shape& shape,
307                                      int64 num_devices) const {
308   if (shape.IsTuple()) {
309     return tensorflow::errors::InvalidArgument(
310         StrCat("Validation shape is a tuple but sharding is not."));
311   }
312   if (replicated_) {
313     return Status::OK();
314   }
315 
316   // All tile assignments must be less than the number of available cores and
317   // unique.
318   Status status = Status::OK();
319   absl::flat_hash_set<int64> seen_cores;
320   tile_assignment_.Each(
321       [&](absl::Span<const int64> indices, int32 core) {
322         // Don't overwrite a bad status, so we report the first error.
323         if (status.ok()) {
324           if (core >= num_devices) {
325             status = tensorflow::errors::InvalidArgument(StrCat(
326                 "core ", core, " > ", num_devices, " in tile assignment"));
327           } else if (seen_cores.contains(core)) {
328             status = tensorflow::errors::InvalidArgument(
329                 StrCat("core ", core, " is not unique in tile assignment"));
330           }
331           seen_cores.insert(core);
332         }
333       });
334   if (!status.ok()) {
335     return status;
336   }
337 
338   if (IsTileMaximal()) {
339     return Status::OK();
340   }
341 
342   // The tile assignment tensor must have the same rank as the input.
343   if (shape.rank() != tile_assignment_.num_dimensions()) {
344     return tensorflow::errors::InvalidArgument(
345         "Number of tile assignment dimensions is different to the input rank. "
346         "sharding=",
347         ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
348   }
349 
350   // The correct constructor has to be used to create tile maximal shardings.
351   if (tile_assignment_.num_elements() == 1) {
352     return tensorflow::errors::InvalidArgument(
353         "Tile assignment only contains a single device. If a replicated "
354         "sharding was intended, use HloSharding::Replicated(). If a device "
355         "placement was intended, use HloSharding::AssignDevice()");
356   }
357   return Status::OK();
358 }
359 
FromProto(const OpSharding & proto)360 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
361     const OpSharding& proto) {
362   if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
363     std::vector<HloSharding> tuple_shardings;
364     tuple_shardings.reserve(proto.tuple_shardings().size());
365     for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
366       TF_ASSIGN_OR_RETURN(HloSharding sharding,
367                           HloSharding::FromProto(tuple_sharding_proto));
368       tuple_shardings.push_back(sharding);
369     }
370     return HloSharding(tuple_shardings);
371   } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
372     return Replicate();
373   } else if (proto.tile_assignment_devices().size() == 1) {
374     return HloSharding(proto.tile_assignment_devices(0));
375   }
376 
377   TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL)
378       << "Maximal sharding is expected to have single device assignment, but "
379       << proto.tile_assignment_devices().size() << " has provided.";
380 
381   TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
382   TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
383 
384   // RE: the product of tile assignment tensor dimensions must be
385   // equal to tile_assignment_devices.size().
386   int64 product_of_dimensions = 1;
387   for (auto dimension : proto.tile_assignment_dimensions()) {
388     TF_RET_CHECK(dimension > 0);
389     product_of_dimensions =
390         MultiplyWithoutOverflow(product_of_dimensions, dimension);
391     TF_RET_CHECK(product_of_dimensions > 0);
392   }
393   TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
394 
395   // Some versions of gcc cannot infer the TileAssignment constructor from a
396   // braced initializer-list, so create one manually.
397   std::vector<int64> devices(proto.tile_assignment_devices().begin(),
398                              proto.tile_assignment_devices().end());
399   Array<int64> tile_assignment(
400       std::vector<int64>(proto.tile_assignment_dimensions().begin(),
401                          proto.tile_assignment_dimensions().end()));
402   std::copy(proto.tile_assignment_devices().begin(),
403             proto.tile_assignment_devices().end(), tile_assignment.begin());
404   return HloSharding(tile_assignment);
405 }
406 
ToProto() const407 OpSharding HloSharding::ToProto() const {
408   OpSharding result;
409 
410   if (IsTuple()) {
411     for (const HloSharding& element : tuple_elements_) {
412       *result.add_tuple_shardings() = element.ToProto();
413     }
414     result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
415     return result;
416   }
417 
418   for (int64 dim : tile_assignment_.dimensions()) {
419     result.add_tile_assignment_dimensions(dim);
420   }
421   for (auto device : tile_assignment_) {
422     result.add_tile_assignment_devices(device);
423   }
424   if (IsReplicated()) {
425     result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
426   } else if (IsTileMaximal()) {
427     result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
428   } else {
429     result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
430   }
431   return result;
432 }
433 
TileShape(const Shape & shape) const434 Shape HloSharding::TileShape(const Shape& shape) const {
435   if (IsTileMaximal()) {
436     return shape;
437   }
438   Shape result_shape = shape;
439   for (int64 i = 0; i < shape.dimensions_size(); ++i) {
440     result_shape.set_dimensions(
441         i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
442   }
443   return result_shape;
444 }
445 
GetSubSharding(const Shape & shape,const ShapeIndex & index) const446 HloSharding HloSharding::GetSubSharding(const Shape& shape,
447                                         const ShapeIndex& index) const {
448   CHECK(IsTuple());
449   int64 sharding_index = 0;
450   const Shape* sub_shape = &shape;
451   for (int64 idx : index) {
452     for (int64 i = 0; i < idx; ++i) {
453       sharding_index +=
454           ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
455     }
456     sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
457   }
458   if (sub_shape->IsTuple()) {
459     auto begin_it = tuple_elements_.begin() + sharding_index;
460     std::vector<HloSharding> sub_shardings(
461         begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
462     return HloSharding::Tuple(*sub_shape, sub_shardings);
463   } else {
464     return tuple_elements_[sharding_index];
465   }
466 }
467 
ExtractSingleSharding() const468 absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
469   if (!IsTuple()) {
470     return *this;
471   }
472   if (tuple_elements_.empty()) {
473     return absl::nullopt;
474   }
475   for (int64 i = 1; i < tuple_elements_.size(); ++i) {
476     if (tuple_elements_[0] != tuple_elements_[i]) {
477       return absl::nullopt;
478     }
479   }
480   return tuple_elements_.front();
481 }
482 
Hash() const483 size_t HloSharding::Hash() const {
484   if (tuple_) {
485     size_t h = 0;
486     for (const auto& element : tuple_elements_) {
487       h = tensorflow::Hash64Combine(h, element.Hash());
488     }
489     return h;
490   }
491   if (replicated_) {
492     return 0;
493   }
494   size_t h = 0;
495   for (uint32 v : tile_assignment_) {
496     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
497   }
498   return h;
499 }
500 
operator <<(std::ostream & out,const HloSharding & sharding)501 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
502   out << sharding.ToString();
503   return out;
504 }
505 
506 }  // namespace xla
507