1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape_util.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/strings/strip.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/overflow_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/gtl/iterator_range.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/protobuf.h"
49 #include "tensorflow/core/platform/regexp.h"
50 
51 namespace xla {
52 
53 using absl::StrAppend;
54 using absl::StrCat;
55 
56 namespace {
57 // An array that is indexed by PrimitiveType, and returns
58 // the size of each element of that primitive type, or 0
59 // if the PrimitiveType is not a primitive type
60 constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
61     0,                  // PRIMITIVE_TYPE_INVALID = 0,
62     sizeof(int8),       // PRED = 1
63     sizeof(int8),       // S8 = 2
64     sizeof(int16),      // S16 = 3
65     sizeof(int32),      // S32 = 4
66     sizeof(int64),      // S64 = 5
67     sizeof(uint8),      // U8 = 6
68     sizeof(uint16),     // U16 = 7
69     sizeof(uint32),     // U32 = 8
70     sizeof(uint64),     // U64 = 9
71     sizeof(float) / 2,  // F16 = 10
72     sizeof(float),      // F32 = 11
73     sizeof(double),     // F64 = 12
74     0,                  // TUPLE = 13
75     0,                  // OPAQUE_TYPE = 14
76     sizeof(complex64),  // C64 = 15
77     sizeof(float) / 2,  // BF16 = 16
78     0,                  // TOKEN = 17
79     sizeof(complex128)  // C128 = 18
80 };
81 }  // namespace
82 
ToString() const83 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
84 
ToString() const85 string ShapeIndexView::ToString() const {
86   return StrCat("{", absl::StrJoin(indices_, ","), "}");
87 }
88 
operator ==(const ShapeIndexView & other) const89 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
90   return indices_ == other.indices_;
91 }
92 
operator !=(const ShapeIndexView & other) const93 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
94   return !(*this == other);
95 }
96 
operator <<(std::ostream & out,const ShapeIndex & shape_index)97 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
98   out << shape_index.ToString();
99   return out;
100 }
101 
operator <<(std::ostream & out,const ShapeIndexView & shape_index)102 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
103   out << shape_index.ToString();
104   return out;
105 }
106 
StartsWith(ShapeIndexView prefix) const107 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
108   return size() >= prefix.size() &&
109          indices_.subspan(0, prefix.size()) == prefix.indices_;
110 }
111 
IsArrayPrimitiveType(PrimitiveType primitive_type)112 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
113     PrimitiveType primitive_type) {
114   return primitive_util::IsArrayType(primitive_type);
115 }
116 
117 namespace {
118 // Constructs and returns the new shape with the given minor_to_major order in
119 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)120 StatusOr<Shape> MakeShapeWithLayoutInternal(
121     PrimitiveType element_type, absl::Span<const int64> dimensions,
122     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
123     int64 element_size_in_bits, int64 memory_space) {
124   if (dimensions.size() != minor_to_major.size()) {
125     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
126                            dimensions.size(), minor_to_major.size());
127   }
128   if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
129     return InvalidArgument("Unsupported element type: %s",
130                            PrimitiveType_Name(element_type));
131   }
132   TF_ASSIGN_OR_RETURN(Shape shape,
133                       ShapeUtil::MakeValidatedShape(element_type, dimensions));
134   if (element_size_in_bits ==
135       ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
136     // Only set element_size_in_bits if it's different from the default value.
137     element_size_in_bits = 0;
138   }
139   *shape.mutable_layout() = LayoutUtil::MakeLayout(
140       minor_to_major, tiles, element_size_in_bits, memory_space);
141   if (!shape.has_layout()) {
142     return InvalidArgument("Shape has no layout.");
143   }
144   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
145   return shape;
146 }
147 }  // namespace
148 
Equal(const Shape & lhs,const Shape & rhs)149 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
150   bool equal = Shape::Equal()(lhs, rhs);
151 
152   if (!equal && VLOG_IS_ON(3)) {
153     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
154             << ", rhs = " << rhs.ShortDebugString();
155   }
156 
157   return equal;
158 }
159 
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)160 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
161                                                       const Shape& rhs) {
162   bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
163   if (!equal && VLOG_IS_ON(3)) {
164     VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
165             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
166   }
167 
168   return equal;
169 }
170 
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)171 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
172                                                       const Shape& rhs) {
173   bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
174   if (!equal && VLOG_IS_ON(3)) {
175     VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
176             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
177   }
178 
179   return equal;
180 }
181 
EqualStructure(const Shape & lhs,const Shape & rhs)182 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
183                                             const Shape& rhs) {
184   bool equal = true;
185   ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
186     equal &= IndexIsValid(rhs, index);
187   });
188   ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
189     equal &= IndexIsValid(lhs, index);
190   });
191 
192   return equal;
193 }
194 
TrueRank(const Shape & shape)195 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
196   int64 accum = 0;
197   for (int64 dimension : shape.dimensions()) {
198     // We do not count zero dimensions.
199     if (dimension != 1) {
200       accum += 1;
201     }
202   }
203   return accum;
204 }
205 
FillNewShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)206 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
207                                           absl::Span<const int64> dimensions,
208                                           Shape* shape) {
209   const int eint = static_cast<int>(element_type);
210   int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
211                                 ? primitive_byte_size[eint]
212                                 : 0);  // Out of range: force a failure
213   if (dense_shape_size <= 0) {
214     return false;
215   }
216 
217   // Verify that array-based lookup is consistent with public API.
218   DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
219       << element_type;
220 
221   shape->set_element_type(element_type);
222   const int ndims = dimensions.size();
223   auto layout = shape->mutable_layout();
224   layout->set_format(DENSE);
225   auto* minor_to_major = layout->mutable_minor_to_major();
226   for (int i = 0; i < ndims; i++) {
227     const int64 d = dimensions[i];
228     if (d < 0) {
229       return false;
230     }
231     dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
232     if (dense_shape_size < 0) {
233       return false;
234     }
235 
236     shape->add_dimensions(d);
237     minor_to_major->push_back(ndims - 1 - i);
238   }
239   return true;
240 }
241 
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)242 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
243     std::initializer_list<Shape> parameters, Shape result) {
244   ProgramShape program_shape;
245   for (const Shape& shape : parameters) {
246     *program_shape.add_parameters() = shape;
247   }
248   *program_shape.mutable_result() = std::move(result);
249   return program_shape;
250 }
251 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)252 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
253                                         absl::Span<const int64> dimensions) {
254   Shape shape;
255   CHECK(FillNewShape(element_type, dimensions, &shape));
256   return shape;
257 }
258 
MakeScalarShape(PrimitiveType element_type)259 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
260   return MakeShape(element_type, {});
261 }
262 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)263 /* static */ Shape ShapeUtil::MakeShape(
264     PrimitiveType element_type, absl::Span<const int64> dimensions,
265     const std::vector<bool>& dynamic_dimensions) {
266   return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
267       .ValueOrDie();
268 }
269 
MakeShapeWithStaticDimensions(const Shape & shape)270 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
271     const Shape& shape) {
272   Shape output = shape;
273   output.clear_dynamic_dimensions();
274   return output;
275 }
276 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)277 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
278     PrimitiveType element_type, absl::Span<const int64> dimensions) {
279   Shape shape;
280   if (!FillNewShape(element_type, dimensions, &shape)) {
281     return InvalidArgument("invalid shape type=%d, dims=[%s]",
282                            static_cast<int>(element_type),
283                            absl::StrJoin(dimensions, ","));
284   }
285   return shape;
286 }
287 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)288 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
289     PrimitiveType element_type, absl::Span<const int64> dimensions,
290     const std::vector<bool>& dynamic_dimensions) {
291   if (dynamic_dimensions.size() != dimensions.size()) {
292     return InvalidArgument(
293         "dynamic dimensions size %d did not match number of dimensions %d",
294         dynamic_dimensions.size(), dimensions.size());
295   }
296 
297   Shape shape;
298   if (!FillNewShape(element_type, dimensions, &shape)) {
299     return InvalidArgument("invalid shape type=%d, dims=[%s]",
300                            static_cast<int>(element_type),
301                            absl::StrJoin(dimensions, ","));
302   }
303   for (int i = 0, n = dimensions.size(); i < n; i++) {
304     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
305   }
306   return shape;
307 }
308 
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)309 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
310     PrimitiveType element_type, absl::Span<const int64> dimensions,
311     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
312     int64 element_size_in_bits, int64 memory_space) {
313   auto ret =
314       MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
315                                   tiles, element_size_in_bits, memory_space);
316   if (!ret.ok()) LOG(ERROR) << ret.status();
317   return ret.ValueOrDie();
318 }
319 
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)320 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
321     PrimitiveType element_type, absl::Span<const int64> dimensions) {
322   std::vector<int64> layout(dimensions.size());
323   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
324   return MakeShapeWithLayout(element_type, dimensions, layout);
325 }
326 
327 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)328 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
329     const Shape& shape) {
330   std::vector<int64> dims(shape.dimensions_size());
331   for (int i = 0; i < shape.dimensions_size(); ++i) {
332     dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
333   }
334   Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
335   // Since the physical layout is kept the same, the tiles and element size are
336   // the same also.
337   new_shape.mutable_layout()->mutable_tiles()->assign(
338       shape.layout().tiles().begin(), shape.layout().tiles().end());
339   new_shape.mutable_layout()->set_element_size_in_bits(
340       shape.layout().element_size_in_bits());
341   for (int i = 0; i < shape.dimensions_size(); ++i) {
342     new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
343   }
344   return new_shape;
345 }
346 
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)347 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
348                                              absl::Span<const int64> dimensions,
349                                              Shape* shape) {
350   shape->Clear();
351   shape->set_element_type(element_type);
352   for (int64 dimension : dimensions) {
353     shape->add_dimensions(dimension);
354   }
355   LayoutUtil::SetToDefaultLayout(shape);
356   return ValidateShape(*shape);
357 }
358 
MakeStaticShape(const Shape & original)359 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
360   Shape result = original;
361   result.clear_dynamic_dimensions();
362   return result;
363 }
364 
MakeTupleShape(absl::Span<const Shape> shapes)365 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
366   Shape result;
367   result.set_element_type(TUPLE);
368   result.mutable_tuple_shapes()->reserve(shapes.size());
369   for (const auto& shape : shapes) {
370     AppendShapeToTuple(shape, &result);
371   }
372   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
373   return result;
374 }
375 
MakeOpaqueShape()376 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
377   Shape result;
378   result.set_element_type(OPAQUE_TYPE);
379   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
380   return result;
381 }
382 
MakeTokenShape()383 /* static */ Shape ShapeUtil::MakeTokenShape() {
384   Shape result;
385   result.set_element_type(TOKEN);
386   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
387   return result;
388 }
389 
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)390 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
391                                                 Shape* tuple_shape) {
392   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
393   *tuple_shape->add_tuple_shapes() = shape;
394 }
395 
UpdateTupleShape(const Shape & shape,int64 index,Shape * tuple_shape)396 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64 index,
397                                               Shape* tuple_shape) {
398   CHECK(index < tuple_shape->tuple_shapes_size());
399   *tuple_shape->mutable_tuple_shapes(index) = shape;
400 }
401 
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64 dim,bool is_dynamic)402 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
403                                                     ShapeIndexView index,
404                                                     int64 dim,
405                                                     bool is_dynamic) {
406   if (index.empty()) {
407     CHECK(!shape->IsTuple());
408     shape->set_dynamic_dimension(dim, is_dynamic);
409     return;
410   }
411 
412   UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
413                          index.ConsumeFront(), dim, is_dynamic);
414 }
415 
AppendMajorDimension(int bound,Shape * shape)416 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
417   CHECK(LayoutUtil::IsDenseArray(*shape));
418   shape->mutable_layout()->add_minor_to_major(shape->rank());
419   shape->add_dimensions(bound);
420   TF_DCHECK_OK(ValidateShape(*shape));
421 }
422 
CopyDynamicDimensions(Shape * to,const Shape & from)423 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
424                                                    const Shape& from) {
425   CHECK_EQ(to->rank(), from.rank());
426   for (int64 i = 0; i < from.rank(); ++i) {
427     to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
428   }
429   TF_DCHECK_OK(ValidateShape(*to));
430 }
431 
ElementIsIntegral(const Shape & shape)432 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
433   return primitive_util::IsIntegralType(shape.element_type());
434 }
435 
ElementIsIntegralWithBits(const Shape & shape,int32 bits)436 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
437                                                        int32 bits) {
438   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
439 }
440 
ElementHasBitWidth(const Shape & shape,int bits)441 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
442   if (!shape.IsArray()) {
443     return false;
444   }
445   return primitive_util::BitWidth(shape.element_type()) == bits;
446 }
447 
ElementIsSigned(const Shape & shape)448 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
449   switch (shape.element_type()) {
450     case S8:
451     case S16:
452     case S32:
453     case S64:
454     case F16:
455     case BF16:
456     case F32:
457     case F64:
458       return true;
459 
460     case PRED:
461     case U8:
462     case U16:
463     case U32:
464     case U64:
465     case C64:
466     case C128:
467     case TUPLE:
468     case OPAQUE_TYPE:
469     case TOKEN:
470       return false;
471 
472     default:
473       LOG(FATAL) << "Unhandled element type " << shape.element_type();
474   }
475 }
476 
ElementIsComplex(const Shape & shape)477 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
478   return primitive_util::IsComplexType(shape.element_type());
479 }
480 
ElementIsFloating(const Shape & shape)481 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
482   return primitive_util::IsFloatingPointType(shape.element_type());
483 }
484 
IsNestedTuple(const Shape & shape)485 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
486   return shape.IsTuple() &&
487          absl::c_any_of(shape.tuple_shapes(),
488                         [](const Shape& s) { return s.IsTuple(); });
489 }
490 
IsEmptyTuple(const Shape & shape)491 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
492   return shape.IsTuple() && TupleElementCount(shape) == 0;
493 }
494 
TupleElementCount(const Shape & shape)495 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
496   CHECK(shape.IsTuple()) << HumanString(shape);
497   return shape.tuple_shapes_size();
498 }
499 
GetTupleElementShape(const Shape & shape,int64 index)500 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
501                                                           int64 index) {
502   CHECK(shape.IsTuple());
503   CHECK_GT(TupleElementCount(shape), index);
504   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
505   return shape.tuple_shapes(index);
506 }
507 
SubshapeCount(const Shape & shape)508 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
509   int64 n = 0;
510   ForEachSubshape(shape, [&](const Shape& literal_subshape,
511                              const ShapeIndex& index) { ++n; });
512   return n;
513 }
514 
SliceTuple(const Shape & tuple,int64 start,int64 limit)515 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
516                                          int64 limit) {
517   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
518   CHECK(tuple.IsTuple());
519   CHECK_LE(start, TupleElementCount(tuple));
520   CHECK_LE(limit, TupleElementCount(tuple));
521 
522   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
523                                   tuple.tuple_shapes().begin() + limit);
524   return MakeTupleShape(new_elements);
525 }
526 
527 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)528 /* static */ Shape ShapeUtil::ComplexComponentShape(
529     const Shape& complex_shape) {
530   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
531   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
532                                               complex_shape.element_type()));
533 }
534 
ElementsIn(const Shape & shape)535 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
536   DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
537   DCHECK_EQ(shape.dimensions_size(), shape.rank());
538   if (shape.dimensions().size() == 1) {
539     return shape.dimensions()[0];
540   }
541   return std::accumulate<decltype(shape.dimensions().begin()), int64>(
542       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
543       std::multiplies<int64>());
544 }
545 
ElementsInRecursive(const Shape & shape)546 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
547   CHECK(shape.IsArray() || shape.IsTuple());
548   if (shape.IsArray()) {
549     return ElementsIn(shape);
550   }
551   int64 count = 0;
552   for (const Shape& element_shape : shape.tuple_shapes()) {
553     count += ElementsInRecursive(element_shape);
554   }
555   return count;
556 }
557 
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)558 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
559                                               PrimitiveType primitive_type) {
560   if (shape.element_type() == primitive_type) {
561     return true;
562   }
563   for (const Shape& element_shape : shape.tuple_shapes()) {
564     if (HasPrimitiveType(element_shape, primitive_type)) {
565       return true;
566     }
567   }
568   return false;
569 }
570 
IsZeroElementArray(const Shape & shape)571 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
572   return shape.IsArray() && ElementsIn(shape) == 0;
573 }
574 
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)575 /* static */ bool ShapeUtil::IsScalarWithElementType(
576     const Shape& shape, PrimitiveType element_type) {
577   return IsScalar(shape) && shape.element_type() == element_type;
578 }
579 
HumanString(const Shape & shape)580 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
581   if (shape.IsTuple()) {
582     string text = "(";
583     const char* prefix = "";
584     for (const Shape& elem_shape : shape.tuple_shapes()) {
585       StrAppend(&text, prefix, HumanString(elem_shape));
586       prefix = ", ";
587     }
588     text += ")";
589     return text;
590   }
591   std::vector<string> dim_elements;
592   for (int i = 0; i < shape.dimensions_size(); ++i) {
593     if (shape.is_dynamic_dimension(i)) {
594       dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
595     } else {
596       dim_elements.push_back(StrCat(shape.dimensions(i)));
597     }
598   }
599   return StrCat(
600       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
601       absl::StrJoin(dim_elements, ","), "]");
602 }
603 
HumanStringWithLayout(const Shape & shape)604 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
605   if (shape.IsTuple()) {
606     string text = "(";
607     const char* prefix = "";
608     for (const Shape& elem_shape : shape.tuple_shapes()) {
609       StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
610       prefix = ", ";
611     }
612     text += ")";
613     return text;
614   }
615   string result = HumanString(shape);
616   if (IsScalar(shape)) {
617     string layout_str = LayoutUtil::HumanString(shape.layout());
618     // Don't print "{}" as layout for scalars.
619     if (layout_str != "{}") {
620       StrAppend(&result, layout_str);
621     }
622   } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
623     StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
624   }
625   return result;
626 }
627 
HumanString(const ProgramShape & program_shape)628 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
629   std::vector<string> parameters;
630   for (auto& shape : program_shape.parameters()) {
631     const int i = parameters.size();
632     parameters.push_back(StrCat(i < program_shape.parameter_names_size()
633                                     ? program_shape.parameter_names(i)
634                                     : "(unknown)",
635                                 ": ", HumanString(shape)));
636   }
637   return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
638                 HumanString(program_shape.result()));
639 }
640 
SameDimensions(const Shape & lhs,const Shape & rhs)641 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
642                                             const Shape& rhs) {
643   CHECK(lhs.IsArray());
644   CHECK(rhs.IsArray());
645   return absl::c_equal(lhs.dimensions(), rhs.dimensions());
646 }
647 
SameRank(const Shape & lhs,const Shape & rhs)648 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
649   CHECK(lhs.IsArray());
650   CHECK(rhs.IsArray());
651   return lhs.rank() == rhs.rank();
652 }
653 
Compatible(const Shape & lhs,const Shape & rhs)654 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
655   return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
656 }
657 
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)658 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
659                                                            const Shape& rhs) {
660   return Shape::Equal()
661       .IgnoreDynamicDimension()
662       .IgnoreElementType()
663       .IgnoreLayout()(lhs, rhs);
664 }
665 
CompatibleKind(const Shape & lhs,const Shape & rhs)666 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
667                                             const Shape& rhs) {
668   return Shape::Equal()
669       .IgnoreElementType()
670       .IgnoreLayout()
671       .IgnoreDimensions()
672       .IgnoreDynamicDimension()(lhs, rhs);
673 }
674 
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)675 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
676                                                            const Shape& rhs) {
677   return Shape::Equal()
678       .IgnoreDynamicDimension()
679       .IgnoreFpPrecision()
680       .IgnoreLayout()(lhs, rhs);
681 }
682 
GetDimension(const Shape & shape,int64 dimension_number)683 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
684                                            int64 dimension_number) {
685   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
686 }
687 
GetDimensionNumber(const Shape & shape,int64 dimension_number)688 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
689                                                  int64 dimension_number) {
690   if (dimension_number < 0) {
691     dimension_number += shape.rank();
692   }
693   CHECK_GE(dimension_number, 0);
694   return dimension_number;
695 }
696 
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)697 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
698     PrimitiveType primitive_type) {
699   switch (primitive_type) {
700     case PRED:
701       return sizeof(int8);
702     case S8:
703       return sizeof(int8);
704     case S16:
705       return sizeof(int16);
706     case S32:
707       return sizeof(int32);
708     case S64:
709       return sizeof(int64);
710     case U8:
711       return sizeof(uint8);
712     case U16:
713       return sizeof(uint16);
714     case U32:
715       return sizeof(uint32);
716     case U64:
717       return sizeof(uint64);
718     case BF16:
719       return sizeof(float) / 2;
720     case F16:
721       return sizeof(float) / 2;
722     case F32:
723       return sizeof(float);
724     case F64:
725       return sizeof(double);
726     case C64:
727       return sizeof(complex64);
728     case C128:
729       return sizeof(complex128);
730     case TOKEN:
731       // Tokens require no space.
732       return 0;
733     case TUPLE:
734     case OPAQUE_TYPE:
735       LOG(FATAL) << PrimitiveType_Name(primitive_type)
736                  << " primitive type has no definitive size";
737     default:
738       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
739   }
740 }
741 
ByteSizeOf(const Shape & shape,int64 pointer_size)742 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
743                                          int64 pointer_size) {
744   TF_DCHECK_OK(ValidateShape(shape));
745   if (shape.element_type() == TUPLE) {
746     return ByteSizeOfTupleIndexTable(shape, pointer_size);
747   } else if (shape.IsArray()) {
748     return ByteSizeOfElements(shape);
749   } else if (shape.element_type() == TOKEN) {
750     return 0;
751   } else if (shape.element_type() == OPAQUE_TYPE) {
752     CHECK_GT(pointer_size, 0);
753     return pointer_size;
754   }
755   LOG(FATAL) << PrimitiveType_Name(shape.element_type())
756              << " primitive type has no definitive size";
757 }
758 
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)759 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
760                                                         int64 pointer_size) {
761   TF_DCHECK_OK(ValidateShape(shape));
762   CHECK_EQ(TUPLE, shape.element_type());
763   CHECK_GT(pointer_size, 0);
764   return pointer_size * shape.tuple_shapes_size();
765 }
766 
ByteSizeOfElements(const Shape & shape)767 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
768   TF_DCHECK_OK(ValidateShape(shape));
769   CHECK(shape.IsArray());
770   int64 allocated_element_count;
771 
772   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
773   allocated_element_count = ElementsIn(shape);
774   return allocated_element_count *
775          ByteSizeOfPrimitiveType(shape.element_type());
776 }
777 
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)778 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
779     const Shape& shape) {
780   if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
781       !PrimitiveType_IsValid(shape.element_type())) {
782     return InvalidArgument("shape has invalid element type: %s",
783                            shape.ShortDebugString());
784   }
785   if (shape.element_type() == TUPLE) {
786     if (shape.dimensions_size() != 0) {
787       return InvalidArgument("tuples must not have dimensions specified");
788     }
789     for (auto& element_shape : shape.tuple_shapes()) {
790       TF_RETURN_IF_ERROR(
791           ValidateShapeWithOptionalLayoutInternal(element_shape));
792     }
793     return Status::OK();
794   }
795 
796   // Non-tuple shape.
797   if (shape.tuple_shapes_size() > 0) {
798     return InvalidArgument("non-tuple shape has tuple_shapes field");
799   }
800 
801   // Tokens and opaques can should not have layout or dimensions.
802   if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
803     if (shape.dimensions_size() != 0) {
804       return InvalidArgument(
805           "shape has %s element type, but has dimensions field: %s",
806           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
807           shape.ShortDebugString());
808     }
809     if (shape.has_layout()) {
810       return InvalidArgument(
811           "shape has %s element type, but has layout field: %s",
812           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
813           shape.ShortDebugString());
814     }
815     return Status::OK();
816   }
817 
818   for (int64 i = 0; i < shape.rank(); ++i) {
819     int64 dimension = shape.dimensions(i);
820     if (dimension < 0) {
821       return InvalidArgument(
822           "shape's dimensions must not be < 0; dimension at index %d was %d", i,
823           dimension);
824     }
825   }
826 
827   TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
828   return Status::OK();
829 }
830 
ValidateShapeSize(const Shape & shape)831 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
832   VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
833 
834   if (!shape.IsArray()) {
835     return Status::OK();
836   }
837 
838   int64 shape_size = [&]() {
839     int64 dense_shape_size = 1;
840     if (shape.dimensions().empty()) {
841       return dense_shape_size;
842     }
843 
844     absl::Span<const int64> shape_max_dimensions =
845         AsInt64Slice(shape.dimensions());
846     for (int64 dim : shape_max_dimensions) {
847       dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
848       if (dense_shape_size < 0) {
849         return dense_shape_size;
850       }
851     }
852     dense_shape_size = MultiplyWithoutOverflow(
853         dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
854     return dense_shape_size;
855   }();
856 
857   if (shape_size < 0) {
858     return InvalidArgument("Shape %s size may overflow int64.",
859                            ShapeUtil::HumanString(shape));
860   }
861 
862   VLOG(3) << "Shape size is valid: " << shape_size;
863   return Status::OK();
864 }
865 
ValidateShapeWithOptionalLayout(const Shape & shape)866 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
867     const Shape& shape) {
868   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
869 
870   return LayoutUtil::ValidateLayoutInShape(shape,
871                                            /*allow_missing_layouts=*/true);
872 }
873 
ValidateShape(const Shape & shape)874 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
875   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
876 
877   return LayoutUtil::ValidateLayoutInShape(shape);
878 }
879 
ChangeElementType(const Shape & original,PrimitiveType type)880 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
881                                                 PrimitiveType type) {
882   if (original.IsTuple()) {
883     std::vector<Shape> new_operands;
884     new_operands.reserve(original.tuple_shapes_size());
885     for (const Shape& operand : original.tuple_shapes()) {
886       new_operands.push_back(ChangeElementType(operand, type));
887     }
888     return MakeTupleShape(new_operands);
889   } else {
890     Shape new_shape = original;
891     new_shape.set_element_type(type);
892     return new_shape;
893   }
894 }
895 
IndexIsValid(const Shape & shape,ShapeIndexView index)896 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
897                                           ShapeIndexView index) {
898   const Shape* subshape = &shape;
899   for (auto i : index) {
900     if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
901       return false;
902     }
903     subshape = &subshape->tuple_shapes(i);
904   }
905   return true;
906 }
907 
GetSubshape(const Shape & shape,ShapeIndexView index)908 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
909                                                  ShapeIndexView index) {
910   const Shape* return_shape = &shape;
911   for (auto i : index) {
912     CHECK(return_shape->IsTuple())
913         << "Invalid index " << index << " for shape " << shape;
914     return_shape = &return_shape->tuple_shapes(i);
915   }
916   return *return_shape;
917 }
918 
TryGetSubshape(const Shape & shape,ShapeIndexView index)919 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
920     const Shape& shape, ShapeIndexView index) {
921   const Shape* return_shape = &shape;
922   for (auto i : index) {
923     if (!return_shape->IsTuple() || i < 0 ||
924         i >= return_shape->tuple_shapes_size()) {
925       return InvalidArgument(
926           "Shape index %s not a valid subshape index for tuple with shape %s",
927           index.ToString(), shape.DebugString());
928     }
929     return_shape = &return_shape->tuple_shapes(i);
930   }
931   return return_shape;
932 }
933 
GetMutableSubshape(Shape * shape,ShapeIndexView index)934 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
935                                                   ShapeIndexView index) {
936   Shape* return_shape = shape;
937   for (auto i : index) {
938     CHECK(return_shape->IsTuple());
939     return_shape = return_shape->mutable_tuple_shapes(i);
940   }
941   return return_shape;
942 }
943 
944 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)945 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
946   return !GetSubshape(shape, index).IsTuple();
947 }
948 
GetLeafCount(const Shape & shape)949 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
950   if (!shape.IsTuple()) {
951     return 1;
952   }
953   int64 count = 0;
954   for (const Shape& subshape : shape.tuple_shapes()) {
955     count += GetLeafCount(subshape);
956   }
957   return count;
958 }
959 
GetLeafShapes(const Shape & shape)960 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
961     const Shape& shape) {
962   std::vector<IndexedShape> leaves;
963   ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
964     if (IsLeafIndex(shape, index)) {
965       leaves.emplace_back(index, sub_shape);
966     }
967   });
968   return leaves;
969 }
970 
HasDegenerateDimensions(const Shape & shape)971 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
972   CHECK(shape.IsArray());
973   return absl::c_linear_search(shape.dimensions(), 1);
974 }
975 
DropDegenerateDimensions(const Shape & shape)976 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
977   return FilterDimensions(
978       [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
979 }
980 
981 namespace {
982 
983 // Helper for ForEachSubshape which visits the subshapes of the given shape in
984 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)985 Status ForEachSubshapeHelper(const Shape& shape,
986                              const ShapeUtil::StatusVisitorFunction& func,
987                              ShapeIndex* index) {
988   TF_RETURN_IF_ERROR(func(shape, *index));
989   if (shape.IsTuple()) {
990     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
991       index->push_back(i);
992       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
993           ShapeUtil::GetTupleElementShape(shape, i), func, index));
994       index->pop_back();
995     }
996   }
997   return Status::OK();
998 }
999 
1000 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1001 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1002 Status ForEachMutableSubshapeHelper(
1003     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1004     ShapeIndex* index) {
1005   TF_RETURN_IF_ERROR(func(shape, *index));
1006   if (shape->IsTuple()) {
1007     for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1008       index->push_back(i);
1009       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1010           shape->mutable_tuple_shapes(i), func, index));
1011       index->pop_back();
1012     }
1013   }
1014   return Status::OK();
1015 }
1016 
1017 }  // namespace
1018 
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1019 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1020                                              const VisitorFunction& func) {
1021   ShapeIndex index;
1022   ForEachSubshapeHelper(
1023       shape,
1024       [&func](const Shape& subshape, const ShapeIndex& index) {
1025         func(subshape, index);
1026         return Status::OK();
1027       },
1028       &index)
1029       .IgnoreError();
1030 }
1031 
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1032 /* static */ void ShapeUtil::ForEachMutableSubshape(
1033     Shape* shape, const MutatingVisitorFunction& func) {
1034   ShapeIndex index;
1035   ForEachMutableSubshapeHelper(
1036       shape,
1037       [&func](Shape* subshape, const ShapeIndex& index) {
1038         func(subshape, index);
1039         return Status::OK();
1040       },
1041       &index)
1042       .IgnoreError();
1043 }
1044 
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1045 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1046     const Shape& shape, const StatusVisitorFunction& func) {
1047   ShapeIndex index;
1048   return ForEachSubshapeHelper(shape, func, &index);
1049 }
1050 
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1051 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1052     Shape* shape, const MutatingStatusVisitorFunction& func) {
1053   ShapeIndex index;
1054   return ForEachMutableSubshapeHelper(shape, func, &index);
1055 }
1056 
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)1057 /* static */ Shape ShapeUtil::PermuteDimensions(
1058     absl::Span<const int64> permutation, const Shape& shape) {
1059   Shape new_shape = shape;
1060   new_shape.clear_dimensions();
1061   for (auto dim : Permute(shape.dimensions(), permutation)) {
1062     new_shape.add_dimensions(dim);
1063   }
1064   auto inv_permutation = InversePermutation(permutation);
1065   for (int64 i = 0; i < shape.rank(); i++) {
1066     new_shape.set_dynamic_dimension(inv_permutation[i],
1067                                     shape.is_dynamic_dimension(i));
1068   }
1069 
1070   // If `shape` has a layout, by contract we choose a new layout such that the
1071   // transpose defined by this permutation is a bitcast.
1072   //
1073   // Some formalism helps to understand the correct way to do this.  We're going
1074   // to do algebra in the group of permutations of the dimensions of `shape`.
1075   //
1076   // Since the order of `shape`'s dimensions is not permuted relative to itself,
1077   // `shape`'s list of dimensions is isomorphic to the identity I.
1078   //
1079   // Let `shape`'s layout be L.  A layout is a permutation which maps a
1080   // minor-to-major physical dimension ordering to a shape's logical dimension
1081   // ordering.  Therefore the inverse of a layout maps from logical to physical
1082   // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1083   // the inverse of L.
1084   //
1085   // Let the argument `permutation` be P.  This is a permutation over `shape`'s
1086   // dimensions, so our return value will be a shape with dims P.I = P.  Our
1087   // goal is to construct a layout permutation L* for this shape. The physical
1088   // dimension ordering of this returned shape must be the same as that of the
1089   // original shape, namely L'.
1090   //
1091   // Our returned shape has dims P and layout L*, so its in-memory ordering is
1092   // L*'.P.  Setting this equal to L' and solving for L*, we get:
1093   //
1094   //   L*'.P = L'    =>
1095   //   L*'   = L'P'  =>
1096   //   L*    = P.L
1097   //
1098   if (shape.has_layout()) {
1099     CHECK(LayoutUtil::IsDenseArray(shape));
1100     Layout* new_layout = new_shape.mutable_layout();
1101     new_layout->set_format(DENSE);
1102     new_layout->clear_minor_to_major();
1103     for (auto index : ComposePermutations(
1104              inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
1105       new_layout->add_minor_to_major(index);
1106     }
1107     // The permutation accepted by TransposeIsBitcast is the inverse of the
1108     // permutation here.
1109     CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1110         << "shape=" << HumanStringWithLayout(shape)
1111         << ", new_shape=" << HumanStringWithLayout(new_shape)
1112         << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1113   }
1114   return new_shape;
1115 }
1116 
1117 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1118 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1119                                              const Shape& shape_post) {
1120   CHECK(shape_pre.IsArray());
1121   CHECK(shape_post.IsArray());
1122 
1123   auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
1124 
1125   std::vector<int64> deleted_indices;
1126   std::vector<int64> inserted_indices;
1127   // Returns false if any input/output index between prior_unmodified_dim_pair
1128   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1129   // the degerenate input/output dimensions in the gap to
1130   // deleted_indices/inserted_indices respectively.
1131   auto check_modified_dims =
1132       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1133           std::pair<int64, int64> prior_unmodified_dim_pair,
1134           std::pair<int64, int64> unmodified_dim_pair) {
1135         for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1136              modified_input_dim < unmodified_dim_pair.first;
1137              ++modified_input_dim) {
1138           if (shape_pre.dimensions(modified_input_dim) > 1) {
1139             return false;
1140           }
1141           deleted_indices.push_back(modified_input_dim);
1142         }
1143         for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1144              modified_output_dim < unmodified_dim_pair.second;
1145              ++modified_output_dim) {
1146           if (shape_post.dimensions(modified_output_dim) > 1) {
1147             return false;
1148           }
1149           inserted_indices.push_back(modified_output_dim);
1150         }
1151         return true;
1152       };
1153 
1154   std::vector<std::pair<int64, int64>> unmodified_dims =
1155       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1156   // Returns nil if the reshape modifies any non-degenerate input/output
1157   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1158   // dimensions, so we only need to check whether dimensions in the gaps (thus
1159   // modified) have size >1.
1160   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1161     // Check (modified) dimensions between unmodified_dims[i-1] and
1162     // unmodified_dims[i].
1163     auto prior_unmodified_dim_pair =
1164         i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1165     auto unmodified_dim_pair =
1166         i < unmodified_dims.size()
1167             ? unmodified_dims[i]
1168             : std::make_pair(shape_pre.rank(), shape_post.rank());
1169     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1170       return nil;
1171     }
1172   }
1173 
1174   return std::make_tuple(true, deleted_indices, inserted_indices);
1175 }
1176 
1177 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1178 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1179                                          const Shape& output_shape) {
1180   CHECK(input_shape.IsArray());
1181   CHECK(output_shape.IsArray());
1182 
1183   // Unmodified dimensions are merely common factors of rank 1.
1184   auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1185                                       AsInt64Slice(output_shape.dimensions()));
1186   for (size_t i = 0; i < common_factors.size() - 1;) {
1187     if (1 != common_factors[i + 1].first - common_factors[i].first ||
1188         1 != common_factors[i + 1].second - common_factors[i].second) {
1189       common_factors.erase(common_factors.begin() + i);
1190     } else {
1191       ++i;
1192     }
1193   }
1194   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1195   common_factors.pop_back();
1196   return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1197                                               common_factors.end());
1198 }
1199 
1200 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1201 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1202     const Shape& from_shape, const Shape& to_shape,
1203     absl::Span<const int64> input_dim_indices) {
1204   CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1205 
1206   std::vector<int64> output_dim_indices;
1207   std::vector<std::pair<int64, int64>> unmodified_dims =
1208       ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1209   size_t i = 0;  // index to unmodified_dims
1210   for (int64 input_dim_index : input_dim_indices) {
1211     // Search unmodified_dims for input_dim_index. We can search from the last
1212     // matching position because input_dim_indices is guaranteed to be sorted.
1213     while (i < unmodified_dims.size() &&
1214            unmodified_dims[i].first < input_dim_index) {
1215       ++i;
1216     }
1217     if (i >= unmodified_dims.size() ||
1218         unmodified_dims[i].first != input_dim_index) {
1219       return absl::nullopt;
1220     }
1221     output_dim_indices.push_back(unmodified_dims[i].second);
1222   }
1223   return output_dim_indices;
1224 }
1225 
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1226 /* static */ bool ShapeUtil::TransposeIsBitcast(
1227     const Shape& input_shape, const Shape& output_shape,
1228     absl::Span<const int64> dimension_mapping) {
1229   CHECK(LayoutUtil::HasLayout(input_shape) &&
1230         LayoutUtil::HasLayout(output_shape));
1231 
1232   if (!SameElementType(input_shape, output_shape)) {
1233     return false;
1234   }
1235 
1236   // Check the reshape permutes the positions of each dimension in the
1237   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1238   //   input_positions = apply(dimension_mapping, output_positions)
1239   //
1240   // Because the positions of each dimension are the inverse permutation of the
1241   // minor-to-major order, the above check is equivalent to
1242   //   inverse(input_dimensions) =
1243   //       apply(dimension_mapping, inverse(output_dimensions))
1244   //   # `I` indicates identity permutation.
1245   //   apply(input_dimensions, I) =
1246   //       apply(dimension_mapping, apply(output_dimensions, I))
1247   //   apply(input_dimensions, I) =
1248   //       apply((dimension_mapping * output_dimensions), I)
1249   //   input_dimensions = dimension_mapping * output_dimensions
1250   return absl::c_equal(
1251       ComposePermutations(dimension_mapping,
1252                           AsInt64Slice(output_shape.layout().minor_to_major())),
1253       input_shape.layout().minor_to_major());
1254 }
1255 
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1256 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1257                                               const Shape& output_shape) {
1258   CHECK(input_shape.IsArray());
1259   CHECK(output_shape.IsArray());
1260   CHECK(LayoutUtil::HasLayout(input_shape));
1261   CHECK(LayoutUtil::HasLayout(output_shape));
1262 
1263   if (!SameElementType(input_shape, output_shape)) {
1264     return false;
1265   }
1266 
1267   CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape))
1268       << "input_shape=" << input_shape.ShortDebugString()
1269       << ", output_shape=" << output_shape.ShortDebugString();
1270   if (ElementsIn(input_shape) == 0) {
1271     return true;
1272   }
1273 
1274   // TL;DR: The rest of the method checks that the reshape does not change the
1275   // physical location of any unit input or output index. Unit indices have
1276   // exactly one dimension that equals 1 and other dimensions 0. This condition
1277   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1278   // reshape shouldn't change the physical location of any element. It is also a
1279   // sufficient condition as is proved below (note: many details are omitted for
1280   // space).
1281   //
1282   // Definitions:
1283   //
1284   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1285   // the size of i-th least significant dimension of IS or OS (this is opposite
1286   // to how we define the index of Shape::dimensions()).
1287   //
1288   // * Given an input or output index I, denote by p(I) I's physical linear
1289   // index (or physical index for short) and l(I) I's logical linear index (or
1290   // logical index for short).
1291   //
1292   // * Given a logical index k, denote by II(k) the input index whose linear
1293   // index is k, and OI(k) the corresponding output index.
1294   //
1295   // * Denote by IT[i] the increment of physical index if i-th dimension of the
1296   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1297   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1298   // is a function of IS or OS and the layout, and not dependent on the specific
1299   // input or output index.
1300   //
1301   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1302   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1303   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1304   // to prove, with every increment on k, the above formula still holds.
1305   //
1306   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1307   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1308   // there exists (i,j) such that
1309   //
1310   //   0<=i<=|IS|
1311   //   0<=j<=|OS|
1312   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1313   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
1314   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1315   //
1316   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1317   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1318   // unit indices which are already tested. This also means IT[0]=OT[0]
1319   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1320   //
1321   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1322   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1323   // to the output physical.
1324   //
1325   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1326   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1327   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1328   // logical index k-1 to logical index k, dimension 1 of the input index
1329   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1330   // IS[0]-1. Therefore, the physical input index is increased by
1331   //
1332   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1333   //
1334   // Because IS[0]<OS[0], the only change to the output index is that its
1335   // dimension 0 is increased by one. Therefore,
1336   //
1337   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1338   //
1339   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1340   // p(II(k))=p(OI(k)). Therefore,
1341   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
1342   //   IT[1] = IS[0] * IT[0]
1343   // In other words, input dimension 1 is immediately more major than input
1344   // dimension 0. We can now conceptually collapse these two dimensions because
1345   // an increment in the logical index affecting only these two dimensions maps
1346   // to IT[0] in the physical index.
1347   //
1348   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1349   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1350   // identical.
1351   //
1352   // A factorizable reshape can be factorized into a list of non-factorizable
1353   // sub-reshapes, each of which can be handled similarly to the proof above.
1354   // For example,
1355   //
1356   //   [7x9x2x15] -> [63x6x5]
1357   //
1358   // can be factorized into
1359   //
1360   //   [7x9] -> [63] and [2x15] -> [6x5].
1361   //
1362   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1363   // same logical linear index. According to the factorization, we know
1364   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1365   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1366   // similar proof, with the increment of the logical index set to
1367   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1368   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1369   //
1370   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1371   //                  = p(y2,0,0) + p(0,0,y1,y0)
1372   //                  = p(y2,y1,y0)
1373   //
1374   // check_input_unit_indices checks one way of the condition: each input unit
1375   // index is mapped to an output index with the same physical location. This
1376   // lambda will be called again with input_shape and output_shape reversed to
1377   // check the other way.
1378   auto check_input_unit_indices = [](const Shape& input_shape,
1379                                      const Shape& output_shape) {
1380     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1381     // as input_shape/output_shape and the dimension-0-major layout. These two
1382     // shapes are used for conversion between logical linear indices and
1383     // multi-dimensional indices.
1384     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1385         input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1386     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1387         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1388 
1389     for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1390       if (input_shape.dimensions(input_dim) <= 1) {
1391         continue;
1392       }
1393 
1394       std::vector<int64> input_unit_index(input_shape.rank(), 0);
1395       input_unit_index[input_dim] = 1;
1396       int64 logical_linear_index =
1397           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1398                                                         input_unit_index);
1399       // output_index has the same logical linear index as input_unit_index.
1400       std::vector<int64> output_index =
1401           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1402                                                         logical_linear_index);
1403       // Check input_unit_index and output_index have the same physical linear
1404       // index.
1405       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1406                                                         input_unit_index) !=
1407           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1408                                                         output_index)) {
1409         return false;
1410       }
1411     }
1412     return true;
1413   };
1414   return check_input_unit_indices(input_shape, output_shape) &&
1415          check_input_unit_indices(output_shape, input_shape);
1416 }
1417 
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1418 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1419     const Shape& input_shape, const Shape& output_shape) {
1420   CHECK(input_shape.IsArray());
1421   CHECK(output_shape.IsArray());
1422   // Removing trivial dimensions from the shape simplifies the alignment
1423   // algorithm since ones can go in any position.
1424   if (HasDegenerateDimensions(input_shape) ||
1425       HasDegenerateDimensions(output_shape)) {
1426     auto simple_output_shape =
1427         AlignLayouts(DropDegenerateDimensions(input_shape),
1428                      DropDegenerateDimensions(output_shape));
1429     if (!simple_output_shape) {
1430       return absl::nullopt;
1431     }
1432 
1433     std::vector<int64> layout =
1434         SpanToVector(simple_output_shape->layout().minor_to_major());
1435     // For each one sized dimension in the output, increment the dimension
1436     // numbers in layout that are more minor than the one.
1437     absl::InlinedVector<int64, 8> dim_map;
1438     dim_map.reserve(simple_output_shape->rank());
1439     for (int64 i = 0; i < output_shape.rank(); ++i) {
1440       if (output_shape.dimensions(i) != 1) {
1441         dim_map.push_back(i);
1442       }
1443     }
1444     for (int64& d : layout) {
1445       d = dim_map[d];
1446     }
1447 
1448     // Add the ones in descending order to the layout. Descending layouts tend
1449     // to reduce the number of copies inserted in layout assignment.
1450     for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1451       if (output_shape.dimensions(i) == 1) {
1452         layout.push_back(i);
1453       }
1454     }
1455     Shape output_shape_with_layout = output_shape;
1456     *output_shape_with_layout.mutable_layout() = Layout{layout};
1457     return output_shape_with_layout;
1458   }
1459 
1460   int64 input_rank = input_shape.rank();
1461   int64 output_rank = output_shape.rank();
1462 
1463   // First, calculate an alignment of the dimensions. A consecutive sequence of
1464   // input dimensions and output dimensions belong to the same alignment part if
1465   // the products of their dimension bounds are the same. In the easiest case,
1466   // an alignment part consists of one input dimension and one output dimension
1467   // which both have the same dimension bound. An alignment part specifies which
1468   // dimensions need to be kept together in a physical layout if we want a
1469   // reshape to be a bitcast. The order of the alignment parts is defined by the
1470   // physical layout of the input shape, so when we construct the layout for the
1471   // output shape we just process the alignment parts in this order, and then
1472   // layout the dimensions belonging to each part in descending (major to minor)
1473   // order.
1474 
1475   // Stores the input and output dimension numbers where each alignment part
1476   // starts.
1477   std::vector<std::pair<int64, int64>> alignment;
1478   alignment.push_back({0, 0});
1479 
1480   // Stores a mapping from the input dimension to the alignment part it belongs
1481   // to.
1482   std::vector<int64> dimension_to_alignment_index(input_rank);
1483   int64 input_dimension_product = 1, output_dimension_product = 1;
1484   for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1485     // Check if we have reached the end of an alignment part.
1486     if (input_dimension_product == output_dimension_product &&
1487         input_dimension_product > 1) {
1488       alignment.push_back({i, j});
1489       input_dimension_product = output_dimension_product = 1;
1490     }
1491     if (input_dimension_product < output_dimension_product ||
1492         j == output_rank) {
1493       if (i == input_rank) {
1494         return absl::nullopt;
1495       }
1496       dimension_to_alignment_index[i] = alignment.size() - 1;
1497       input_dimension_product *= input_shape.dimensions(i);
1498       ++i;
1499     } else {
1500       output_dimension_product *= output_shape.dimensions(j);
1501       ++j;
1502     }
1503   }
1504   if (input_dimension_product != output_dimension_product) {
1505     return absl::nullopt;
1506   }
1507 
1508   // We also need to store an end element so that we know where the last
1509   // alignment part ends.
1510   alignment.push_back({input_rank, output_rank});
1511   // Now check if the physical layout can potentially be aligned to the output
1512   // shape by changing the physical layout of the output shape. We need to check
1513   // that all dimension numbers that belong to the same alignment part appear
1514   // consecutively, and are in descending order. However we can ignore any
1515   // trivial dimension bounds of 1, because they can be placed anywhere.
1516   auto input_dimension_numbers = input_shape.layout().minor_to_major();
1517   std::vector<int64> output_layout;
1518   output_layout.reserve(output_rank);
1519   for (int64 i = 0; i < input_rank;) {
1520     int64 current_dimension_number = input_dimension_numbers[i];
1521 
1522     // Trivial dimensions are stripped.
1523     CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1524     const int64 current_alignment_index =
1525         dimension_to_alignment_index[current_dimension_number];
1526     // Because of the special end element that we added, we can be sure that
1527     // 'current_alignment_index' is < alignment.size() - 1.
1528     CHECK_LT(current_alignment_index, alignment.size() - 1);
1529 
1530     // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1531     // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1532     // in descending order and belong to the current alignment part.
1533     for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1534                               alignment[current_alignment_index].first;
1535          ++i, ++j) {
1536       if (i == input_rank) {
1537         return absl::nullopt;
1538       }
1539       // If the current dimension number belongs to a different alignment part,
1540       // or the dimension numbers are not in descending order, we can return
1541       // early.
1542       if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1543               current_alignment_index ||
1544           input_dimension_numbers[i] > current_dimension_number) {
1545         return absl::nullopt;
1546       }
1547       current_dimension_number = input_dimension_numbers[i];
1548     }
1549     // The output dimension numbers that belong to the current alignment part
1550     // need to appear in the same descending order as in the input.
1551     for (int64 j = alignment[current_alignment_index + 1].second - 1;
1552          j >= alignment[current_alignment_index].second; --j) {
1553       output_layout.push_back(j);
1554     }
1555   }
1556   CHECK_EQ(output_layout.size(), output_rank);
1557   Shape output_shape_with_layout = MakeShapeWithLayout(
1558       output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1559       output_layout);
1560   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1561       << "reshape is not a bitcast for input_shape: "
1562       << ShapeUtil::HumanStringWithLayout(input_shape)
1563       << " and output_shape_with_layout: "
1564       << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1565   return output_shape_with_layout;
1566 }
1567 
DeleteDimension(int64 dim_to_delete,Shape shape)1568 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1569                                               Shape shape) {
1570   CHECK(shape.IsArray());
1571   shape.DeleteDimension(dim_to_delete);
1572   return shape;
1573 }
1574 
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1575 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1576     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1577   if (dynamic_shape.rank() != bounded_shape.rank()) {
1578     return false;
1579   }
1580   for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
1581     if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1582       return false;
1583     }
1584   }
1585   return true;
1586 }
1587 
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1588 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1589     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1590   bool compatible = true;
1591   xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
1592                                                      const ShapeIndex& index) {
1593     if (compatible) {
1594       auto subshape_result = TryGetSubshape(bounded_shape, index);
1595       if (subshape_result.ok()) {
1596         const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
1597         if (sub_shape.IsTuple()) {
1598           if (!bounded_sub_shape->IsTuple()) {
1599             compatible = false;
1600           }
1601         } else {
1602           if (bounded_sub_shape->IsTuple()) {
1603             compatible = false;
1604           } else if (!sub_shape.is_static() &&
1605                      !DynamicArrayShapeIsCompatible(sub_shape,
1606                                                     *bounded_sub_shape)) {
1607             compatible = false;
1608           }
1609         }
1610       } else {
1611         compatible = false;
1612       }
1613     }
1614   });
1615   return compatible;
1616 }
1617 
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1618 /* static */ Shape ShapeUtil::FilterDimensions(
1619     const std::function<bool(int64)>& p, Shape shape) {
1620   CHECK(shape.IsArray());
1621   std::vector<int64> dims_to_delete;
1622   for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1623     if (!p(i)) {
1624       dims_to_delete.push_back(i);
1625     }
1626   }
1627   for (int64 dim : dims_to_delete) {
1628     shape = DeleteDimension(dim, shape);
1629   }
1630   return shape;
1631 }
1632 
Hash(const Shape & shape)1633 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1634   using tensorflow::hash;
1635   using tensorflow::Hash64Combine;
1636 
1637   size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1638 
1639   if (shape.tuple_shapes().empty()) {
1640     for (int i = 0; i < shape.dimensions_size(); ++i) {
1641       hash_value =
1642           Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1643       hash_value = Hash64Combine(hash_value,
1644                                  hash<bool>()(shape.is_dynamic_dimension(i)));
1645     }
1646 
1647     hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1648   } else {
1649     hash_value = 0;
1650     for (const Shape& subshape : shape.tuple_shapes()) {
1651       hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1652     }
1653   }
1654 
1655   return hash_value;
1656 }
1657 
1658 // Returns the indices of the first elements of all consecutive subarrays of the
1659 // given array. For example:
1660 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1661 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1662   std::vector<size_t> is = {0};
1663   for (size_t i = 1; i < xs.size(); ++i) {
1664     if (1 != xs[i] - xs[i - 1]) {
1665       is.push_back(i);
1666     }
1667   }
1668   return is;
1669 }
1670 
1671 // Merges the sequences of dimensions of the given shape which start at the
1672 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1673 static Shape MergeDimensions(absl::Span<const size_t> segs,
1674                              const Shape& shape) {
1675   std::vector<int64> dimensions;
1676   for (size_t i = 1; i <= segs.size(); ++i) {
1677     dimensions.push_back(std::accumulate(
1678         shape.dimensions().begin() + segs[i - 1],
1679         shape.dimensions().begin() +
1680             (segs.size() == i ? shape.dimensions().size() : segs[i]),
1681         1, std::multiplies<int64>()));
1682   }
1683   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1684                                                   dimensions);
1685 }
1686 
FindTranspose021(const Shape & a,const Shape & b)1687 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1688     const Shape& a, const Shape& b) {
1689   if (!CompatibleIgnoringElementType(a, b)) {
1690     return absl::nullopt;
1691   }
1692 
1693   std::vector<int64> permutation(a.dimensions().size());
1694   absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1695   std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1696                                       minor_to_major_a.rend());
1697   absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1698   std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1699                                       minor_to_major_b.rend());
1700   for (size_t i = 0; i < permutation.size(); ++i) {
1701     permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1702   }
1703 
1704   std::vector<size_t> segments = ConsecutiveSegments(permutation);
1705   if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1706     Shape descending_layout_shape =
1707         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1708     Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1709     absl::Span<const int64> normalized_dims =
1710         AsInt64Slice(normalized_shape.dimensions());
1711     std::vector<int64> dims_021;
1712     if (2 == segments.size()) {
1713       // The logical component-0 is of size one.
1714       dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1715     } else {
1716       dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1717     }
1718 
1719     return dims_021;
1720   }
1721 
1722   return absl::nullopt;
1723 }
1724 
DeviceShapeToHostShape(Shape s)1725 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1726   ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1727     if (subshape->IsArray()) {
1728       subshape->mutable_layout()->clear_tiles();
1729       subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1730     }
1731   });
1732   return s;
1733 }
1734 
ElementCanUpcast(const Shape & from,const Shape & to)1735 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1736                                             const Shape& to) {
1737   return ElementIsFloating(from) == ElementIsFloating(to) &&
1738          ElementIsSigned(from) == ElementIsSigned(to) &&
1739          HigherPrecisionElementType(from, to) == to.element_type();
1740 }
1741 
1742 }  // namespace xla
1743