1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape_util.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/numbers.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/strings/strip.h"
33 #include "absl/types/optional.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/overflow_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/regexp.h"
48 
49 namespace xla {
50 
51 using absl::StrAppend;
52 using absl::StrCat;
53 
ToString() const54 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
55 
ToString() const56 string ShapeIndexView::ToString() const {
57   return StrCat("{", absl::StrJoin(indices_, ","), "}");
58 }
59 
operator ==(const ShapeIndexView & other) const60 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
61   return indices_ == other.indices_;
62 }
63 
operator !=(const ShapeIndexView & other) const64 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
65   return !(*this == other);
66 }
67 
operator <<(std::ostream & out,const ShapeIndex & shape_index)68 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
69   out << shape_index.ToString();
70   return out;
71 }
72 
operator <<(std::ostream & out,const ShapeIndexView & shape_index)73 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
74   out << shape_index.ToString();
75   return out;
76 }
77 
StartsWith(ShapeIndexView prefix) const78 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
79   return size() >= prefix.size() &&
80          indices_.subspan(0, prefix.size()) == prefix.indices_;
81 }
82 
IsArrayPrimitiveType(PrimitiveType primitive_type)83 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
84     PrimitiveType primitive_type) {
85   return primitive_util::IsArrayType(primitive_type);
86 }
87 
88 namespace {
89 // Constructs and returns the new shape with the given minor_to_major order in
90 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)91 StatusOr<Shape> MakeShapeWithLayoutInternal(
92     PrimitiveType element_type, absl::Span<const int64> dimensions,
93     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
94     int64 element_size_in_bits) {
95   if (dimensions.size() != minor_to_major.size()) {
96     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
97                            dimensions.size(), minor_to_major.size());
98   }
99   if (element_type == OPAQUE || element_type == TUPLE) {
100     return InvalidArgument("Unsupported element type: %s",
101                            PrimitiveType_Name(element_type));
102   }
103   TF_ASSIGN_OR_RETURN(Shape shape,
104                       ShapeUtil::MakeValidatedShape(element_type, dimensions));
105   *shape.mutable_layout() =
106       LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits);
107   if (!shape.has_layout()) {
108     return InvalidArgument("Shape has no layout.");
109   }
110   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
111   return shape;
112 }
113 }  // namespace
114 
Equal(const Shape & lhs,const Shape & rhs)115 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
116   bool equal = Shape::Equal()(lhs, rhs);
117 
118   if (!equal && VLOG_IS_ON(3)) {
119     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
120             << ", rhs = " << rhs.ShortDebugString();
121   }
122 
123   return equal;
124 }
125 
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)126 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
127                                                       const Shape& rhs) {
128   bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
129   if (!equal && VLOG_IS_ON(3)) {
130     VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
131             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
132   }
133 
134   return equal;
135 }
136 
TrueRank(const Shape & shape)137 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
138   int64 accum = 0;
139   for (int64 dimension : shape.dimensions()) {
140     // We do not count zero dimensions.
141     if (dimension != 1) {
142       accum += 1;
143     }
144   }
145   return accum;
146 }
147 
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)148 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
149     std::initializer_list<Shape> parameters, Shape result) {
150   ProgramShape program_shape;
151   for (const Shape& shape : parameters) {
152     *program_shape.add_parameters() = shape;
153   }
154   *program_shape.mutable_result() = std::move(result);
155   return program_shape;
156 }
157 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)158 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
159                                         absl::Span<const int64> dimensions) {
160   return MakeValidatedShape(element_type, dimensions).ValueOrDie();
161 }
162 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)163 /* static */ Shape ShapeUtil::MakeShape(
164     PrimitiveType element_type, absl::Span<const int64> dimensions,
165     const std::vector<bool>& dynamic_dimensions) {
166   return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
167       .ValueOrDie();
168 }
169 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)170 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
171     PrimitiveType element_type, absl::Span<const int64> dimensions) {
172   CHECK(IsArrayPrimitiveType(element_type)) << element_type;
173   Shape result;
174   TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
175   return result;
176 }
177 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)178 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
179     PrimitiveType element_type, absl::Span<const int64> dimensions,
180     const std::vector<bool>& dynamic_dimensions) {
181   TF_ASSIGN_OR_RETURN(Shape shape,
182                       MakeValidatedShape(element_type, dimensions));
183   for (int i = 0; i < dynamic_dimensions.size(); ++i) {
184     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
185   }
186   return shape;
187 }
188 
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)189 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
190     PrimitiveType element_type, absl::Span<const int64> dimensions,
191     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
192     int64 element_size_in_bits) {
193   return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
194                                      tiles, element_size_in_bits)
195       .ValueOrDie();
196 }
197 
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)198 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
199     PrimitiveType element_type, absl::Span<const int64> dimensions) {
200   std::vector<int64> layout(dimensions.size());
201   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
202   return MakeShapeWithLayout(element_type, dimensions, layout);
203 }
204 
MakeShapeWithSparseLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,int64 max_sparse_elements)205 /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
206     PrimitiveType element_type, absl::Span<const int64> dimensions,
207     int64 max_sparse_elements) {
208   CHECK(IsArrayPrimitiveType(element_type));
209   Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
210   *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
211   TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
212   return shape;
213 }
214 
215 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)216 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
217     const Shape& shape) {
218   std::vector<int64> dims(shape.dimensions_size());
219   for (int i = 0; i < shape.dimensions_size(); ++i) {
220     dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
221   }
222   return MakeShapeWithDescendingLayout(shape.element_type(), dims);
223 }
224 
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)225 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
226                                              absl::Span<const int64> dimensions,
227                                              Shape* shape) {
228   shape->Clear();
229   shape->set_element_type(element_type);
230   for (int64 dimension : dimensions) {
231     shape->add_dimensions(dimension);
232   }
233   LayoutUtil::SetToDefaultLayout(shape);
234   return ValidateShape(*shape);
235 }
236 
MakeTupleShape(absl::Span<const Shape> shapes)237 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
238   Shape result;
239   result.set_element_type(TUPLE);
240   result.mutable_tuple_shapes()->reserve(shapes.size());
241   for (const auto& shape : shapes) {
242     AppendShapeToTuple(shape, &result);
243   }
244   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
245   return result;
246 }
247 
MakeOpaqueShape()248 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
249   Shape result;
250   result.set_element_type(OPAQUE);
251   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
252   return result;
253 }
254 
MakeTokenShape()255 /* static */ Shape ShapeUtil::MakeTokenShape() {
256   Shape result;
257   result.set_element_type(TOKEN);
258   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
259   return result;
260 }
261 
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)262 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
263                                                 Shape* tuple_shape) {
264   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
265   *tuple_shape->add_tuple_shapes() = shape;
266 }
267 
AppendMajorDimension(int bound,Shape * shape)268 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
269   CHECK(LayoutUtil::IsDenseArray(*shape));
270   shape->mutable_layout()->add_minor_to_major(shape->rank());
271   shape->add_dimensions(bound);
272   TF_DCHECK_OK(ValidateShape(*shape));
273 }
274 
ElementIsIntegral(const Shape & shape)275 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
276   return primitive_util::IsIntegralType(shape.element_type());
277 }
278 
ElementIsIntegralWithBits(const Shape & shape,int32 bits)279 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
280                                                        int32 bits) {
281   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
282 }
283 
ElementHasBitWidth(const Shape & shape,int bits)284 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
285   if (!shape.IsArray()) {
286     return false;
287   }
288   return primitive_util::BitWidth(shape.element_type()) == bits;
289 }
290 
ElementIsSigned(const Shape & shape)291 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
292   switch (shape.element_type()) {
293     case S8:
294     case S16:
295     case S32:
296     case S64:
297     case F16:
298     case BF16:
299     case F32:
300     case F64:
301       return true;
302 
303     case PRED:
304     case U8:
305     case U16:
306     case U32:
307     case U64:
308     case C64:
309     case C128:
310     case TUPLE:
311     case OPAQUE:
312     case TOKEN:
313       return false;
314 
315     default:
316       LOG(FATAL) << "Unhandled element type " << shape.element_type();
317   }
318 }
319 
ElementIsComplex(const Shape & shape)320 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
321   return primitive_util::IsComplexType(shape.element_type());
322 }
323 
ElementIsFloating(const Shape & shape)324 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
325   return primitive_util::IsFloatingPointType(shape.element_type());
326 }
327 
IsNestedTuple(const Shape & shape)328 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
329   return shape.IsTuple() &&
330          absl::c_any_of(shape.tuple_shapes(),
331                         [](const Shape& s) { return s.IsTuple(); });
332 }
333 
IsEmptyTuple(const Shape & shape)334 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
335   return shape.IsTuple() && TupleElementCount(shape) == 0;
336 }
337 
TupleElementCount(const Shape & shape)338 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
339   CHECK(shape.IsTuple()) << HumanString(shape);
340   return shape.tuple_shapes_size();
341 }
342 
GetTupleElementShape(const Shape & shape,int64 index)343 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
344                                                           int64 index) {
345   CHECK(shape.IsTuple());
346   CHECK_GT(TupleElementCount(shape), index);
347   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
348   return shape.tuple_shapes(index);
349 }
350 
SubshapeCount(const Shape & shape)351 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
352   int64 n = 0;
353   ForEachSubshape(shape, [&](const Shape& literal_subshape,
354                              const ShapeIndex& index) { ++n; });
355   return n;
356 }
357 
SliceTuple(const Shape & tuple,int64 start,int64 limit)358 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
359                                          int64 limit) {
360   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
361   CHECK(tuple.IsTuple());
362   CHECK_LE(start, TupleElementCount(tuple));
363   CHECK_LE(limit, TupleElementCount(tuple));
364 
365   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
366                                   tuple.tuple_shapes().begin() + limit);
367   return MakeTupleShape(new_elements);
368 }
369 
370 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)371 /* static */ Shape ShapeUtil::ComplexComponentShape(
372     const Shape& complex_shape) {
373   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
374   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
375                                               complex_shape.element_type()));
376 }
377 
ElementsIn(const Shape & shape)378 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
379   DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
380   DCHECK_EQ(shape.dimensions_size(), shape.rank());
381   if (shape.dimensions().size() == 1) {
382     return shape.dimensions()[0];
383   }
384   return std::accumulate<decltype(shape.dimensions().begin()), int64>(
385       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
386       std::multiplies<int64>());
387 }
388 
ElementsInRecursive(const Shape & shape)389 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
390   CHECK(shape.IsArray() || shape.IsTuple());
391   if (shape.IsArray()) {
392     return ElementsIn(shape);
393   }
394   int64 count = 0;
395   for (const Shape& element_shape : shape.tuple_shapes()) {
396     count += ElementsInRecursive(element_shape);
397   }
398   return count;
399 }
400 
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)401 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
402                                               PrimitiveType primitive_type) {
403   if (shape.element_type() == primitive_type) {
404     return true;
405   }
406   for (const Shape& element_shape : shape.tuple_shapes()) {
407     if (HasPrimitiveType(element_shape, primitive_type)) {
408       return true;
409     }
410   }
411   return false;
412 }
413 
IsZeroElementArray(const Shape & shape)414 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
415   return shape.IsArray() && ElementsIn(shape) == 0;
416 }
417 
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)418 /* static */ bool ShapeUtil::IsScalarWithElementType(
419     const Shape& shape, PrimitiveType element_type) {
420   return IsScalar(shape) && shape.element_type() == element_type;
421 }
422 
HumanString(const Shape & shape)423 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
424   if (shape.IsTuple()) {
425     string text = "(";
426     const char* prefix = "";
427     for (const Shape& elem_shape : shape.tuple_shapes()) {
428       StrAppend(&text, prefix, HumanString(elem_shape));
429       prefix = ", ";
430     }
431     text += ")";
432     return text;
433   }
434   std::vector<string> dim_elements;
435   for (int i = 0; i < shape.dimensions_size(); ++i) {
436     if (shape.is_dynamic_dimension(i)) {
437       dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
438     } else {
439       dim_elements.push_back(StrCat(shape.dimensions(i)));
440     }
441   }
442   return StrCat(
443       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
444       absl::StrJoin(dim_elements, ","), "]");
445 }
446 
HumanStringWithLayout(const Shape & shape)447 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
448   if (shape.IsTuple()) {
449     string text = "(";
450     const char* prefix = "";
451     for (const Shape& elem_shape : shape.tuple_shapes()) {
452       StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
453       prefix = ", ";
454     }
455     text += ")";
456     return text;
457   }
458   string result = StrCat(
459       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[");
460   for (int i = 0; i < shape.dimensions().size(); i++) {
461     StrAppend(&result, (i > 0) ? "," : "",
462               shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i));
463   }
464   result += "]";
465   if (IsScalar(shape)) {
466     string layout_str = LayoutUtil::HumanString(shape.layout());
467     // Don't print "{}" as layout for scalars.
468     if (layout_str != "{}") {
469       StrAppend(&result, layout_str);
470     }
471   } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
472     StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
473   }
474   return result;
475 }
476 
HumanString(const ProgramShape & program_shape)477 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
478   std::vector<string> parameters;
479   for (auto& shape : program_shape.parameters()) {
480     const int i = parameters.size();
481     parameters.push_back(StrCat(i < program_shape.parameter_names_size()
482                                     ? program_shape.parameter_names(i)
483                                     : "(unknown)",
484                                 ": ", HumanString(shape)));
485   }
486   return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
487                 HumanString(program_shape.result()));
488 }
489 
SameDimensions(const Shape & lhs,const Shape & rhs)490 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
491                                             const Shape& rhs) {
492   CHECK(lhs.IsArray());
493   CHECK(rhs.IsArray());
494   return absl::c_equal(lhs.dimensions(), rhs.dimensions());
495 }
496 
Compatible(const Shape & lhs,const Shape & rhs)497 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
498   return Shape::Equal().IgnoreLayout()(lhs, rhs);
499 }
500 
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)501 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
502                                                            const Shape& rhs) {
503   return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs);
504 }
505 
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)506 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
507                                                            const Shape& rhs) {
508   return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs);
509 }
510 
GetDimension(const Shape & shape,int64 dimension_number)511 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
512                                            int64 dimension_number) {
513   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
514 }
515 
GetDimensionNumber(const Shape & shape,int64 dimension_number)516 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
517                                                  int64 dimension_number) {
518   if (dimension_number < 0) {
519     dimension_number += shape.rank();
520   }
521   CHECK_GE(dimension_number, 0);
522   return dimension_number;
523 }
524 
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)525 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
526     PrimitiveType primitive_type) {
527   switch (primitive_type) {
528     case PRED:
529       return sizeof(int8);
530     case S8:
531       return sizeof(int8);
532     case S16:
533       return sizeof(int16);
534     case S32:
535       return sizeof(int32);
536     case S64:
537       return sizeof(int64);
538     case U8:
539       return sizeof(uint8);
540     case U16:
541       return sizeof(uint16);
542     case U32:
543       return sizeof(uint32);
544     case U64:
545       return sizeof(uint64);
546     case BF16:
547       return sizeof(float) / 2;
548     case F16:
549       return sizeof(float) / 2;
550     case F32:
551       return sizeof(float);
552     case F64:
553       return sizeof(double);
554     case C64:
555       return sizeof(complex64);
556     case C128:
557       return sizeof(complex128);
558     case TOKEN:
559       // Tokens require no space.
560       return 0;
561     case TUPLE:
562     case OPAQUE:
563       LOG(FATAL) << PrimitiveType_Name(primitive_type)
564                  << " primitive type has no definitive size";
565     default:
566       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
567   }
568 }
569 
ByteSizeOf(const Shape & shape,int64 pointer_size)570 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
571                                          int64 pointer_size) {
572   TF_DCHECK_OK(ValidateShape(shape));
573   if (shape.element_type() == TUPLE) {
574     return ByteSizeOfTupleIndexTable(shape, pointer_size);
575   } else if (shape.IsArray()) {
576     int64 byte_size = ByteSizeOfElements(shape);
577     if (LayoutUtil::IsSparseArray(shape)) {
578       byte_size += ByteSizeOfSparseIndices(shape);
579     }
580     return byte_size;
581   } else if (shape.element_type() == TOKEN) {
582     return 0;
583   } else if (shape.element_type() == OPAQUE) {
584     CHECK_GT(pointer_size, 0);
585     return pointer_size;
586   }
587   LOG(FATAL) << PrimitiveType_Name(shape.element_type())
588              << " primitive type has no definitive size";
589 }
590 
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)591 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
592                                                         int64 pointer_size) {
593   TF_DCHECK_OK(ValidateShape(shape));
594   CHECK_EQ(TUPLE, shape.element_type());
595   CHECK_GT(pointer_size, 0);
596   return pointer_size * shape.tuple_shapes_size();
597 }
598 
ByteSizeOfElements(const Shape & shape)599 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
600   TF_DCHECK_OK(ValidateShape(shape));
601   CHECK(shape.IsArray());
602   int64 allocated_element_count;
603 
604   if (LayoutUtil::IsSparseArray(shape)) {
605     allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
606   } else {
607     CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
608     allocated_element_count = ElementsIn(shape);
609   }
610   return allocated_element_count *
611          ByteSizeOfPrimitiveType(shape.element_type());
612 }
613 
ByteSizeOfSparseIndices(const Shape & shape)614 /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
615   TF_DCHECK_OK(ValidateShape(shape));
616   CHECK(LayoutUtil::IsSparseArray(shape));
617   return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
618          sizeof(int64);
619 }
620 
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)621 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
622     const Shape& shape) {
623   if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
624       !PrimitiveType_IsValid(shape.element_type())) {
625     return InvalidArgument("shape has invalid element type: %s",
626                            shape.ShortDebugString());
627   }
628   if (shape.element_type() == TUPLE) {
629     if (shape.dimensions_size() != 0) {
630       return InvalidArgument("tuples must not have dimensions specified");
631     }
632     for (auto& element_shape : shape.tuple_shapes()) {
633       TF_RETURN_IF_ERROR(
634           ValidateShapeWithOptionalLayoutInternal(element_shape));
635     }
636     return Status::OK();
637   }
638 
639   // Non-tuple shape.
640   if (shape.tuple_shapes_size() > 0) {
641     return InvalidArgument("non-tuple shape has tuple_shapes field");
642   }
643 
644   // Tokens and opaques can should not have layout or dimensions.
645   if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
646     if (shape.dimensions_size() != 0) {
647       return InvalidArgument(
648           "shape has %s element type, but has dimensions field: %s",
649           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
650           shape.ShortDebugString());
651     }
652     if (shape.has_layout()) {
653       return InvalidArgument(
654           "shape has %s element type, but has layout field: %s",
655           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
656           shape.ShortDebugString());
657     }
658     return Status::OK();
659   }
660 
661   if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) {
662     return InvalidArgument("sparse arrays must have rank > 0");
663   }
664   for (int64 i = 0; i < shape.rank(); ++i) {
665     int64 dimension = shape.dimensions(i);
666     if (dimension < 0) {
667       return InvalidArgument(
668           "shape's dimensions must not be < 0; dimension at index %d was %d", i,
669           dimension);
670     }
671   }
672 
673   TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
674   return Status::OK();
675 }
676 
ValidateShapeSize(const Shape & shape)677 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
678   VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
679 
680   if (!shape.IsArray()) {
681     return Status::OK();
682   }
683 
684   // We can only reason about some aspects of array's shape if it has a valid
685   // layout, these aspects will be ignored otherwise.
686   bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) &&
687                                 LayoutUtil::ValidateLayoutInShape(shape).ok();
688 
689   int64 shape_size = [&]() {
690     if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) {
691       int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout());
692       if (max_sparse_elements < 0) {
693         return max_sparse_elements;
694       }
695       int64 sparse_elements_size = MultiplyWithoutOverflow(
696           max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type()));
697       if (sparse_elements_size < 0) {
698         return sparse_elements_size;
699       }
700       int64 sparse_indices_size =
701           MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
702       if (sparse_indices_size < 0) {
703         return sparse_indices_size;
704       }
705       sparse_indices_size =
706           MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64));
707       if (sparse_indices_size < 0) {
708         return sparse_indices_size;
709       }
710       // At this point, both sparse_indices_size and sparse_elements_size are
711       // non-negative, so we can easily check if adding them wraps.
712       if (static_cast<uint64>(sparse_elements_size) +
713               static_cast<uint64>(sparse_indices_size) >
714           INT64_MAX) {
715         return static_cast<int64>(-1);
716       }
717     }
718 
719     // This is intentionally unconditional: even if the shape is sparse, we want
720     // to verify the densified version has a reasonable size.
721     int64 dense_shape_size = 1;
722     if (shape.dimensions().empty()) {
723       return dense_shape_size;
724     }
725 
726     absl::Span<const int64> shape_max_dimensions =
727         AsInt64Slice(shape.dimensions());
728     for (int64 dim : shape_max_dimensions) {
729       dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
730       if (dense_shape_size < 0) {
731         return dense_shape_size;
732       }
733     }
734     dense_shape_size = MultiplyWithoutOverflow(
735         dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
736     return dense_shape_size;
737   }();
738 
739   if (shape_size < 0) {
740     return InvalidArgument("Shape %s size may overflow int64.",
741                            ShapeUtil::HumanString(shape));
742   }
743 
744   VLOG(3) << "Shape size is valid: " << shape_size;
745   return Status::OK();
746 }
747 
ValidateShapeWithOptionalLayout(const Shape & shape)748 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
749     const Shape& shape) {
750   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
751 
752   return LayoutUtil::ValidateLayoutInShape(shape,
753                                            /*allow_missing_layouts=*/true);
754 }
755 
ValidateShape(const Shape & shape)756 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
757   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
758 
759   return LayoutUtil::ValidateLayoutInShape(shape);
760 }
761 
ChangeElementType(const Shape & original,PrimitiveType type)762 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
763                                                 PrimitiveType type) {
764   Shape new_shape = original;
765   new_shape.set_element_type(type);
766   return new_shape;
767 }
768 
IndexIsValid(const Shape & shape,ShapeIndexView index)769 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
770                                           ShapeIndexView index) {
771   const Shape* subshape = &shape;
772   for (auto i : index) {
773     if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
774       return false;
775     }
776     subshape = &subshape->tuple_shapes(i);
777   }
778   return true;
779 }
780 
GetSubshape(const Shape & shape,ShapeIndexView index)781 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
782                                                  ShapeIndexView index) {
783   const Shape* return_shape = &shape;
784   for (auto i : index) {
785     CHECK(return_shape->IsTuple())
786         << "Invalid index " << index << " for shape " << shape;
787     return_shape = &return_shape->tuple_shapes(i);
788   }
789   return *return_shape;
790 }
791 
TryGetSubshape(const Shape & shape,ShapeIndexView index)792 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
793     const Shape& shape, ShapeIndexView index) {
794   const Shape* return_shape = &shape;
795   for (auto i : index) {
796     if (!return_shape->IsTuple() || i < 0 ||
797         i >= return_shape->tuple_shapes_size()) {
798       return InvalidArgument(
799           "Shape index %s not a valid subshape index for tuple with shape %s",
800           index.ToString(), shape.DebugString());
801     }
802     return_shape = &return_shape->tuple_shapes(i);
803   }
804   return return_shape;
805 }
806 
GetMutableSubshape(Shape * shape,ShapeIndexView index)807 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
808                                                   ShapeIndexView index) {
809   Shape* return_shape = shape;
810   for (auto i : index) {
811     CHECK(return_shape->IsTuple());
812     return_shape = return_shape->mutable_tuple_shapes(i);
813   }
814   return return_shape;
815 }
816 
817 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)818 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
819   return !GetSubshape(shape, index).IsTuple();
820 }
821 
GetLeafCount(const Shape & shape)822 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
823   if (!shape.IsTuple()) {
824     return 1;
825   }
826   int64 count = 0;
827   for (const Shape& subshape : shape.tuple_shapes()) {
828     count += GetLeafCount(subshape);
829   }
830   return count;
831 }
832 
GetLeafShapes(const Shape & shape)833 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
834     const Shape& shape) {
835   std::vector<IndexedShape> leaves;
836   ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
837     if (IsLeafIndex(shape, index)) {
838       leaves.emplace_back(index, sub_shape);
839     }
840   });
841   return leaves;
842 }
843 
HasDegenerateDimensions(const Shape & shape)844 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
845   CHECK(shape.IsArray());
846   return absl::c_linear_search(shape.dimensions(), 1);
847 }
848 
DropDegenerateDimensions(const Shape & shape)849 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
850   return FilterDimensions(
851       [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
852 }
853 
854 namespace {
855 
856 // Helper for ForEachSubshape which visits the subshapes of the given shape in
857 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)858 Status ForEachSubshapeHelper(const Shape& shape,
859                              const ShapeUtil::StatusVisitorFunction& func,
860                              ShapeIndex* index) {
861   TF_RETURN_IF_ERROR(func(shape, *index));
862   if (shape.IsTuple()) {
863     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
864       index->push_back(i);
865       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
866           ShapeUtil::GetTupleElementShape(shape, i), func, index));
867       index->pop_back();
868     }
869   }
870   return Status::OK();
871 }
872 
873 // Helper for ForEachMutableSubshape which visits the subshapes of the given
874 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)875 Status ForEachMutableSubshapeHelper(
876     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
877     ShapeIndex* index) {
878   TF_RETURN_IF_ERROR(func(shape, *index));
879   if (shape->IsTuple()) {
880     for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
881       index->push_back(i);
882       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
883           shape->mutable_tuple_shapes(i), func, index));
884       index->pop_back();
885     }
886   }
887   return Status::OK();
888 }
889 
890 }  // namespace
891 
ForEachSubshape(const Shape & shape,const VisitorFunction & func)892 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
893                                              const VisitorFunction& func) {
894   ShapeIndex index;
895   ForEachSubshapeHelper(
896       shape,
897       [&func](const Shape& subshape, const ShapeIndex& index) {
898         func(subshape, index);
899         return Status::OK();
900       },
901       &index)
902       .IgnoreError();
903 }
904 
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)905 /* static */ void ShapeUtil::ForEachMutableSubshape(
906     Shape* shape, const MutatingVisitorFunction& func) {
907   ShapeIndex index;
908   ForEachMutableSubshapeHelper(
909       shape,
910       [&func](Shape* subshape, const ShapeIndex& index) {
911         func(subshape, index);
912         return Status::OK();
913       },
914       &index)
915       .IgnoreError();
916 }
917 
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)918 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
919     const Shape& shape, const StatusVisitorFunction& func) {
920   ShapeIndex index;
921   return ForEachSubshapeHelper(shape, func, &index);
922 }
923 
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)924 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
925     Shape* shape, const MutatingStatusVisitorFunction& func) {
926   ShapeIndex index;
927   return ForEachMutableSubshapeHelper(shape, func, &index);
928 }
929 
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)930 /* static */ Shape ShapeUtil::PermuteDimensions(
931     absl::Span<const int64> permutation, const Shape& shape) {
932   Shape new_shape = shape;
933   new_shape.clear_dimensions();
934   for (auto dim : Permute(permutation, shape.dimensions())) {
935     new_shape.add_dimensions(dim);
936   }
937   for (int64 i = 0; i < shape.rank(); i++) {
938     new_shape.set_dynamic_dimension(permutation[i],
939                                     shape.is_dynamic_dimension(i));
940   }
941 
942   // If `shape` has a layout, by contract we choose a new layout such that the
943   // transpose defined by this permutation is a bitcast.
944   //
945   // Some formalism helps to understand the correct way to do this.  We're going
946   // to do algebra in the group of permutations of the dimensions of `shape`.
947   //
948   // Since the order of `shape`'s dimensions is not permuted relative to itself,
949   // `shape`'s list of dimensions is isomorphic to the identity I.
950   //
951   // Let `shape`'s layout be L.  A layout is a permutation which maps a
952   // minor-to-major physical layout to the order of a shape's logical dims.
953   // Therefore inverse of a layout maps from logical to physical dims, and so
954   // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
955   //
956   // Let the argument `permutation` be P.  This is a permutation over `shape`'s
957   // dimensions, so our return value will be a shape with dims P.I = P.  Our
958   // goal is to construct a layout permutation L* that we can apply to P such
959   // that the physical dimension ordering of the returned shape is the same
960   // as that of the original shape, namely L'.
961   //
962   // Our returned shape has dims P and layout L*, so its in-memory layout is
963   // L*'.P.  Setting this equal to L' and solving for L*, we get:
964   //
965   //   L*'.P = L'    =>
966   //   L*'   = L'P'  =>
967   //   L*    = P.L
968   //
969   if (shape.has_layout()) {
970     CHECK(LayoutUtil::IsDenseArray(shape));
971     Layout* new_layout = new_shape.mutable_layout();
972     new_layout->set_format(DENSE);
973     new_layout->clear_minor_to_major();
974     for (auto index : ComposePermutations(
975              permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
976       new_layout->add_minor_to_major(index);
977     }
978     // The permutation accepted by TransposeIsBitcast is the inverse of the
979     // permutation here.
980     CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
981         << "shape=" << HumanStringWithLayout(shape)
982         << ", new_shape=" << HumanStringWithLayout(new_shape)
983         << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
984   }
985   return new_shape;
986 }
987 
988 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)989 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
990                                              const Shape& shape_post) {
991   CHECK(shape_pre.IsArray());
992   CHECK(shape_post.IsArray());
993 
994   auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
995 
996   std::vector<int64> deleted_indices;
997   std::vector<int64> inserted_indices;
998   // Returns false if any input/output index between prior_unmodified_dim_pair
999   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1000   // the degerenate input/output dimensions in the gap to
1001   // deleted_indices/inserted_indices respectively.
1002   auto check_modified_dims =
1003       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1004           std::pair<int64, int64> prior_unmodified_dim_pair,
1005           std::pair<int64, int64> unmodified_dim_pair) {
1006         for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1007              modified_input_dim < unmodified_dim_pair.first;
1008              ++modified_input_dim) {
1009           if (shape_pre.dimensions(modified_input_dim) > 1) {
1010             return false;
1011           }
1012           deleted_indices.push_back(modified_input_dim);
1013         }
1014         for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1015              modified_output_dim < unmodified_dim_pair.second;
1016              ++modified_output_dim) {
1017           if (shape_post.dimensions(modified_output_dim) > 1) {
1018             return false;
1019           }
1020           inserted_indices.push_back(modified_output_dim);
1021         }
1022         return true;
1023       };
1024 
1025   std::vector<std::pair<int64, int64>> unmodified_dims =
1026       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1027   // Returns nil if the reshape modifies any non-degenerate input/output
1028   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1029   // dimensions, so we only need to check whether dimensions in the gaps (thus
1030   // modified) have size >1.
1031   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1032     // Check (modified) dimensions between unmodified_dims[i-1] and
1033     // unmodified_dims[i].
1034     auto prior_unmodified_dim_pair =
1035         i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL);
1036     auto unmodified_dim_pair =
1037         i < unmodified_dims.size()
1038             ? unmodified_dims[i]
1039             : std::make_pair(shape_pre.rank(), shape_post.rank());
1040     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1041       return nil;
1042     }
1043   }
1044 
1045   return std::make_tuple(true, deleted_indices, inserted_indices);
1046 }
1047 
1048 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1049 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1050                                          const Shape& output_shape) {
1051   CHECK(input_shape.IsArray());
1052   CHECK(output_shape.IsArray());
1053 
1054   // Unmodified dimensions are merely common factors of rank 1.
1055   auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1056                                       AsInt64Slice(output_shape.dimensions()));
1057   for (size_t i = 0; i < common_factors.size() - 1;) {
1058     if (1 != common_factors[i + 1].first - common_factors[i].first ||
1059         1 != common_factors[i + 1].second - common_factors[i].second) {
1060       common_factors.erase(common_factors.begin() + i);
1061     } else {
1062       ++i;
1063     }
1064   }
1065   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1066   common_factors.pop_back();
1067   return common_factors;
1068 }
1069 
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1070 /* static */ bool ShapeUtil::TransposeIsBitcast(
1071     const Shape& input_shape, const Shape& output_shape,
1072     absl::Span<const int64> dimension_mapping) {
1073   CHECK(LayoutUtil::HasLayout(input_shape) &&
1074         LayoutUtil::HasLayout(output_shape));
1075 
1076   if (!SameElementType(input_shape, output_shape)) {
1077     return false;
1078   }
1079 
1080   // Check the reshape permutes the positions of each dimension in the
1081   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1082   //   input_positions = apply(dimension_mapping, output_positions)
1083   //
1084   // Because the positions of each dimension are the inverse permutation of the
1085   // minor-to-major order, the above check is equivalent to
1086   //   inverse(input_dimensions) =
1087   //       apply(dimension_mapping, inverse(output_dimensions))
1088   //   # `I` indicates identity permutation.
1089   //   apply(input_dimensions, I) =
1090   //       apply(dimension_mapping, apply(output_dimensions, I))
1091   //   apply(input_dimensions, I) =
1092   //       apply((dimension_mapping * output_dimensions), I)
1093   //   input_dimensions = dimension_mapping * output_dimensions
1094   return absl::c_equal(
1095       ComposePermutations(dimension_mapping,
1096                           AsInt64Slice(output_shape.layout().minor_to_major())),
1097       input_shape.layout().minor_to_major());
1098 }
1099 
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1100 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1101                                               const Shape& output_shape) {
1102   CHECK(input_shape.IsArray());
1103   CHECK(output_shape.IsArray());
1104   CHECK(LayoutUtil::HasLayout(input_shape));
1105   CHECK(LayoutUtil::HasLayout(output_shape));
1106 
1107   if (!SameElementType(input_shape, output_shape)) {
1108     return false;
1109   }
1110 
1111   CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
1112   if (ElementsIn(input_shape) == 0) {
1113     return true;
1114   }
1115 
1116   // TL;DR: The rest of the method checks that the reshape does not change the
1117   // physical location of any unit input or output index. Unit indices have
1118   // exactly one dimension that equals 1 and other dimensions 0. This condition
1119   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1120   // reshape shouldn't change the physical location of any element. It is also a
1121   // sufficient condition as is proved below (note: many details are omitted for
1122   // space).
1123   //
1124   // Definitions:
1125   //
1126   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1127   // the size of i-th least significant dimension of IS or OS (this is opposite
1128   // to how we define the index of Shape::dimensions()).
1129   //
1130   // * Given an input or output index I, denote by p(I) I's physical linear
1131   // index (or physical index for short) and l(I) I's logical linear index (or
1132   // logical index for short).
1133   //
1134   // * Given a logical index k, denote by II(k) the input index whose linear
1135   // index is k, and OI(k) the corresponding output index.
1136   //
1137   // * Denote by IT[i] the increment of physical index if i-th dimension of the
1138   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1139   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1140   // is a function of IS or OS and the layout, and not dependent on the specific
1141   // input or output index.
1142   //
1143   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1144   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1145   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1146   // to prove, with every increment on k, the above formula still holds.
1147   //
1148   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1149   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1150   // there exists (i,j) such that
1151   //
1152   //   0<=i<=|IS|
1153   //   0<=j<=|OS|
1154   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1155   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
1156   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1157   //
1158   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1159   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1160   // unit indices which are already tested. This also means IT[0]=OT[0]
1161   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1162   //
1163   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1164   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1165   // to the output physical.
1166   //
1167   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1168   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1169   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1170   // logical index k-1 to logical index k, dimension 1 of the input index
1171   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1172   // IS[0]-1. Therefore, the physical input index is increased by
1173   //
1174   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1175   //
1176   // Because IS[0]<OS[0], the only change to the output index is that its
1177   // dimension 0 is increased by one. Therefore,
1178   //
1179   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1180   //
1181   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1182   // p(II(k))=p(OI(k)). Therefore,
1183   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
1184   //   IT[1] = IS[0] * IT[0]
1185   // In other words, input dimension 1 is immediately more major than input
1186   // dimension 0. We can now conceptually collapse these two dimensions because
1187   // an increment in the logical index affecting only these two dimensions maps
1188   // to IT[0] in the physical index.
1189   //
1190   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1191   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1192   // identical.
1193   //
1194   // A factorizable reshape can be factorized into a list of non-factorizable
1195   // sub-reshapes, each of which can be handled similarly to the proof above.
1196   // For example,
1197   //
1198   //   [7x9x2x15] -> [63x6x5]
1199   //
1200   // can be factorized into
1201   //
1202   //   [7x9] -> [63] and [2x15] -> [6x5].
1203   //
1204   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1205   // same logical linear index. According to the factorization, we know
1206   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1207   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1208   // similar proof, with the increment of the logical index set to
1209   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1210   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1211   //
1212   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1213   //                  = p(y2,0,0) + p(0,0,y1,y0)
1214   //                  = p(y2,y1,y0)
1215   //
1216   // check_input_unit_indices checks one way of the condition: each input unit
1217   // index is mapped to an output index with the same physical location. This
1218   // lambda will be called again with input_shape and output_shape reversed to
1219   // check the other way.
1220   auto check_input_unit_indices = [](const Shape& input_shape,
1221                                      const Shape& output_shape) {
1222     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1223     // as input_shape/output_shape and the dimension-0-major layout. These two
1224     // shapes are used for conversion between logical linear indices and
1225     // multi-dimensional indices.
1226     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1227         input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1228     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1229         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1230 
1231     for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1232       if (input_shape.dimensions(input_dim) <= 1) {
1233         continue;
1234       }
1235 
1236       std::vector<int64> input_unit_index(input_shape.rank(), 0);
1237       input_unit_index[input_dim] = 1;
1238       int64 logical_linear_index =
1239           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1240                                                         input_unit_index);
1241       // output_index has the same logical linear index as input_unit_index.
1242       std::vector<int64> output_index =
1243           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1244                                                         logical_linear_index);
1245       // Check input_unit_index and output_index have the same physical linear
1246       // index.
1247       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1248                                                         input_unit_index) !=
1249           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1250                                                         output_index)) {
1251         return false;
1252       }
1253     }
1254     return true;
1255   };
1256   return check_input_unit_indices(input_shape, output_shape) &&
1257          check_input_unit_indices(output_shape, input_shape);
1258 }
1259 
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1260 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1261     const Shape& input_shape, const Shape& output_shape) {
1262   CHECK(input_shape.IsArray());
1263   CHECK(output_shape.IsArray());
1264   // Removing trivial dimensions from the shape simplifies the alignment
1265   // algorithm since ones can go in any position.
1266   if (HasDegenerateDimensions(input_shape) ||
1267       HasDegenerateDimensions(output_shape)) {
1268     auto simple_output_shape =
1269         AlignLayouts(DropDegenerateDimensions(input_shape),
1270                      DropDegenerateDimensions(output_shape));
1271     if (!simple_output_shape) {
1272       return absl::nullopt;
1273     }
1274 
1275     auto layout = simple_output_shape->layout().minor_to_major();
1276     // For each one sized dimension in the output, increment the dimension
1277     // numbers in layout that are more minor than the one.
1278     absl::InlinedVector<int64, 8> dim_map;
1279     dim_map.reserve(simple_output_shape->rank());
1280     for (int64 i = 0; i < output_shape.rank(); ++i) {
1281       if (output_shape.dimensions(i) != 1) {
1282         dim_map.push_back(i);
1283       }
1284     }
1285     for (int64& d : layout) {
1286       d = dim_map[d];
1287     }
1288 
1289     // Add the ones in descending order to the layout. Descending layouts tend
1290     // to reduce the number of copies inserted in layout assignment.
1291     for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1292       if (output_shape.dimensions(i) == 1) {
1293         layout.push_back(i);
1294       }
1295     }
1296     Shape output_shape_with_layout = output_shape;
1297     *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() =
1298         layout;
1299     return output_shape_with_layout;
1300   }
1301 
1302   int64 input_rank = input_shape.rank();
1303   int64 output_rank = output_shape.rank();
1304 
1305   // First, calculate an alignment of the dimensions. A consecutive sequence of
1306   // input dimensions and output dimensions belong to the same alignment part if
1307   // the products of their dimension bounds are the same. In the easiest case,
1308   // an alignment part consists of one input dimension and one output dimension
1309   // which both have the same dimension bound. An alignment part specifies which
1310   // dimensions need to be kept together in a physical layout if we want a
1311   // reshape to be a bitcast. The order of the alignment parts is defined by the
1312   // physical layout of the input shape, so when we construct the layout for the
1313   // output shape we just process the alignment parts in this order, and then
1314   // layout the dimensions belonging to each part in descending (major to minor)
1315   // order.
1316 
1317   // Stores the input and output dimension numbers where each alignment part
1318   // starts.
1319   std::vector<std::pair<int64, int64>> alignment;
1320   alignment.push_back({0, 0});
1321 
1322   // Stores a mapping from the input dimension to the alignment part it belongs
1323   // to.
1324   std::vector<int64> dimension_to_alignment_index(input_rank);
1325   int64 input_dimension_product = 1, output_dimension_product = 1;
1326   for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1327     // Check if we have reached the end of an alignment part.
1328     if (input_dimension_product == output_dimension_product &&
1329         input_dimension_product > 1) {
1330       alignment.push_back({i, j});
1331       input_dimension_product = output_dimension_product = 1;
1332     }
1333     if (input_dimension_product < output_dimension_product ||
1334         j == output_rank) {
1335       if (i == input_rank) {
1336         return absl::nullopt;
1337       }
1338       dimension_to_alignment_index[i] = alignment.size() - 1;
1339       input_dimension_product *= input_shape.dimensions(i);
1340       ++i;
1341     } else {
1342       output_dimension_product *= output_shape.dimensions(j);
1343       ++j;
1344     }
1345   }
1346   if (input_dimension_product != output_dimension_product) {
1347     return absl::nullopt;
1348   }
1349 
1350   // We also need to store an end element so that we know where the last
1351   // alignment part ends.
1352   alignment.push_back({input_rank, output_rank});
1353   // Now check if the physical layout can potentially be aligned to the output
1354   // shape by changing the physical layout of the output shape. We need to check
1355   // that all dimension numbers that belong to the same alignment part appear
1356   // consecutively, and are in descending order. However we can ignore any
1357   // trivial dimension bounds of 1, because they can be placed anywhere.
1358   auto input_dimension_numbers = input_shape.layout().minor_to_major();
1359   std::vector<int64> output_layout;
1360   output_layout.reserve(output_rank);
1361   for (int64 i = 0; i < input_rank;) {
1362     int64 current_dimension_number = input_dimension_numbers[i];
1363 
1364     // Trivial dimensions are stripped.
1365     CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1366     const int64 current_alignment_index =
1367         dimension_to_alignment_index[current_dimension_number];
1368     // Because of the special end element that we added, we can be sure that
1369     // 'current_alignment_index' is < alignment.size() - 1.
1370     CHECK_LT(current_alignment_index, alignment.size() - 1);
1371 
1372     // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1373     // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1374     // in descending order and belong to the current alignment part.
1375     for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1376                               alignment[current_alignment_index].first;
1377          ++i, ++j) {
1378       if (i == input_rank) {
1379         return absl::nullopt;
1380       }
1381       // If the current dimension number belongs to a different alignment part,
1382       // or the dimension numbers are not in descending order, we can return
1383       // early.
1384       if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1385               current_alignment_index ||
1386           input_dimension_numbers[i] > current_dimension_number) {
1387         return absl::nullopt;
1388       }
1389       current_dimension_number = input_dimension_numbers[i];
1390     }
1391     // The output dimension numbers that belong to the current alignment part
1392     // need to appear in the same descending order as in the input.
1393     for (int64 j = alignment[current_alignment_index + 1].second - 1;
1394          j >= alignment[current_alignment_index].second; --j) {
1395       output_layout.push_back(j);
1396     }
1397   }
1398   CHECK_EQ(output_layout.size(), output_rank);
1399   Shape output_shape_with_layout = MakeShapeWithLayout(
1400       output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1401       output_layout);
1402   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1403       << "reshape is not a bitcast for input_shape: "
1404       << ShapeUtil::HumanStringWithLayout(input_shape)
1405       << " and output_shape_with_layout: "
1406       << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1407   return output_shape_with_layout;
1408 }
1409 
DeleteDimension(int64 dim_to_delete,Shape shape)1410 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1411                                               Shape shape) {
1412   CHECK(shape.IsArray());
1413   shape.DeleteDimension(dim_to_delete);
1414   return shape;
1415 }
1416 
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1417 /* static */ Shape ShapeUtil::FilterDimensions(
1418     const std::function<bool(int64)>& p, Shape shape) {
1419   CHECK(shape.IsArray());
1420   std::vector<int64> dims_to_delete;
1421   for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1422     if (!p(i)) {
1423       dims_to_delete.push_back(i);
1424     }
1425   }
1426   for (int64 dim : dims_to_delete) {
1427     shape = DeleteDimension(dim, shape);
1428   }
1429   return shape;
1430 }
1431 
Hash(const Shape & shape)1432 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1433   using tensorflow::hash;
1434   using tensorflow::Hash64Combine;
1435 
1436   size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1437 
1438   if (shape.tuple_shapes().empty()) {
1439     for (int i = 0; i < shape.dimensions_size(); ++i) {
1440       hash_value =
1441           Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1442       hash_value = Hash64Combine(hash_value,
1443                                  hash<bool>()(shape.is_dynamic_dimension(i)));
1444     }
1445 
1446     hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1447   } else {
1448     hash_value = 0;
1449     for (const Shape& subshape : shape.tuple_shapes()) {
1450       hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1451     }
1452   }
1453 
1454   return hash_value;
1455 }
1456 
1457 }  // namespace xla
1458