1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <initializer_list>
23 #include <string>
24 
25 #include "absl/base/macros.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/platform/cpu_info.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 
46 // An index for specifying a particular nested subshape within a shape. Used in
47 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
48 // structures (trees) and ShapeIndex defines a path through the tree where each
49 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
50 // shape. For a non-nested tuple, an index has a single element. For example,
51 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
52 // {1} corresponds to array b. For a nested tuple, the index can have more than
53 // one element. For the nested tuple (a, (b, c, d), e) below are the values
54 // corresponding to the given indices:
55 //
56 //   index {0}    : array a
57 //   index {1, 2} : array d
58 //   index {2}    : array e
59 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
60 //
61 // For indexing into array shapes, the index is always trivially empty, ie {}.
62 //
63 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
64 // methods implemented.
65 class ShapeIndex {
66  public:
67   ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)68   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
69   template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)70   ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
71 
empty()72   bool empty() const { return indices_.empty(); }
size()73   size_t size() const { return indices_.size(); }
push_back(int64 value)74   void push_back(int64 value) { indices_.push_back(value); }
pop_back()75   void pop_back() { indices_.pop_back(); }
76 
77   // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64 value)78   void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
79 
80   using container_type = absl::InlinedVector<int64, 2>;
81 
begin()82   container_type::const_iterator begin() const { return indices_.begin(); }
end()83   container_type::const_iterator end() const { return indices_.end(); }
begin()84   container_type::iterator begin() { return indices_.begin(); }
end()85   container_type::iterator end() { return indices_.end(); }
86 
data()87   const int64* data() const { return indices_.data(); }
88 
back()89   int64 back() const { return indices_.back(); }
back()90   int64& back() { return indices_.back(); }
91 
92   const int64& operator[](size_t i) const { return indices_[i]; }
93   int64& operator[](size_t i) { return indices_[i]; }
94 
95   bool operator==(const ShapeIndex& other) const {
96     return indices_ == other.indices_;
97   }
98   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
99   bool operator<(const ShapeIndex& other) const {
100     return indices_ < other.indices_;
101   }
102 
103   string ToString() const;
104 
105   template <typename H>
AbslHashValue(H h,const ShapeIndex & index)106   friend H AbslHashValue(H h, const ShapeIndex& index) {
107     return H::combine(std::move(h), index.indices_);
108   }
109 
110  private:
111   container_type indices_;
112 };
113 
114 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
115 // value at the front of the view.
116 //
117 // NB! ShapeIndexView does not own the memory backing the index array.
118 // The memory backing the index array should be owned by an object
119 // that lives longer than the ShapeIndexView instances pointing into
120 // it.
121 class ShapeIndexView {
122  public:
123   ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
124       : indices_(shape_index.data() + offset, shape_index.size() - offset) {
125     CHECK_LE(offset, shape_index.size());
126   }
ShapeIndexView(std::initializer_list<int64> indices)127   ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
128   ShapeIndexView(const ShapeIndexView& other) = default;
129 
130   using iterator = const int64*;
131 
begin()132   iterator begin() const { return indices_.begin(); }
end()133   iterator end() const { return indices_.end(); }
size()134   int64 size() const { return indices_.size(); }
empty()135   bool empty() const { return indices_.empty(); }
front()136   int64 front() const {
137     CHECK(!empty());
138     return indices_.front();
139   }
ConsumeFront()140   ShapeIndexView ConsumeFront() const {
141     ShapeIndexView result = *this;
142     result.indices_.remove_prefix(1);
143     return result;
144   }
ConsumeBack()145   ShapeIndexView ConsumeBack() const {
146     ShapeIndexView result = *this;
147     result.indices_.remove_suffix(1);
148     return result;
149   }
ToShapeIndex()150   ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
151 
152   bool operator==(const ShapeIndexView& other) const;
153   bool operator!=(const ShapeIndexView& other) const;
154 
155   string ToString() const;
156 
157   // Returns true if this shape index starts with 'prefix'.
158   bool StartsWith(ShapeIndexView prefix) const;
159 
160  private:
161   absl::Span<const int64> indices_;
162 };
163 
164 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
165 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
166 
167 // Namespaced collection of (static) shape utilities.
168 //
169 // These are all effectively convenience functions for testing/tweaking proto
170 // properties, which do invariant checks before / after the operation.
171 class ShapeUtil {
172  public:
173   // Data structure which describes the coordinates and the shape, of a tuple
174   // shaped sub-shape.
175   struct IndexedShape {
176     IndexedShape() = default;
IndexedShapeIndexedShape177     IndexedShape(ShapeIndex index, Shape shape)
178         : index(std::move(index)), shape(std::move(shape)) {}
179     ShapeIndex index;
180     Shape shape;
181   };
182 
183   // Returns the number of elements are contained within the provided shape;
184   // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
185   // may not actually be able to store this number of elements. See
186   // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
187   // elements that can be stored in a sparse shape.
188   // Precondition: shape.IsArray()
189   static int64 ElementsIn(const Shape& shape);
190 
191   // As ElementsIn(), but recurses through tuples.
192   static int64 ElementsInRecursive(const Shape& shape);
193 
194   // Returns true if shape has the primitive type, recurses through tuples.
195   static bool HasPrimitiveType(const Shape& shape,
196                                PrimitiveType primitive_type);
197 
198   // Returns true if 'shape' is an array with zero elements.
199   static bool IsZeroElementArray(const Shape& shape);
200 
201   // Returns the number of bytes required for an allocation of shape.  The
202   // |pointer_size| parameter is used for calculating the size of tuple
203   // shapes. This includes only the size of the top-level buffer. For example, a
204   // tuple is stored as an array of pointers to other buffers. In this case,
205   // this method only returns the size of the pointer array.
206   static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
207 
208   // Returns the number of bytes used to store the primitive_type.
209   //
210   // Precondition: shape.IsArray()
211   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
212 
213   // Returns the number of bytes required to store the tuple member pointers for
214   // a allocation of shape. The `shape` must be a TUPLE shape, and
215   // `pointer_size` must be larger than zero.
216   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
217                                          int64 pointer_size);
218 
219   // Returns the number of bytes required for the elements in an allocation of
220   // `shape`, which must be an array shape. The return value does not include
221   // the bytes needed to store sparse indices. Dense shapes use a separate
222   // memory location for each element, and so for these shapes,
223   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
224   // size also includes padding if present in the layout. For sparse shapes,
225   // `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
226   // ByteSizeOfSparseindices(shape)`.
227   static int64 ByteSizeOfElements(const Shape& shape);
228 
229   // Returns the number of bytes required for the sparse indices in an
230   // allocation of shape. The shape must be an array shape. The return value
231   // does not include the bytes needed to store sparse indices.
232   static int64 ByteSizeOfSparseIndices(const Shape& shape);
233 
234   // Returns a human-readable string that represents the given shape, with or
235   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
236   static string HumanString(const Shape& shape);
237   static string HumanStringWithLayout(const Shape& shape);
238 
239   // As above, but for program shapes, returns a string for the form:
240   //
241   // (param_name: f32[42x12], ...) -> f32[24x42]
242   static string HumanString(const ProgramShape& program_shape);
243 
244   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
245   // not check element type.
246   // Precondition: IsArray(lhs) && IsArray(rhs)
247   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
248 
249   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)250   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
251     return lhs.element_type() == rhs.element_type();
252   }
253 
254   // As SameElementType, but allows floating point types to have different
255   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)256   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
257                                                  const Shape& b) {
258     if (ElementIsFloating(a) && ElementIsFloating(b)) {
259       return true;
260     }
261     return ShapeUtil::SameElementType(a, b);
262   }
263 
264   // Returns the higher-precision element type if a and b are both floating
265   // point types; otherwise, checks that they have the same element type
266   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)267   static PrimitiveType HigherPrecisionElementType(const Shape& a,
268                                                   const Shape& b) {
269     if (SameElementType(a, b)) {
270       return a.element_type();
271     }
272     CHECK(SameElementTypeIgnoringFpPrecision(a, b));
273     return primitive_util::BitWidth(a.element_type()) <
274                    primitive_util::BitWidth(b.element_type())
275                ? b.element_type()
276                : a.element_type();
277   }
278 
279   // Returns true if the rank, dimension sizes, and element type are
280   // identical. Layout is ignored. Tuple elements are compared recursively for
281   // compatibility.
282   static bool Compatible(const Shape& lhs, const Shape& rhs);
283 
284   // Returns true if the rank and dimension sizes are identical. Element type
285   // and layout are ignored. Tuple elements are compared recursively for
286   // compatibility.
287   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
288 
289   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
290   // being F32. Tuple elements are compared recursively for compatibility.
291   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
292 
293   // Returns whether the lhs and rhs shapes are identical.
294   static bool Equal(const Shape& lhs, const Shape& rhs);
295 
296   // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
297   static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
298 
299   // Returns the number of dimensions for which the dimension is not (trivially)
300   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
301   // fluff. Note that zero dimensions are included in the true rank, e.g.,
302   // f32[3,0,1] has a true rank of 2D.
303   static int64 TrueRank(const Shape& shape);
304 
305   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
306                                        Shape result);
307 
308   ////////////////////
309   // Scalar-specific
310 
IsScalar(const Shape & shape)311   static bool IsScalar(const Shape& shape) {
312     return shape.IsArray() && shape.rank() == 0;
313   }
IsEffectiveScalar(const Shape & shape)314   static bool IsEffectiveScalar(const Shape& shape) {
315     return shape.IsArray() && TrueRank(shape) == 0;
316   }
317 
318   // Returns whether "shape" is a scalar (array) with the given element_type.
319   static bool IsScalarWithElementType(const Shape& shape,
320                                       PrimitiveType element_type);
321 
322   // Extracts the size of the shape's dimension at dimension number
323   // GetDimensionNumber(dimension_number).
324   static int64 GetDimension(const Shape& shape, int64 dimension_number);
325 
326   // Resolves a dimension number, supporting negative indexing.
327   //
328   // Negative indexing has similar semantics to Python. For an N-dimensional
329   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
330   // N-2, and so on.
331   //
332   // This function always returns a positive dimension number for any given
333   // dimension_number (which itself can be negative).
334   static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
335 
336   // Returns a shape with the same dimensions as the original, but with the
337   // element type changed to type.
338   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
339 
340   // Creates a tuple shape from a slice of element shapes within the tuple.
341   static Shape MakeTupleShape(absl::Span<const Shape> shapes);
342 
343   // Creates an opaque shape. These are generally used for threading a context
344   // into a custom operation.
345   static Shape MakeOpaqueShape();
346 
347   // Creates a token shape. Values of this shape are used for ordering
348   // side-effecting operations.
349   static Shape MakeTokenShape();
350 
351   // Appends a shape to the given tuple.
352   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
353 
354   // Appends a major dimension to the shape with the given bound.
355   static void AppendMajorDimension(int bound, Shape* shape);
356 
357   // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()358   static Shape MakeNil() { return MakeTupleShape({}); }
359 
360   // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)361   static bool IsInitialized(const Shape& shape) {
362     return shape.element_type() != PRIMITIVE_TYPE_INVALID;
363   }
364 
365   // Constructs a new shape with the given element type and sequence of
366   // dimensions.
367   static Shape MakeShape(PrimitiveType element_type,
368                          absl::Span<const int64> dimensions);
369 
370   // Constructs a new shape with the given element type and sequence of
371   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
372   // with a true value that the respective dimension is dynamic. If the
373   // dimension is dynamic then the respective value in 'dimension' is an upper
374   // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
375   // the same size.
376   static Shape MakeShape(PrimitiveType element_type,
377                          absl::Span<const int64> dimensions,
378                          const std::vector<bool>& dynamic_dimensions);
379 
380   // Constructs a new shape with the given element type and sequence of
381   // dimensions. Method checks if the element type is valid and the shape's
382   // size fits in std::numeric_limits<int64>::max().
383   static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
384                                             absl::Span<const int64> dimensions);
385   static StatusOr<Shape> MakeValidatedShape(
386       PrimitiveType element_type, absl::Span<const int64> dimensions,
387       const std::vector<bool>& dynamic_dimensions);
388 
389   // Creates a Shape with element type corresponding to T and the given
390   // dimensions
391   template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)392   static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
393     return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
394                                 dimensions);
395   }
396 
397   // Constructs a new shape with the given minor_to_major order in its Layout.
398   // Returns a value shape such that shape.has_layout().
399   static Shape MakeShapeWithLayout(PrimitiveType element_type,
400                                    absl::Span<const int64> dimensions,
401                                    absl::Span<const int64> minor_to_major,
402                                    absl::Span<const Tile> tiles = {},
403                                    int64 element_size_in_bits = 0);
404 
405   static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
406                                          absl::Span<const int64> dimensions,
407                                          int64 max_sparse_elements);
408 
409   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
410   static Shape MakeShapeWithDescendingLayout(
411       PrimitiveType element_type, absl::Span<const int64> dimensions);
412 
413   // Returns a new Shape based on the given Shape with low-dimension-major
414   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
415   // rearranged so that it has the same in-memory layout as the given shape.
416   //
417   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
418   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
419       const Shape& shape);
420 
421   // As MakeShape, but the object to write to is passed in.
422   static Status PopulateShape(PrimitiveType element_type,
423                               absl::Span<const int64> dimensions, Shape* shape);
424 
425   // Validates that the provided shape satisfies invariants.
426   static Status ValidateShape(const Shape& shape);
427 
428   // Validates the provided shape satisfies invariants, except those that
429   // pertain to layout.
430   //
431   // Layout is optional for client-provided shapes, so that the compiler may
432   // determine and assign an optimized layout.
433   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
434 
435   // Returns whether the element type of the shape is integral (signed or
436   // unsigned). Note that predicates are not considered integral here, since
437   // they are logical values.
438   static bool ElementIsIntegral(const Shape& shape);
439 
440   // Returns whether the element type of the shape is floating point.
441   static bool ElementIsFloating(const Shape& shape);
442 
443   // Returns whether the element type of the shape is complex.
444   static bool ElementIsComplex(const Shape& shape);
445 
446   // Returns whether the element type has the given bit width.
447   static bool ElementHasBitWidth(const Shape& shape, int bits);
448 
449   // Returns whether the element type of the shape is integral and has
450   // the specified number of bits.
451   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
452 
453   // Returns whether the element type of the shape is signed. Note
454   // that floating point numbers are signed.
455   static bool ElementIsSigned(const Shape& shape);
456 
457   // Returns whether the given primitive type corresponds to an array shape.
458   static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
459 
460   // Returns whether the shape is a tuple with at least one element which is
461   // also a tuple.
462   static bool IsNestedTuple(const Shape& shape);
463 
464   // Returns true if shape is an empty tuple.
465   static bool IsEmptyTuple(const Shape& shape);
466 
467   // Returns the number of elements in the given tuple shape.
468   // Precondition: IsTuple(shape)
469   static int64 TupleElementCount(const Shape& shape);
470 
471   // Returns the tuple element shape at given index.
472   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
473   static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
474 
475   // Returns the number of elements, recursively, in the given shape.
476   static int64 SubshapeCount(const Shape& shape);
477 
478   // Slices tuple elements in the range [start, limit) and returns a new tuple
479   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
480   static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
481 
482   // Returns the shape of the real/imaginary components of the given complex
483   // shape.
484   static Shape ComplexComponentShape(const Shape& complex_shape);
485 
486   // Returns true if the given shape has a subshape at the given index.
487   static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
488 
489   // GetSubshape and GetMutableSubshape return a particular nested Shape within
490   // the given Shape argument. The non-Try variants check fail if index is
491   // invalid.
492   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
493   static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
494                                                ShapeIndexView index);
495   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
496 
497   // Returns whether the given index in the given shape is a leaf element of the
498   // shape.
499   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
500 
501   // Returns the number of leaves in the shape.
502   static int64 GetLeafCount(const Shape& shape);
503 
504   // Retrieves all the leaf shapes and their indexes, in the order walked by
505   // the ForEachSubshape() API.
506   static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
507 
508   // Calls the given visitor function for each subshape of the given shape.
509   // Subshapes are visited in DFS pre-order starting with the entire shape
510   // (index {}).
511   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
512                                              const ShapeIndex& /*index*/)>;
513   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
514   using MutatingVisitorFunction =
515       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
516   static void ForEachMutableSubshape(Shape* shape,
517                                      const MutatingVisitorFunction& func);
518 
519   // Variants of ForEach(Mutable)Subshape which propagate Status from the
520   // visitor function.
521   using StatusVisitorFunction = std::function<Status(
522       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
523   static Status ForEachSubshapeWithStatus(const Shape& shape,
524                                           const StatusVisitorFunction& func);
525   using MutatingStatusVisitorFunction =
526       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
527   static Status ForEachMutableSubshapeWithStatus(
528       Shape* shape, const MutatingStatusVisitorFunction& func);
529 
530   // Returns true if `shape` (which must be an array) with degenerate dimensions
531   // (dimensions with bound 1).
532   static bool HasDegenerateDimensions(const Shape& shape);
533 
534   // Drops any degenerate dimensions (i.e. dimensions of size 1)
535   static Shape DropDegenerateDimensions(const Shape& shape);
536 
537   // Permutes the dimensions by the given permutation, so
538   // return_value.dimensions[permutation[i]] = argument.dimensions[i].
539   //
540   // Postcondition: For any valid permutation,
541   //
542   //   !HasLayout(shape) ||
543   //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
544   //                      InversePermutation(permutation)).
545   static Shape PermuteDimensions(absl::Span<const int64> permutation,
546                                  const Shape& shape);
547 
548   // If we can go from `shape_pre` to `shape_post` by merely inserting or
549   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
550   // deleted dimensions and the indices in `dims_post` of the inserted
551   // dimensions.
552   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
553   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
554   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
555   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
556   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
557   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
558   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
559   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
560                                     const Shape& shape_post);
561 
562   // Suppose a reshape transforms input_shape to output shape. Returns a vector
563   // of pairs that indicate the input and output dimensions that this reshape
564   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
565   // the returned vector, the reshape transforms any input index whose I-th
566   // dimension is x to an output index whose O-th dimension is x too.
567   //
568   // Post-condition: the returned vector is sorted (by both input and output
569   // dimensions because input and output dimensions have the same order).
570   //
571   // Example:
572   //   input  shape = T[a, b, x, y, cd]
573   //   output shape = T[ab, x, 1, y, c, d]
574   //   return value = {{2, 1}, {3, 3}}
575   //
576   //   The two pairs represent the input and output dimension of size x and
577   //   those of size y.
578   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
579       const Shape& input_shape, const Shape& output_shape);
580 
581   // Returns whether a transpose from input_shape to output_shape with dimension
582   // mapping "dimension_mapping" produces a result which is bit-wise identical
583   // to its input and thus may be replaced with a bitcast.
584   //
585   // Precondition: Both input_shape and output_shape have explicit layouts.
586   static bool TransposeIsBitcast(const Shape& input_shape,
587                                  const Shape& output_shape,
588                                  absl::Span<const int64> dimension_mapping);
589 
590   // Returns whether a reshape from "input_shape" to "output_shape" is a
591   // bitcast.
592   //
593   // Precondition: Both input_shape and output_shape have explicit layouts.
594   static bool ReshapeIsBitcast(const Shape& input_shape,
595                                const Shape& output_shape);
596 
597   // Find a physical layout for 'output_shape' such that
598   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
599   // true (where 'output_shape_with_layout' is 'output_shape' with the found
600   // layout). The layout of 'input_shape' is kept fixed. Returns
601   // 'output_shape_with_layout' if such a layout can be found, and an error
602   // otherwise.
603   static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
604                                             const Shape& output_shape);
605 
606   // Returns a shape with the given dimension deleted.
607   // For example:
608   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
609   static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
610 
611   // Returns a shape with all the dimensions of the input shape for which `p`
612   // returns true.
613   // For examples:
614   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
615   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
616   static Shape FilterDimensions(const std::function<bool(int64)>& p,
617                                 Shape shape);
618 
619   // Iterates through all the shape indexes, in minor to major order, starting
620   // from the base indexes, incrementing by the incr steps, up to count
621   // (index[i] < base[i] + count[i]), and calls the visitor_function with the
622   // current index.
623   // The visitor_function visitor function should return true if it wants to
624   // continue, or false otherwise.
625   //
626   // visitor_function must be a callable of type
627   // StatusOr<bool>(Span<int64>) or compatible.
628   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)629   static Status ForEachIndexWithStatus(const Shape& shape,
630                                        absl::Span<const int64> base,
631                                        absl::Span<const int64> count,
632                                        absl::Span<const int64> incr,
633                                        const FnType& visitor_function) {
634     return ForEachIndexInternal(shape, base, count, incr, visitor_function);
635   }
636 
637   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
638   struct IndexIterationSpace {
639     std::vector<int64> index_base;
640     std::vector<int64> index_count;
641     std::vector<int64> index_incr;
642   };
643 
644   template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)645   static Status ForEachIndexWithStatus(
646       const Shape& shape, const IndexIterationSpace& iteration_space,
647       FnTy&& function) {
648     return ShapeUtil::ForEachIndexWithStatus(
649         shape, iteration_space.index_base, iteration_space.index_count,
650         iteration_space.index_incr, std::forward<FnTy>(function));
651   }
652 
653   template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)654   static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
655                            absl::Span<const int64> count,
656                            absl::Span<const int64> incr,
657                            const FnType& visitor_function) {
658     ForEachIndexWithStatus(shape, base, count, incr,
659                            [&](absl::Span<const int64> indices) {
660                              return StatusOr<bool>(visitor_function(indices));
661                            })
662         .IgnoreError();
663   }
664 
665   // These convenience wrappers don't take `base`, `count` and `incr`
666   // explicitly, but iterate over every element in `shape` instead.
667 
668   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)669   static Status ForEachIndexWithStatus(const Shape& shape,
670                                        const FnType& visitor_function) {
671     std::vector<int64> base(shape.dimensions_size());
672     std::vector<int64> incr(shape.dimensions_size(), 1);
673     return ForEachIndexWithStatus(shape, base,
674                                   /*count=*/AsInt64Slice(shape.dimensions()),
675                                   incr, visitor_function);
676   }
677 
678   template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)679   static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
680     ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
681       return StatusOr<bool>(visitor_function(indices));
682     }).IgnoreError();
683   }
684 
685   // A parallel version of ForEachIndex(WithStatus). This can only be used if
686   // the visitor_function is thread-safe and the order of iteration does not
687   // matter.
688   //
689   // visitor_function must be a callable of type
690   // void(Span<int64>) or compatible.
691   template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)692   static void ForEachIndexParallel(const Shape& shape,
693                                    absl::Span<const int64> base,
694                                    absl::Span<const int64> count,
695                                    absl::Span<const int64> incr,
696                                    const FnType& visitor_function) {
697     // The parallel version of ForEachIndexInternal can never fail.
698     CHECK(ForEachIndexInternal(
699               shape, base, count, incr,
700               [&visitor_function](
701                   absl::Span<const int64> indexes) -> StatusOr<bool> {
702                 visitor_function(indexes);
703                 return true;
704               },
705               /*parallel=*/true)
706               .ok());
707   }
708 
709   // Compute a hash for `shape`.
710   static size_t Hash(const Shape& shape);
711 
712  private:
713   // Validates the shape size is sane. This makes sure it's safe to do
714   // calculations in int64 without overflowing.
715   static Status ValidateShapeSize(const Shape& shape);
716 
717   // Validates all of the non-layout properties of the shape -- this is a helper
718   // used by both the layout-optional and layout-required public method.
719   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
720 
721   template <typename FnType>
722   static Status ForEachIndexInternal(const Shape& shape,
723                                      absl::Span<const int64> base,
724                                      absl::Span<const int64> count,
725                                      absl::Span<const int64> incr,
726                                      const FnType& visitor_function,
727                                      bool parallel = false) {
728     if (ShapeUtil::IsZeroElementArray(shape)) {
729       return Status::OK();
730     }
731     CHECK_EQ(shape.rank(), base.size());
732     CHECK_EQ(incr.size(), base.size());
733     CHECK_EQ(count.size(), base.size());
734     const int64 rank = LayoutUtil::MinorToMajor(shape).size();
735     // Allows handling R0 arrays, such that the visitor function will be called
736     // once with the proper empty indexes.
737     int64 n = -1;
738     std::vector<int64> indexes(base.begin(), base.end());
739     const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
740     absl::optional<tensorflow::thread::ThreadPool> pool;
741     if (parallel) {
742       pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
743     }
744 
745     tensorflow::mutex mu;
746     Status status;  // Guarded by mu
747 
748     while (n < rank) {
749       if (pool != absl::nullopt) {
750         pool->Schedule([indexes, &visitor_function, &mu, &status] {
751           StatusOr<bool> result = visitor_function(indexes);
752           if (!result.ok()) {
753             tensorflow::mutex_lock lock(mu);
754             status = status.ok() ? result.status() : status;
755           }
756         });
757       } else {
758         TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
759         if (!should_continue) {
760           break;
761         }
762       }
763       // Increments dimensions in minor to major order.
764       for (n = 0; n < rank; ++n) {
765         int64 dim = LayoutUtil::Minor(shape.layout(), n);
766         indexes[dim] += incr[dim];
767         if (indexes[dim] < base[dim] + count[dim]) {
768           break;
769         }
770         indexes[dim] = base[dim];
771       }
772     }
773 
774     // Waits for the scheduled work to complete.
775     pool.reset();
776     return status;
777   }
778 
779   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
780 };
781 
782 }  // namespace xla
783 
784 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
785