1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO shardings describe how an HLO instruction is split across multiple
17 // computations.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
21 
22 #include <map>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_tree.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 // HLO shardings describe how an HLO instruction is split across multiple
40 // computations.
41 class HloSharding {
42  public:
43   // Creates a trivial sharding that replicates a maximal tile across all
44   // devices.
45   static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {
46     return HloSharding(/*manual=*/false, /*replicated=*/true, metadata);
47   }
48 
49   // Creates a sharding that represents the op is manually partitioned.
50   static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) {
51     return HloSharding(/*manual=*/true, /*replicated=*/false, metadata);
52   }
53 
54   // Creates a sharding that emulates device placement; a tile shape equal to
55   // the input shape (one tile) assigned to a single device.
56   static HloSharding AssignDevice(int64 device_id,
57                                   absl::Span<const OpMetadata> metadata = {});
58 
59   // Creates a new sharding which splits a shape into tiles amongst the devices
60   // specified by `tile_assignment`.
61   static HloSharding Tile(const Array<int64>& tile_assignment,
62                           absl::Span<const OpMetadata> metadata = {}) {
63     return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
64                        metadata);
65   }
66 
67   // Creates a new sharding where data is replicated within each replication
68   // group, and sharded across replication groups according to
69   // group_tile_assignment. Replication group members will be sorted.
70   static HloSharding PartialTile(
71       const Array<int64>& group_tile_assignment,
72       absl::Span<const absl::Span<const int64>> replication_groups,
73       absl::Span<const OpMetadata> metadata = {});
74 
75   // Creates a partially replicated tiled sharding with device-level tile
76   // assignment, where the last dimension is the additional replication
77   // dimension. Replication group members will be sorted.
78   static HloSharding PartialTile(
79       const Array<int64>& tile_assignment_last_dim_replicate,
80       absl::Span<const OpMetadata> metadata = {});
81 
82   // Creates a new sharding which splits a one-dimensional input shape into
83   // `num_tiles` tiles.
84   static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles,
85                             absl::Span<const OpMetadata> metadata = {});
86 
87   // Creates a new sharding for a tuple type. The given ShapeTree must have
88   // elements for every leaf shape contained in the tuple.
89   static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings);
90 
91   // Creates a new sharding for a tuple type. The number of elements in
92   // shardings must match the number of leaf nodes in tuple_shape. For
93   // empty tuples, the shardings array must have one element.
94   static HloSharding Tuple(const Shape& tuple_shape,
95                            absl::Span<const HloSharding> shardings);
96 
97   // Creates a new sharding for a tuple type, with a single input sharding
98   // repeated on each leaf.
99   static HloSharding SingleTuple(const Shape& tuple_shape,
100                                  const HloSharding& sharding);
101 
102   // If shape is an array, returns sharding, otherwise returns the tuple shaped
103   // sharding with all the leaf nodes having the same input sharding.
104   static HloSharding Single(const Shape& shape, const HloSharding& sharding);
105 
106   // Create a new sharding from a protobuf OpSharding.
107   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
108 
109   // Checks whether device is a reserved device number. A reserved device number
110   // has usually a special meaning, with dedicated handling logic.
IsReservedDevice(int64 device)111   static bool IsReservedDevice(int64 device) { return device < 0; }
112 
113   OpSharding ToProto() const;
114 
115   // Note that this string canonically has outer curly braces, e.g.
116   // "{replicated}".
117   string ToString(bool include_metadata = false) const;
118 
119   // Validate that this sharding can be applied to a tensor with shape `shape`.
120   Status Validate(const Shape& shape, int64 num_devices) const;
121 
122   // Returns true if the sharding has tuple type.
IsTuple()123   bool IsTuple() const { return tuple_; }
124 
125   // Returns true if the sharding is trivial: replicate on all devices.
IsReplicated()126   bool IsReplicated() const {
127     if (!IsTuple()) {
128       return replicated_;
129     }
130     return absl::c_all_of(
131         tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
132   }
133 
134   // Returns true if the tile size is the same as the input size.
IsTileMaximal()135   bool IsTileMaximal() const {
136     if (!IsTuple()) {
137       return maximal_;
138     }
139     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
140       return s.IsTileMaximal();
141     });
142   }
143 
144   // Returns whether the sharding represents manual partitioning.
IsManual()145   bool IsManual() const {
146     if (!IsTuple()) {
147       return manual_;
148     }
149     return absl::c_all_of(tuple_elements_,
150                           [](const HloSharding& s) { return s.IsManual(); });
151   }
152 
153   // Returns if the sharding has partial replication and partial sharding. If
154   // true, data is sharded according to other dimensions of tile_assignment(),
155   // but replicated across devices along the last dimension.
ReplicateOnLastTileDim()156   bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; }
157 
158   // Returns true if the sharding defines an operation on the given device.
159   bool UsesDevice(int64 device) const;
160 
161   // Retrieves a histogram of the devices used by the sharding. The returned
162   // map has the device number as key, and the occurrence count as value.
163   // If a sharding does not have a device, it will not be included in the
164   // histogram. The count argument, if not nullptr, will receive the total
165   // number of elements this sharding is made of (one for array, N leaves for
166   // tuples).
167   std::map<int64, int64> UsedDevices(int64* count) const;
168 
169   // Returns the tile that should be executed on the given device.
170   // REQUIRES: !IsTuple()
171   std::vector<int64> TileIndexForDevice(int64 device) const;
172 
173   // Returns the device that should execute the given tile.
174   // It is an error to call this if is_replicated() is true.
175   // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it
176   // returns the first device in that replicated subgroup; otherwise,
177   // index.size() should be the same as tile_assignment()'s rank and specifies
178   // the member of the replication subgroup.
179   // REQUIRES: !IsTuple()
180   int64 DeviceForTileIndex(absl::Span<const int64> index) const;
181 
182   // Given a device ID, returns the offset within the specified shape of the
183   // tile that should be executed on the given core. This returns the lower
184   // extent of the tile in the input space.
185   // REQUIRES: !IsTuple()
186   std::vector<int64> TileOffsetForDevice(const Shape& shape,
187                                          int64 device) const;
188 
189   // Given a device ID, returns the limit within the specified shape of the
190   // tile that should be executed on the given core. This returns the upper
191   // extent of the tile in the input space.
192   // REQUIRES: !IsTuple()
193   std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const;
194 
195   // Returns the single device this op operates on. If the sharding does not
196   // span a single device, the return value will be empty.
197   // In order for a sharding to span a single device, every leaf sharding must
198   // be maximal and not replicated, and the used device must match.
199   absl::optional<int64> UniqueDevice() const;
200 
201   // Retrieves the unique device or fails with a CHECK.
202   int64 GetUniqueDevice() const;
203 
204   // Returns true if this op only uses a single device.
HasUniqueDevice()205   bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
206 
207   // Returns the ShapeTree containing the shardings for each element of this
208   // tuple, if IsTuple, or a ShapeTree with a single element containing this
209   // sharding. Only the leaf elements are populated. This creates a new
210   // ShapeTree object so is not cheap.
211   StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const;
GetAsShapeTree(const Shape & shape)212   ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
213     return AsShapeTree(shape).ValueOrDie();
214   }
215 
216   // Retrieves the sub sharding at a given index, out of a tuple sharding.
217   // REQUIRES: IsTuple()
218   HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
219 
220   // If the current sharding is a tuple sharding, return itself as result.
221   // Otherwise returns a tuple sharding for the input shape, with all the leaves
222   // having this object sharding.
223   StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
224 
225   // Extracts the sharding that is common within the current sharding.
226   // If the current sharding is not a tuple sharding, the current sharding will
227   // be returned. If it is a tuple, and all the tuple elements are common, the
228   // common element will be returned. Otherwise the optional will contain no
229   // value.
230   absl::optional<HloSharding> ExtractSingleSharding() const;
231 
232   // Returns a copy of the sharding with no metadata. If sharding is of tuple
233   // type, sub shardings will have no metadata.
234   HloSharding WithoutMetadata() const;
235 
236   // Returns a copy of the sharding with specified metadata. If metadata is
237   // already present, that metadata will not be replaced unless `overwrite` is
238   // set to true. If sharding is of tuple type, sub shardings metadata will be
239   // assigned instead.
240   HloSharding WithMetadata(absl::Span<const OpMetadata> metadata,
241                            bool overwrite) const;
242 
243   bool operator==(const HloSharding& other) const {
244     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
245            manual_ == other.manual_ &&
246            tile_assignment_ == other.tile_assignment_ &&
247            tuple_elements_ == other.tuple_elements_ &&
248            replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_;
249   }
250   bool operator!=(const HloSharding& other) const { return !(*this == other); }
251 
252   size_t Hash() const;
253 
254   struct Hasher {
operatorHasher255     size_t operator()(const HloSharding& sharding) const {
256       return sharding.Hash();
257     }
258   };
259 
260   // Gets the tile assignment tensor.
261   // REQUIRES: !IsReplicated() && !IsTuple()
tile_assignment()262   const Array<int64>& tile_assignment() const { return tile_assignment_; }
263 
264   // Returns the flattened list of all the leaf shardings in a tuple shape, by
265   // pre-order walk (ShapeTree iterator order).
266   // REQUIRES: IsTuple().
tuple_elements()267   std::vector<HloSharding>& tuple_elements() { return tuple_elements_; }
tuple_elements()268   const std::vector<HloSharding>& tuple_elements() const {
269     return tuple_elements_;
270   }
271 
272   // Gets the tile shape.
273   // REQUIRES: !IsTuple()
274   Shape TileShape(const Shape& shape) const;
275 
276   // Gets the tile shape on the device.
277   // REQUIRES: !IsTuple()
278   Shape TileShape(const Shape& shape, int64 device) const;
279 
280   // Gets the number of tiles. If it has partial replication, this will not
281   // equal the device count.
282   int64 NumTiles() const;
283   // Like NumTiles() but considers only some specific dimensions passed as
284   // argument
285   int64 NumTiles(absl::Span<const int64> dims) const;
286 
287   // Gets metadata from sharding.
metadata()288   std::vector<OpMetadata>& metadata() { return metadata_; }
metadata()289   const std::vector<OpMetadata>& metadata() const { return metadata_; }
290 
291  private:
HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)292   explicit HloSharding(bool manual, bool replicated,
293                        absl::Span<const OpMetadata> metadata)
294       : replicated_(replicated),
295         maximal_(replicated),
296         tuple_(false),
297         manual_(manual),
298         tile_assignment_({0}),
299         replicate_on_last_tile_dim_(false),
300         metadata_(metadata.begin(), metadata.end()) {}
301   // device_id values:
302   // -2: magic number to mean unassigned device, used by spatial partitioning
303   // -1: the id of the host
304   //  0 or positive: the id of a device
305   // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once
306   // we have fully switched to the side-effect tokens.
HloSharding(int64 device_id,absl::Span<const OpMetadata> metadata)307   explicit HloSharding(int64 device_id, absl::Span<const OpMetadata> metadata)
308       : replicated_(false),
309         maximal_(true),
310         tuple_(false),
311         manual_(false),
312         tile_assignment_({1}, device_id),
313         replicate_on_last_tile_dim_(false),
314         metadata_(metadata.begin(), metadata.end()) {}
315   explicit HloSharding(const Array<int64>& tile_assignment,
316                        bool replicate_on_last_tile_dim,
317                        absl::Span<const OpMetadata> metadata = {})
replicated_(false)318       : replicated_(false),
319         maximal_(false),
320         tuple_(false),
321         manual_(false),
322         tile_assignment_(tile_assignment),
323         replicate_on_last_tile_dim_(replicate_on_last_tile_dim),
324         metadata_(metadata.begin(), metadata.end()) {}
HloSharding(const std::vector<HloSharding> & tuple_shardings)325   explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
326       : replicated_(false),
327         maximal_(false),
328         tuple_(true),
329         manual_(false),
330         tile_assignment_({0}),
331         tuple_elements_(tuple_shardings),
332         replicate_on_last_tile_dim_(false) {}
333 
334   // Checks that the number of elements in tuple_elements_ is consistent with
335   // the tuple shape passes as argument.
336   Status CheckLeafCount(const Shape& shape) const;
337 
338   // Internal helper to validate a tuple sharding.
339   Status ValidateTuple(const Shape& shape, int64 num_devices) const;
340 
341   // Internal helper to validate a non-tuple (leaf) sharding.
342   Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
343 
344   // Returns the number of tuple_elements_ entries to fit the shape.
345   static int64 RequiredLeaves(const Shape& shape);
346 
347   bool replicated_;
348   bool maximal_;
349   bool tuple_;
350   bool manual_;
351   // This field is only used if replicated_ is false. If maximal_ is true, then
352   // the field contains a rank 1 array with a single element, which is the
353   // device the HLO is assigned to. If maximal_ is false, the field contains an
354   // array with the same rank as the corresponding HLO. The dimension sizes of
355   // the array describe the number of ways the HLO is partitioned along each
356   // dimension. The values of the array specify which device each tile of
357   // the HLO is assigned to. The index of each value determines which tile it
358   // takes.
359   // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is
360   // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and
361   // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the
362   // tile that contains the 2nd half of dimension 1 and the 1st half of
363   // dimension 3.
364   Array<int64> tile_assignment_;
365   // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
366   // present for the root. This is a flattened list of all the leaf shardings in
367   // a tuple shape, by pre-order walk (ShapeTree iterator order).
368   std::vector<HloSharding> tuple_elements_;
369   // This flag is to support partial replication and partial sharding. If it is
370   // true, tile_assignment_ will have an extra dimension in addition to the data
371   // shape rank, and the added last dimension represents the subgroups of
372   // replications, i.e., elements in slice [..., :] will be replicated.
373   bool replicate_on_last_tile_dim_;
374   // This field is used to track the source of this sharding, usually derived
375   // from instructions. Multple metadata may be populated if sharding is
376   // combined with other shardings. Metadata are to not be populated when
377   // tuple_ == true and instead metadata should be set on individual tuple
378   // elements.
379   std::vector<OpMetadata> metadata_;
380 };
381 
382 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
383 
384 }  // namespace xla
385 
386 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
387