1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24 
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/index_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace xla {
44 namespace {
45 
46 using absl::StrCat;
47 
48 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
49 
50 // Converts between little and big endian.
51 //
52 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)53 void ConvertEndianShort(string* bytes) {
54   CHECK_EQ(bytes->size() / 2, 0);
55   for (int64 i = 0; i < bytes->size(); i += 2) {
56     std::swap((*bytes)[i], (*bytes)[i + 1]);
57   }
58 }
59 
ConvertEndianShort(char * bytes,int64 size)60 void ConvertEndianShort(char* bytes, int64 size) {
61   CHECK_EQ(size / 2, 0);
62   for (int64 i = 0; i < size; i += 2) {
63     std::swap(bytes[i], bytes[i + 1]);
64   }
65 }
66 
67 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
68 // able to transparently access the raw 16-bit value contained within.
69 template <typename T>
GetRawValue(T val)70 T GetRawValue(T val) {
71   return val;
72 }
GetRawValue(Eigen::half val)73 uint16 GetRawValue(Eigen::half val) { return val.x; }
74 
75 }  // namespace
76 
~LiteralBase()77 LiteralBase::~LiteralBase() {}
78 
operator <<(std::ostream & out,const Literal & literal)79 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
80   out << literal.ToString();
81   return out;
82 }
83 
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)84 MutableLiteralBase::StrideConfig::StrideConfig(
85     const Shape& source_shape, const Shape& dest_shape,
86     absl::Span<const int64> dimensions)
87     : dimensions(dimensions),
88       base(dimensions.size(), 0),
89       step(dimensions.size(), 1) {
90   if (!dimensions.empty()) {
91     // Selects the shape with the largest minor dimension as the one upon
92     // which to run the tight stride loop.
93     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
94         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
95       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
96       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
97     } else {
98       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
99       source_stride =
100           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
101     }
102     minor_loop_size = dimensions[minor_dimension];
103     step[minor_dimension] = minor_loop_size;
104   }
105 }
106 
Literal(const Shape & shape)107 Literal::Literal(const Shape& shape)
108     : Literal(shape, /*allocate_arrays=*/true) {}
109 
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)110 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
111   if (shape.IsTuple()) {
112     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
113       const Shape& subshape = shape.tuple_shapes(i);
114 
115       auto child_piece = Piece();
116       child_piece.set_subshape(&subshape);
117 
118       SetPiece(subshape, &child_piece, allocate_arrays);
119 
120       piece->emplace_back(std::move(child_piece));
121     }
122   } else if (shape.IsArray()) {
123     if (allocate_arrays) {
124       if (LayoutUtil::IsSparseArray(shape)) {
125         // For sparse arrays, the buffer must be of the size of the maximum
126         // number of sparse elements possible.
127         const int64 max_sparse_elements =
128             LayoutUtil::MaxSparseElements(shape.layout());
129         piece->set_buffer(
130             new char[max_sparse_elements *
131                      ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
132         piece->set_sparse_indices(
133             new SparseIndexArray(max_sparse_elements, shape.rank()));
134       } else {
135         piece->set_buffer(new char[piece->size_bytes()]);
136       }
137     }
138   } else {
139     // If the shape is neither an array nor tuple, then it must be
140     // zero-sized. Otherwise, some memory needs to be allocated for it.
141     CHECK_EQ(piece->size_bytes(), 0);
142   }
143 }
144 
Literal(const Shape & shape,bool allocate_arrays)145 Literal::Literal(const Shape& shape, bool allocate_arrays)
146     : MutableLiteralBase() {
147   shape_ = absl::make_unique<Shape>(shape);
148   CHECK(LayoutUtil::HasLayout(*shape_));
149   root_piece_ = new Piece();
150   root_piece_->set_subshape(shape_.get());
151   CHECK(&root_piece_->subshape() == shape_.get());
152 
153   SetPiece(*shape_, root_piece_, allocate_arrays);
154 }
155 
~Literal()156 Literal::~Literal() {
157   if (root_piece_ != nullptr) {
158     DeallocateBuffers();
159     delete root_piece_;
160   }
161 }
162 
DeallocateBuffers()163 void Literal::DeallocateBuffers() {
164   root_piece_->ForEachMutableSubpiece(
165       [&](const ShapeIndex& index, Piece* piece) {
166         if (piece->buffer() != nullptr) {
167           delete[] piece->buffer();
168           delete piece->sparse_indices();
169         }
170       });
171 }
172 
Literal(Literal && other)173 Literal::Literal(Literal&& other) : MutableLiteralBase() {
174   *this = std::move(other);
175 }
176 
operator =(Literal && other)177 Literal& Literal::operator=(Literal&& other) {
178   DCHECK(&other.root_piece_->subshape() == other.shape_.get());
179   using std::swap;
180   swap(shape_, other.shape_);
181   swap(root_piece_, other.root_piece_);
182   DCHECK(&root_piece_->subshape() == shape_.get());
183 
184   return *this;
185 }
186 
CreateFromShape(const Shape & shape)187 Literal LiteralBase::CreateFromShape(const Shape& shape) {
188   Literal literal(shape);
189   literal.root_piece_->ForEachMutableSubpiece(
190       [&](const ShapeIndex& index, Piece* piece) {
191         if (piece->subshape().IsArray()) {
192           memset(piece->untyped_data(), 0, piece->size_bytes());
193         }
194       });
195   return literal;
196 }
197 
sparse_indices(const ShapeIndex & shape_index) const198 const SparseIndexArray* LiteralBase::sparse_indices(
199     const ShapeIndex& shape_index) const {
200   return piece(shape_index).sparse_indices();
201 }
202 
sparse_indices(const ShapeIndex & shape_index)203 SparseIndexArray* MutableLiteralBase::sparse_indices(
204     const ShapeIndex& shape_index) {
205   return piece(shape_index).sparse_indices();
206 }
207 
208 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)209 Status MutableLiteralBase::CopySliceFromInternal(
210     const LiteralBase& src_literal, absl::Span<const int64> src_base,
211     absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
212   TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
213   TF_RET_CHECK(shape().rank() == dest_base.size());
214 
215   auto linear_index = [](const Shape& shape,
216                          absl::Span<const int64> multi_index) {
217     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
218   };
219 
220   if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
221     // If any of the two shapes are scalars, we can just call the StridedCopy()
222     // directly, and we know we will be copying only one value.
223     TF_RET_CHECK(copy_size.empty());
224     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
225                 src_literal.data<NativeT>(),
226                 linear_index(src_literal.shape(), src_base), 0, 1);
227   } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
228              !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
229     // Perform copy if neither src nor dest has dimensions with zero element,
230     // otherwise it's a no-op.
231     TF_RET_CHECK(src_base.size() == dest_base.size());
232     TF_RET_CHECK(src_base.size() == copy_size.size());
233 
234     // Scan the source from minor, stepping in copy size blocks, then within
235     // the index enumaration functor, do a strided copy advancing source index
236     // by one (walking through the minor dimension), and destination index by
237     // proper stride size at the matching dimension.
238     DimensionVector src_indexes(src_base.size(), 0);
239     DimensionVector dest_indexes(dest_base.size(), 0);
240     MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
241                                                    copy_size);
242 
243     auto copy_proc = [&](absl::Span<const int64> indexes) {
244       // Map from multi-dimensional index, to source index.
245       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
246                      src_indexes.begin(), std::plus<int64>());
247       // Map from multi-dimensional index, to destination index.
248       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
249                      dest_indexes.begin(), std::plus<int64>());
250 
251       int64 src_index = linear_index(src_literal.shape(), src_indexes);
252       int64 dest_index = linear_index(shape(), dest_indexes);
253 
254       // `this->` is needed to workaround MSVC bug: #16882
255       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
256                   src_literal.data<NativeT>(), src_index,
257                   stride_config.source_stride, stride_config.minor_loop_size);
258       return true;
259     };
260 
261     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
262                             stride_config.dimensions, stride_config.step,
263                             copy_proc);
264   }
265   return Status::OK();
266 }
267 
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)268 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
269                                            absl::Span<const int64> src_index,
270                                            absl::Span<const int64> dest_index) {
271   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
272   const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
273       src_literal.shape(), src_index);
274   const int64 dest_linear_index =
275       IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
276   const int64 primitive_size =
277       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
278 
279   char* dest_address =
280       static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
281   const char* source_address =
282       static_cast<const char*>(src_literal.untyped_data()) +
283       src_linear_index * primitive_size;
284   if (dest_address != source_address) {
285     memcpy(dest_address, source_address, primitive_size);
286   }
287   return Status::OK();
288 }
289 
CreateFromProto(const LiteralProto & proto)290 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
291     const LiteralProto& proto) {
292   if (!proto.has_shape()) {
293     return InvalidArgument("LiteralProto has no shape");
294   }
295   Shape shape(proto.shape());
296   if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) {
297     return InvalidArgument("Literal shape cannot include OPAQUE sub-shape");
298   }
299   if (!LayoutUtil::HasLayout(shape)) {
300     return InvalidArgument("LiteralProto has no layout");
301   }
302 
303   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
304 
305   Literal literal(shape);
306 
307   TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
308       [&](const ShapeIndex& index, Piece* piece) {
309         const LiteralProto* proto_element = &proto;
310         for (int64 i : index) {
311           CHECK(i < proto_element->tuple_literals_size());
312           proto_element = &proto_element->tuple_literals(i);
313         }
314 
315         if (piece->subshape().IsTuple()) {
316           if (proto_element->tuple_literals_size() !=
317               ShapeUtil::TupleElementCount(piece->subshape())) {
318             return InvalidArgument(
319                 "Expected %d tuple elements in LiteralProto, has %d",
320                 ShapeUtil::TupleElementCount(piece->subshape()),
321                 proto_element->tuple_literals_size());
322           }
323           return Status::OK();
324         }
325         if (piece->subshape().element_type() == TOKEN) {
326           return Status::OK();
327         }
328 
329         CHECK(piece->subshape().IsArray());
330         TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
331 
332         return Status::OK();
333       }));
334 
335   return std::move(literal);
336 }
337 
DecomposeTuple()338 std::vector<Literal> Literal::DecomposeTuple() {
339   CHECK(shape().IsTuple());
340   std::vector<Literal> elements;
341   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
342     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
343                                /*allocate_arrays=*/false));
344     Literal& element = elements.back();
345     element.root_piece_->ForEachMutableSubpiece(
346         [&](const ShapeIndex& index, Piece* dest_piece) {
347           ShapeIndex src_index = {i};
348           for (int64 j : index) {
349             src_index.push_back(j);
350           }
351           Piece& src_piece = piece(src_index);
352 
353           // Move the respective buffer and sparse indices over to the element
354           // Literal.
355           dest_piece->set_buffer(src_piece.buffer());
356           src_piece.set_buffer(nullptr);
357           dest_piece->set_sparse_indices(src_piece.sparse_indices());
358           src_piece.set_sparse_indices(nullptr);
359         });
360   }
361   // Set this literal to be nil-shaped.
362   *this = Literal();
363   return elements;
364 }
365 
366 namespace {
367 
368 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
369 // the array slices are indicated by dest_shape and src_shape respectively.
370 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)371 void CopyElementsBetween(absl::Span<NativeT> dest,
372                          absl::Span<const NativeT> src, const Shape& dest_shape,
373                          const Shape& src_shape) {
374   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
375   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
376     return;
377   }
378   std::vector<int64> index(dest_shape.rank());
379   do {
380     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
381         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
382   } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
383 }
384 
385 }  // namespace
386 
CopyFrom(const LiteralBase::Piece & src)387 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
388   CHECK(subshape_ != nullptr);
389   CHECK(src.subshape_ != nullptr);
390   if (ShapeUtil::Equal(subshape(), src.subshape())) {
391     // If the layouts are equal it's faster just to memcpy.
392     memcpy(buffer(), src.buffer(), src.size_bytes());
393   } else {
394     TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
395     std::vector<int64> origin(subshape().rank(), 0);
396     switch (subshape().element_type()) {
397 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                    \
398   case (XLA_T):                                                           \
399     CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
400                                   subshape(), src.subshape());            \
401     break;
402       COPY_ELEMENTS(U8, uint8);
403       COPY_ELEMENTS(U16, uint16);
404       COPY_ELEMENTS(U32, uint32);
405       COPY_ELEMENTS(U64, uint64);
406       COPY_ELEMENTS(S8, int8);
407       COPY_ELEMENTS(S16, int16);
408       COPY_ELEMENTS(S32, int32);
409       COPY_ELEMENTS(S64, int64);
410       COPY_ELEMENTS(F16, half);
411       COPY_ELEMENTS(BF16, bfloat16);
412       COPY_ELEMENTS(F32, float);
413       COPY_ELEMENTS(F64, double);
414       COPY_ELEMENTS(C64, complex64);
415       COPY_ELEMENTS(C128, complex128);
416       COPY_ELEMENTS(PRED, bool);
417 #undef COPY_ELEMENTS
418       default:
419         return Unimplemented(
420             "Copying a Literal object with element type %s is not implemented.",
421             PrimitiveType_Name(subshape().element_type()));
422     }
423   }
424   return Status::OK();
425 }
426 
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index)427 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
428                                     const ShapeIndex& dest_shape_index,
429                                     const ShapeIndex& src_shape_index) {
430   const Shape& dest_subshape =
431       ShapeUtil::GetSubshape(shape(), dest_shape_index);
432   const Shape& src_subshape =
433       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
434   if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
435     return InvalidArgument(
436         "Destination subshape incompatible with source subshape: %s vs %s",
437         ShapeUtil::HumanString(dest_subshape),
438         ShapeUtil::HumanString(src_subshape));
439   }
440   return root_piece_->ForEachMutableSubpieceWithStatus(
441       [&](const ShapeIndex& index, Piece* piece) {
442         if (!piece->subshape().IsArray()) {
443           return Status::OK();
444         }
445 
446         // Determine if this index is in the part of this literal that we want
447         // to copy over from src_literal.
448         bool in_subtree_to_copy = true;
449         for (int i = 0; i < dest_shape_index.size(); ++i) {
450           if (index[i] != dest_shape_index[i]) {
451             in_subtree_to_copy = false;
452             break;
453           }
454         }
455         if (!in_subtree_to_copy) {
456           return Status::OK();
457         }
458         // Construct the index of the corresponding piece in the source literal.
459         ShapeIndex src_piece_index = src_shape_index;
460         for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
461           src_piece_index.push_back(index[i]);
462         }
463         TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
464         return Status::OK();
465       });
466 }
467 
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)468 Status Literal::MoveFrom(Literal&& src_literal,
469                          const ShapeIndex& dest_shape_index) {
470   const Shape& dest_subshape =
471       ShapeUtil::GetSubshape(shape(), dest_shape_index);
472   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
473     return InvalidArgument(
474         "Destination subshape not equal to source shape: %s vs %s",
475         ShapeUtil::HumanString(dest_subshape),
476         ShapeUtil::HumanString(src_literal.shape()));
477   }
478 
479   src_literal.root_piece_->ForEachSubpiece(
480       [&](const ShapeIndex& src_index, const Piece& src_piece) {
481         if (!src_piece.subshape().IsArray()) {
482           return;
483         }
484 
485         ShapeIndex dest_index = dest_shape_index;
486         for (int64 i : src_index) {
487           dest_index.push_back(i);
488         }
489         Piece& dest_piece = piece(dest_index);
490         delete[] dest_piece.buffer();
491         dest_piece.set_buffer(src_piece.buffer());
492         delete dest_piece.sparse_indices();
493         dest_piece.set_sparse_indices(src_piece.sparse_indices());
494       });
495 
496   src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
497   delete src_literal.root_piece_;
498   src_literal.root_piece_ = new LiteralBase::Piece();
499   src_literal.root_piece_->set_subshape(src_literal.shape_.get());
500 
501   return Status::OK();
502 }
503 
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)504 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
505                                          absl::Span<const int64> src_base,
506                                          absl::Span<const int64> dest_base,
507                                          absl::Span<const int64> copy_size) {
508   TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
509   TF_RET_CHECK(src_literal.shape().IsArray())
510       << ShapeUtil::HumanString(src_literal.shape());
511   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
512 
513   switch (shape().element_type()) {
514     case U8:
515       return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
516                                           copy_size);
517     case U16:
518       return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
519                                            copy_size);
520     case U32:
521       return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
522                                            copy_size);
523     case U64:
524       return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
525                                            copy_size);
526     case S8:
527       return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
528                                          copy_size);
529     case S16:
530       return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
531                                           copy_size);
532     case S32:
533       return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
534                                           copy_size);
535     case S64:
536       return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
537                                           copy_size);
538     case F16:
539       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
540                                          copy_size);
541     case BF16:
542       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
543                                              copy_size);
544     case F32:
545       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
546                                           copy_size);
547     case F64:
548       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
549                                            copy_size);
550     case C64:
551       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
552                                               copy_size);
553     case C128:
554       return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
555                                                copy_size);
556     case PRED:
557       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
558                                          copy_size);
559     default:
560       break;
561   }
562   return Unimplemented(
563       "Copying a slice from a Literal object with element type %d is not "
564       "implemented.",
565       shape().element_type());
566 }
567 
PopulateR1(const tensorflow::core::Bitmap & values)568 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
569   CHECK(shape().IsArray());
570   CHECK_EQ(shape().rank(), 1);
571   CHECK_EQ(element_count(), values.bits());
572   CHECK_EQ(shape().element_type(), PRED);
573   for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
574     Set({i}, values.get(i));
575   }
576 }
577 
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const578 Literal LiteralBase::Relayout(const Layout& new_layout,
579                               const ShapeIndex& shape_index) const {
580   // Create new shape with 'new_layout' set at the given shape index.
581   Shape new_shape = shape();
582   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
583   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
584   *subshape->mutable_layout() = new_layout;
585   Literal result(new_shape);
586   TF_CHECK_OK(result.CopyFrom(*this));
587   return result;
588 }
589 
Relayout(const Shape & shape_with_layout) const590 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
591   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
592       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
593       << " not compatible with literal shape "
594       << ShapeUtil::HumanString(shape());
595   Literal result = CreateFromShape(shape_with_layout);
596   ShapeUtil::ForEachSubshape(
597       result.shape(),
598       [this, &result](const Shape& subshape, const ShapeIndex& index) {
599         if (subshape.IsArray()) {
600           TF_CHECK_OK(result.CopyFrom(*this,
601                                       /*dest_shape_index=*/index,
602                                       /*src_shape_index=*/index));
603         }
604       });
605   return result;
606 }
607 
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const608 StatusOr<Literal> LiteralBase::Broadcast(
609     const Shape& result_shape, absl::Span<const int64> dimensions) const {
610   if (!shape().IsArray()) {
611     return InvalidArgument("Broadcast only supports arrays.");
612   }
613 
614   for (int64 i = 0; i < dimensions.size(); i++) {
615     TF_RET_CHECK(shape().dimensions(i) ==
616                  result_shape.dimensions(dimensions[i]));
617   }
618 
619   Literal result(result_shape);
620 
621   // scratch_source_index is temporary storage space for the computed index into
622   // the input literal.  We put it here to avoid allocating an std::vector in
623   // every iteration of ShapeUtil::ForEachIndex.
624   std::vector<int64> scratch_source_index(shape().dimensions_size());
625 
626   char* dest_data = static_cast<char*>(result.untyped_data());
627   const char* source_data = static_cast<const char*>(untyped_data());
628   const int64 primitive_size =
629       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
630 
631   ShapeUtil::ForEachIndex(
632       result_shape, [&](absl::Span<const int64> output_index) {
633         for (int64 i = 0; i < dimensions.size(); ++i) {
634           scratch_source_index[i] = output_index[dimensions[i]];
635         }
636         int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
637             result_shape, output_index);
638         int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
639             shape(), scratch_source_index);
640         memcpy(dest_data + primitive_size * dest_index,
641                source_data + primitive_size * source_index, primitive_size);
642         return true;
643       });
644 
645   return std::move(result);
646 }
647 
Reshape(absl::Span<const int64> dimensions) const648 StatusOr<Literal> LiteralBase::Reshape(
649     absl::Span<const int64> dimensions) const {
650   if (!shape().IsArray()) {
651     return InvalidArgument("Reshape does not support tuples.");
652   }
653   Literal output;
654   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
655     output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
656   } else {
657     output = Clone();
658   }
659   // Because the layout is monotonic, we can simply reuse the same sequence of
660   // values without changing their order.
661   *output.mutable_shape_do_not_use() =
662       ShapeUtil::MakeShape(shape().element_type(), dimensions);
663 
664   int64 elements_before = ShapeUtil::ElementsIn(shape());
665   int64 elements_after = ShapeUtil::ElementsIn(output.shape());
666   if (elements_before != elements_after) {
667     return InvalidArgument(
668         "Shapes before and after Literal::Reshape have different numbers "
669         "of elements: %s vs %s.",
670         ShapeUtil::HumanString(shape()),
671         ShapeUtil::HumanString(output.shape()));
672   }
673   return std::move(output);
674 }
675 
Transpose(absl::Span<const int64> permutation) const676 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
677   CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
678   CHECK(IsPermutation(permutation, shape().rank()))
679       << "Given permutation is not a permutation of dimension numbers";
680   // To transpose the array, we just permute the dimensions and layout, and
681   // do a straight memory copy of the raw data set.
682   // This is considerably faster than iterating over every array element using
683   // the EachCell<>() and Set<>() APIs.
684   std::vector<int64> inverse_permutation = InversePermutation(permutation);
685   Shape permuted_shape =
686       ShapeUtil::PermuteDimensions(inverse_permutation, shape());
687   // Replace the layout with one affine to this shape, such that a
688   // transpose operation can be performed by leaving the flat values
689   // representation intact.
690   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
691   // The shape with affine layout resulting from that operation will be
692   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
693   // most minor.
694   //
695   // Essentially, given MinMaj(Di) the position of the Di dimension within the
696   // minor to major vector, and given T(Di) the index that the original Di
697   // dimension has within the transposed array, a layout is affine if
698   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
699   // vector of the affine layout.
700   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
701   Layout* layout = permuted_shape.mutable_layout();
702   layout->clear_minor_to_major();
703   for (auto index : LayoutUtil::MinorToMajor(shape())) {
704     layout->add_minor_to_major(inverse_permutation[index]);
705   }
706   Literal new_literal(permuted_shape);
707   DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
708             ShapeUtil::ByteSizeOf(shape()));
709   std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
710   return new_literal;
711 }
712 
713 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const714 Literal LiteralBase::SliceInternal(
715     const Shape& result_shape, absl::Span<const int64> start_indices) const {
716   Literal result_literal(result_shape);
717   DimensionVector new_indices(result_shape.rank());
718   result_literal.EachCell<NativeT>(
719       [&](absl::Span<const int64> indices, NativeT /*value*/) {
720         for (int64 i = 0; i < result_shape.rank(); ++i) {
721           new_indices[i] = indices[i] + start_indices[i];
722         }
723         NativeT value = Get<NativeT>(new_indices);
724         result_literal.Set<NativeT>(indices, value);
725       });
726   return result_literal;
727 }
728 
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const729 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
730                            absl::Span<const int64> limit_indices) const {
731   CHECK(shape().IsArray()) << "tuple is not supported for slice";
732 
733   DimensionVector result_dimensions;
734   for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
735     CHECK_GE(start_indices[dnum], 0);
736     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
737         << "dnum = " << dnum;
738     int64 dimension = limit_indices[dnum] - start_indices[dnum];
739     CHECK_GE(dimension, 0) << "dnum = " << dnum;
740     result_dimensions.push_back(dimension);
741   }
742   const auto result_shape =
743       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
744                                      LayoutUtil::MinorToMajor(shape()));
745   switch (result_shape.element_type()) {
746     case PRED:
747       return SliceInternal<bool>(result_shape, start_indices);
748     case U8:
749       return SliceInternal<uint8>(result_shape, start_indices);
750     case U16:
751       return SliceInternal<uint16>(result_shape, start_indices);
752     case U32:
753       return SliceInternal<uint32>(result_shape, start_indices);
754     case U64:
755       return SliceInternal<uint64>(result_shape, start_indices);
756     case S8:
757       return SliceInternal<int8>(result_shape, start_indices);
758     case S16:
759       return SliceInternal<int16>(result_shape, start_indices);
760     case S32:
761       return SliceInternal<int32>(result_shape, start_indices);
762     case S64:
763       return SliceInternal<int64>(result_shape, start_indices);
764     case F16:
765       return SliceInternal<half>(result_shape, start_indices);
766     case BF16:
767       return SliceInternal<bfloat16>(result_shape, start_indices);
768     case F32:
769       return SliceInternal<float>(result_shape, start_indices);
770     case F64:
771       return SliceInternal<double>(result_shape, start_indices);
772     case C64:
773       return SliceInternal<complex64>(result_shape, start_indices);
774     case C128:
775       return SliceInternal<complex128>(result_shape, start_indices);
776     default:
777       LOG(FATAL) << "not yet implemented: "
778                  << PrimitiveType_Name(result_shape.element_type());
779   }
780 }
781 
Clone() const782 Literal LiteralBase::Clone() const {
783   Literal result(shape());
784   TF_CHECK_OK(result.CopyFrom(*this));
785   return result;
786 }
787 
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const788 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
789                                 const ShapeIndex& shape_index) const {
790   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
791   CHECK(LayoutUtil::IsDenseArray(subshape));
792   switch (subshape.element_type()) {
793     case PRED:
794       return Get<bool>(multi_index, shape_index) ? "true" : "false";
795     case S8:
796       return StrCat(Get<int8>(multi_index, shape_index));
797     case S16:
798       return StrCat(Get<int16>(multi_index, shape_index));
799     case S32:
800       return StrCat(Get<int32>(multi_index, shape_index));
801     case S64:
802       return StrCat(Get<int64>(multi_index, shape_index));
803     case U8:
804       return StrCat(Get<uint8>(multi_index, shape_index));
805     case U16:
806       return StrCat(Get<uint16>(multi_index, shape_index));
807     case U32:
808       return StrCat(Get<uint32>(multi_index, shape_index));
809     case U64:
810       return StrCat(Get<uint64>(multi_index, shape_index));
811     case F16:
812       return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
813     case F32:
814       return StrCat(Get<float>(multi_index, shape_index));
815     case BF16:
816       return StrCat(
817           static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
818     case F64:
819       return StrCat(Get<double>(multi_index, shape_index));
820     case C64: {
821       complex64 c = Get<complex64>(multi_index, shape_index);
822       return StrCat("(", c.real(), ", ", c.imag(), ")");
823     }
824     case C128: {
825       complex128 c = Get<complex128>(multi_index, shape_index);
826       return StrCat("(", c.real(), ", ", c.imag(), ")");
827     }
828     default:
829       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
830   }
831 }
832 
GetSparseElementAsString(int64 sparse_element_number,const ShapeIndex & shape_index) const833 string LiteralBase::GetSparseElementAsString(
834     int64 sparse_element_number, const ShapeIndex& shape_index) const {
835   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
836   CHECK(LayoutUtil::IsSparseArray(subshape));
837   switch (subshape.element_type()) {
838     case PRED:
839       return GetSparseElement<bool>(sparse_element_number, shape_index)
840                  ? "true"
841                  : "false";
842     case S8:
843       return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
844     case S16:
845       return StrCat(
846           GetSparseElement<int16>(sparse_element_number, shape_index));
847     case S32:
848       return StrCat(
849           GetSparseElement<int32>(sparse_element_number, shape_index));
850     case S64:
851       return StrCat(
852           GetSparseElement<int64>(sparse_element_number, shape_index));
853     case U8:
854       return StrCat(
855           GetSparseElement<uint8>(sparse_element_number, shape_index));
856     case U16:
857       return StrCat(
858           GetSparseElement<uint16>(sparse_element_number, shape_index));
859     case U32:
860       return StrCat(
861           GetSparseElement<uint32>(sparse_element_number, shape_index));
862     case U64:
863       return StrCat(
864           GetSparseElement<uint64>(sparse_element_number, shape_index));
865     case F16:
866       return StrCat(static_cast<float>(
867           GetSparseElement<half>(sparse_element_number, shape_index)));
868     case F32:
869       return StrCat(
870           GetSparseElement<float>(sparse_element_number, shape_index));
871     case BF16:
872       return StrCat(static_cast<float>(
873           GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
874     case F64:
875       return StrCat(
876           GetSparseElement<double>(sparse_element_number, shape_index));
877     case C64: {
878       complex64 c =
879           GetSparseElement<complex64>(sparse_element_number, shape_index);
880       return StrCat("(", c.real(), ", ", c.imag(), ")");
881     }
882     case C128: {
883       complex128 c =
884           GetSparseElement<complex128>(sparse_element_number, shape_index);
885       return StrCat("(", c.real(), ", ", c.imag(), ")");
886     }
887     default:
888       LOG(FATAL) << "Invalid element type for sparse arrays: "
889                  << PrimitiveType_Name(subshape.element_type());
890   }
891 }
892 
GetIntegralAsS64(absl::Span<const int64> multi_index) const893 StatusOr<int64> LiteralBase::GetIntegralAsS64(
894     absl::Span<const int64> multi_index) const {
895   CHECK(LayoutUtil::IsDenseArray(shape()));
896   switch (shape().element_type()) {
897     case PRED:
898       return Get<bool>(multi_index);
899     case U8:
900       return Get<uint8>(multi_index);
901     case S32:
902       return Get<int32>(multi_index);
903     case S64:
904       return Get<int64>(multi_index);
905     case U32:
906       return Get<uint32>(multi_index);
907     case U64:
908       return Get<uint64>(multi_index);
909     default:
910       return FailedPrecondition("Array element type is not integral: %s",
911                                 PrimitiveType_Name(shape().element_type()));
912   }
913 }
914 
Hash() const915 size_t LiteralBase::Hash() const {
916   using tensorflow::Hash64;
917   using tensorflow::Hash64Combine;
918 
919   size_t hash_value = ShapeUtil::Hash(shape());
920 
921   ShapeUtil::ForEachSubshape(
922       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
923         if (!subshape.IsArray()) {
924           return;
925         }
926 
927         CHECK(LayoutUtil::IsDense(subshape.layout()));
928         hash_value = Hash64Combine(
929             hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
930                                size_bytes(index)));
931       });
932 
933   return hash_value;
934 }
935 
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)936 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
937                                             int64 value) {
938   CHECK(LayoutUtil::IsDenseArray(shape()));
939   switch (shape().element_type()) {
940     case PRED:
941       Set<bool>(multi_index, value);
942       break;
943     case U8:
944       Set<uint8>(multi_index, value);
945       break;
946     case S32:
947       Set<int32>(multi_index, value);
948       break;
949     case S64:
950       Set<int64>(multi_index, value);
951       break;
952     case U32:
953       Set<uint32>(multi_index, value);
954       break;
955     case U64:
956       Set<uint64>(multi_index, value);
957       break;
958     default:
959       return FailedPrecondition("Array element type is not integral: %s",
960                                 PrimitiveType_Name(shape().element_type()));
961   }
962   return Status::OK();
963 }
964 
GetSparseIndex(int64 sparse_element_number,const ShapeIndex & shape_index) const965 absl::Span<const int64> LiteralBase::GetSparseIndex(
966     int64 sparse_element_number, const ShapeIndex& shape_index) const {
967   const Piece& p = piece(shape_index);
968   CHECK_GE(sparse_element_number, 0);
969   CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
970   return p.sparse_indices()->At(sparse_element_number);
971 }
972 
SortSparseElements(const ShapeIndex & shape_index)973 void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
974   piece(shape_index).SortSparseElements();
975 }
976 
SortSparseElements()977 void LiteralBase::Piece::SortSparseElements() {
978   switch (subshape().element_type()) {
979     case PRED:
980       SortSparseElementsInternal<bool>();
981       break;
982     case S8:
983       SortSparseElementsInternal<int8>();
984       break;
985     case U8:
986       SortSparseElementsInternal<uint8>();
987       break;
988     case S16:
989       SortSparseElementsInternal<int16>();
990       break;
991     case U16:
992       SortSparseElementsInternal<uint16>();
993       break;
994     case S32:
995       SortSparseElementsInternal<int32>();
996       break;
997     case U32:
998       SortSparseElementsInternal<uint32>();
999       break;
1000     case S64:
1001       SortSparseElementsInternal<int64>();
1002       break;
1003     case U64:
1004       SortSparseElementsInternal<uint64>();
1005       break;
1006     case F32:
1007       SortSparseElementsInternal<float>();
1008       break;
1009     case F64:
1010       SortSparseElementsInternal<double>();
1011       break;
1012     case C64:
1013       SortSparseElementsInternal<complex64>();
1014       break;
1015     case C128:
1016       SortSparseElementsInternal<complex128>();
1017       break;
1018     case F16:
1019       SortSparseElementsInternal<half>();
1020       break;
1021     case BF16:
1022       SortSparseElementsInternal<bfloat16>();
1023       break;
1024     default:
1025       LOG(FATAL) << "Element type not valid for sparse array: "
1026                  << PrimitiveType_Name(subshape().element_type());
1027   }
1028 }
1029 
1030 template <typename NativeT>
SortSparseElementsInternal()1031 void LiteralBase::Piece::SortSparseElementsInternal() {
1032   CHECK(LayoutUtil::IsSparseArray(subshape()));
1033   int64 num_elements = sparse_indices()->index_count();
1034   auto values = data<NativeT>();
1035   CHECK_LE(num_elements, values.size());
1036   sparse_indices()->SortWithValues(
1037       absl::Span<NativeT>(values.data(), num_elements));
1038 }
1039 
1040 namespace {
1041 
ShapeToString(bool print_layout,const Shape & shape)1042 string ShapeToString(bool print_layout, const Shape& shape) {
1043   return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1044                       : ShapeUtil::HumanString(shape);
1045 }
1046 
1047 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1048                     bool print_shape, bool print_layout,
1049                     std::vector<string>* pieces);
1050 
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1051 void TupleToStringHelper(const LiteralBase& literal,
1052                          const ShapeIndex& shape_index, bool print_shape,
1053                          bool print_layout, std::vector<string>* pieces) {
1054   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1055   pieces->push_back("(\n");
1056   std::vector<string> tuple_pieces;
1057   for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1058     ShapeIndex element_index = shape_index;
1059     element_index.push_back(i);
1060     std::vector<string> element_pieces;
1061     ToStringHelper(literal, element_index, print_shape, print_layout,
1062                    &element_pieces);
1063     tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1064   }
1065   pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1066   pieces->push_back("\n)");
1067 }
1068 
SparseArrayToStringHelper(const LiteralBase & literal,const Shape & subshape,bool print_shape,bool print_layout,std::vector<string> * pieces)1069 void SparseArrayToStringHelper(const LiteralBase& literal,
1070                                const Shape& subshape, bool print_shape,
1071                                bool print_layout, std::vector<string>* pieces) {
1072   if (print_shape) {
1073     pieces->push_back(ShapeToString(print_layout, subshape));
1074   }
1075   pieces->push_back("{");
1076   int64 rank = subshape.rank();
1077   int64 num_elements = literal.sparse_element_count();
1078   for (int64 i = 0; i < num_elements; ++i) {
1079     if (i > 0) {
1080       pieces->push_back(", ");
1081     }
1082     if (rank == 1) {
1083       pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
1084       pieces->push_back(": ");
1085     } else {
1086       pieces->push_back("[");
1087       pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
1088       pieces->push_back("]: ");
1089     }
1090     pieces->push_back(literal.GetSparseElementAsString(i));
1091   }
1092   pieces->push_back("}");
1093 }
1094 
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1095 void DenseArrayToStringHelper(const LiteralBase& literal,
1096                               const ShapeIndex& shape_index, bool print_shape,
1097                               bool print_layout, std::vector<string>* pieces) {
1098   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1099   int64 rank = subshape.rank();
1100 
1101   std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1102       to_string_recursive = [&](absl::Span<const int64> dimensions,
1103                                 std::vector<int64>* accum_indices) {
1104         // dimensions.size() decreases by 1 at each recursive call,
1105         // and accum_indices->size() increases by 1.
1106         // Their sum is equal to the rank of the tensor.
1107         CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1108 
1109         auto brace_to_string = [&](string brace) -> string {
1110           // Handle 1D tensor
1111           if (rank == 1) {
1112             return brace;
1113           }
1114           // Handle the innermost tensor of a 2D+ tensor.
1115           if (dimensions.size() == 1 && brace == "{") {
1116             return StrCat("  ", brace, dimensions[0] <= 1 ? "" : " ");
1117           }
1118           if (dimensions.size() == 1 && brace == "}") {
1119             return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1120           }
1121           // Handle the non-innermost tensors of a 2D+ tensor.
1122           if (brace == "{") {
1123             if (rank > 3 && !accum_indices->empty() &&
1124                 accum_indices->size() < rank) {
1125               int index = accum_indices->size() - 1;
1126               int value = accum_indices->back();
1127               return StrCat(brace, " /*i", index, "=", value, "*/\n");
1128             }
1129             return StrCat(brace, "\n");
1130           }
1131           return StrCat("\n", brace);
1132         };
1133 
1134         if (dimensions.empty()) {
1135           // Display predicates as 0s and 1s so that the string is more dense.
1136           string elem;
1137           if (subshape.element_type() == PRED && rank > 0) {
1138             elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1139           } else {
1140             elem = literal.GetAsString(*accum_indices, shape_index);
1141           }
1142           pieces->push_back(elem);
1143         } else {
1144           pieces->push_back(brace_to_string("{"));
1145           for (int i = 0; i < dimensions[0]; ++i) {
1146             std::vector<int64> cloned_indices(*accum_indices);
1147             cloned_indices.push_back(i);
1148             to_string_recursive(dimensions.subspan(1), &cloned_indices);
1149             if (i < dimensions[0] - 1) {
1150               pieces->push_back(",");
1151               pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1152             }
1153           }
1154           pieces->push_back(brace_to_string("}"));
1155         }
1156       };
1157 
1158   if (print_shape) {
1159     pieces->push_back(ShapeToString(print_layout, subshape));
1160     pieces->push_back(" ");
1161   }
1162   std::vector<int64> indices = {};
1163   std::vector<int64> dimensions(subshape.dimensions().begin(),
1164                                 subshape.dimensions().end());
1165   to_string_recursive(dimensions, &indices);
1166 }
1167 
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1168 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1169                     bool print_shape, bool print_layout,
1170                     std::vector<string>* pieces) {
1171   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1172   CHECK(LayoutUtil::HasLayout(literal.shape()));
1173   CHECK(LayoutUtil::HasLayout(subshape));
1174   if (subshape.IsTuple()) {
1175     TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1176                         pieces);
1177   } else if (subshape.IsToken()) {
1178     pieces->push_back("token");
1179   } else if (LayoutUtil::IsSparseArray(subshape)) {
1180     SparseArrayToStringHelper(literal, subshape, print_shape, print_layout,
1181                               pieces);
1182   } else {
1183     CHECK(LayoutUtil::IsDenseArray(subshape));
1184     DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1185                              pieces);
1186   }
1187 }
1188 
1189 }  // namespace
1190 
sparse_element_count() const1191 int64 LiteralBase::sparse_element_count() const {
1192   CHECK(LayoutUtil::IsSparseArray(shape()));
1193   return sparse_indices()->index_count();
1194 }
1195 
ToString() const1196 string LiteralBase::ToString() const {
1197   std::vector<string> pieces;
1198   CHECK(LayoutUtil::HasLayout(this->shape()));
1199   ToStringHelper(*this, {}, /*print_shape=*/true,
1200                  /*print_layout=*/false, &pieces);
1201   return absl::StrJoin(pieces, "");
1202 }
1203 
ToStringWithoutShape() const1204 string LiteralBase::ToStringWithoutShape() const {
1205   std::vector<string> pieces;
1206   CHECK(LayoutUtil::HasLayout(this->shape()));
1207   ToStringHelper(*this, {}, /*print_shape=*/false,
1208                  /*print_layout=*/false, &pieces);
1209   return absl::StrJoin(pieces, "");
1210 }
1211 
ToStringWithLayout() const1212 string LiteralBase::ToStringWithLayout() const {
1213   std::vector<string> pieces;
1214   CHECK(LayoutUtil::HasLayout(this->shape()));
1215   ToStringHelper(*this, {}, /*print_shape=*/true,
1216                  /*print_layout=*/true, &pieces);
1217   return absl::StrJoin(pieces, "");
1218 }
1219 
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1220 void LiteralBase::EachCellAsString(
1221     const std::function<void(absl::Span<const int64> indices,
1222                              const string& value)>& per_cell) const {
1223   if (ShapeUtil::IsZeroElementArray(shape())) {
1224     return;
1225   }
1226   std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1227       shape(), /*linear_index=*/0);
1228   do {
1229     per_cell(indices, GetAsString(indices));
1230   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1231 }
1232 
1233 namespace {
1234 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1235 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1236                                                const ConverterType& converter) {
1237   CHECK(src_literal.shape().IsArray());
1238   Literal result_literal(ShapeUtil::ChangeElementType(
1239       src_literal.shape(),
1240       primitive_util::NativeToPrimitiveType<NativeDestT>()));
1241   auto src_data = src_literal.data<NativeSrcT>();
1242   auto dest_data = result_literal.template data<NativeDestT>();
1243   int64 num_elements = src_literal.element_count();
1244 
1245   for (int64 i = 0; i < num_elements; ++i) {
1246     dest_data[i] = converter(src_data[i]);
1247   }
1248   return result_literal;
1249 }
1250 
1251 template <typename NativeSrcT, typename NativeDestT>
1252 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1253                             (std::is_same<NativeDestT, complex64>::value ||
1254                              std::is_same<NativeDestT, complex128>::value),
1255                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1256 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1257   auto converter = [](NativeSrcT src) {
1258     return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1259   };
1260   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1261       src_literal, converter);
1262 }
1263 
1264 template <typename NativeSrcT, typename NativeDestT>
1265 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1266                             (!std::is_same<NativeDestT, complex64>::value &&
1267                              !std::is_same<NativeDestT, complex128>::value),
1268                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1269 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1270   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1271   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1272       src_literal, converter);
1273 }
1274 
1275 template <typename NativeSrcT, typename NativeDestT>
1276 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1277                          !std::is_same<NativeDestT, Eigen::half>::value),
1278                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1279 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1280   auto converter = [](NativeSrcT src) {
1281     return absl::bit_cast<NativeDestT>(GetRawValue(src));
1282   };
1283   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1284       src_literal, converter);
1285 }
1286 
1287 template <typename NativeSrcT, typename NativeDestT>
1288 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1289                          std::is_same<NativeDestT, Eigen::half>::value),
1290                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1291 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1292   // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1293   // cast to unsigned short and then use raw_uint16_to_half.
1294   auto converter = [](NativeSrcT src) {
1295     return Eigen::half_impl::raw_uint16_to_half(
1296         absl::bit_cast<uint16>(GetRawValue(src)));
1297   };
1298   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1299       src_literal, converter);
1300 }
1301 
1302 // This template specialization is here to make the compiler happy. bit_cast has
1303 // a static check that the types are the same size. This specialization should
1304 // never be used because the source and destination types are checked for
1305 // identical sizes higher up.
1306 template <typename NativeSrcT, typename NativeDestT>
1307 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1308                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1309 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1310   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1311 }
1312 
1313 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1314 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1315   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1316   if (bitcast) {
1317     return BitcastBetweenNativeTypes<
1318         typename primitive_util::PrimitiveTypeToNative<
1319             primitive_src_type>::type,
1320         typename primitive_util::PrimitiveTypeToNative<
1321             primitive_dest_type>::type>(src_literal);
1322   } else {
1323     return ConvertBetweenNativeTypes<
1324         typename primitive_util::PrimitiveTypeToNative<
1325             primitive_src_type>::type,
1326         typename primitive_util::PrimitiveTypeToNative<
1327             primitive_dest_type>::type>(src_literal);
1328   }
1329 }
1330 
1331 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1332 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1333                                            PrimitiveType primitive_dest_type,
1334                                            bool bitcast) {
1335   switch (primitive_dest_type) {
1336 #define CONVERT_IF_TYPES_MATCH(type)                                    \
1337   case (type):                                                          \
1338     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1339                                                            bitcast);
1340     CONVERT_IF_TYPES_MATCH(PRED)
1341     CONVERT_IF_TYPES_MATCH(S8)
1342     CONVERT_IF_TYPES_MATCH(S16)
1343     CONVERT_IF_TYPES_MATCH(S32)
1344     CONVERT_IF_TYPES_MATCH(S64)
1345     CONVERT_IF_TYPES_MATCH(U8)
1346     CONVERT_IF_TYPES_MATCH(U16)
1347     CONVERT_IF_TYPES_MATCH(U32)
1348     CONVERT_IF_TYPES_MATCH(U64)
1349     CONVERT_IF_TYPES_MATCH(F16)
1350     CONVERT_IF_TYPES_MATCH(F32)
1351     CONVERT_IF_TYPES_MATCH(F64)
1352     CONVERT_IF_TYPES_MATCH(BF16)
1353 #undef CONVERT_IF_TYPES_MATCH
1354     case C64:
1355       if (bitcast) {
1356         break;
1357       }
1358       return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1359     case C128:
1360       if (bitcast) {
1361         break;
1362       }
1363       return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1364     // Other types are not yet supported.
1365     default:
1366       break;
1367   }
1368   return Unimplemented("Converting from type %s to type %s is not implemented.",
1369                        PrimitiveType_Name(src_literal.shape().element_type()),
1370                        PrimitiveType_Name(primitive_dest_type));
1371 }
1372 
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1373 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1374                                 PrimitiveType primitive_dest_type,
1375                                 bool bitcast) {
1376   TF_RET_CHECK(literal.shape().IsArray());
1377   if (literal.shape().element_type() == primitive_dest_type) {
1378     return literal.Clone();
1379   }
1380   switch (literal.shape().element_type()) {
1381 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
1382   case (type):                                                            \
1383     return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1384                                             bitcast);
1385     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1386     CONVERT_IF_DEST_TYPE_MATCHES(S8)
1387     CONVERT_IF_DEST_TYPE_MATCHES(S16)
1388     CONVERT_IF_DEST_TYPE_MATCHES(S32)
1389     CONVERT_IF_DEST_TYPE_MATCHES(S64)
1390     CONVERT_IF_DEST_TYPE_MATCHES(U8)
1391     CONVERT_IF_DEST_TYPE_MATCHES(U16)
1392     CONVERT_IF_DEST_TYPE_MATCHES(U32)
1393     CONVERT_IF_DEST_TYPE_MATCHES(U64)
1394     CONVERT_IF_DEST_TYPE_MATCHES(F16)
1395     CONVERT_IF_DEST_TYPE_MATCHES(F32)
1396     CONVERT_IF_DEST_TYPE_MATCHES(F64)
1397     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1398 #undef CONVERT_IF_DEST_TYPE_MATCHES
1399       // Other types are not yet supported.
1400     default:
1401       return Unimplemented("%s from type %s to type %s is not implemented.",
1402                            (bitcast ? "Bitcast converting" : "Converting"),
1403                            PrimitiveType_Name(literal.shape().element_type()),
1404                            PrimitiveType_Name(primitive_dest_type));
1405   }
1406 }
1407 
1408 }  // namespace
1409 
Convert(PrimitiveType primitive_dest_type) const1410 StatusOr<Literal> LiteralBase::Convert(
1411     PrimitiveType primitive_dest_type) const {
1412   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1413 }
1414 
BitcastConvert(PrimitiveType primitive_dest_type) const1415 StatusOr<Literal> LiteralBase::BitcastConvert(
1416     PrimitiveType primitive_dest_type) const {
1417   if (primitive_util::BitWidth(shape().element_type()) !=
1418       primitive_util::BitWidth(primitive_dest_type)) {
1419     return InvalidArgument(
1420         "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1421         "%d",
1422         PrimitiveType_Name(shape().element_type()),
1423         PrimitiveType_Name(primitive_dest_type),
1424         primitive_util::BitWidth(shape().element_type()),
1425         primitive_util::BitWidth(primitive_dest_type));
1426   }
1427   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1428 }
1429 
ConvertToShape(const Shape & dest_shape) const1430 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1431   if (!dest_shape.IsTuple()) {
1432     return Convert(dest_shape.element_type());
1433   }
1434   std::vector<Literal> elements;
1435   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1436     auto element = LiteralSlice(*this, {i});
1437     TF_ASSIGN_OR_RETURN(
1438         auto new_element,
1439         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1440     elements.push_back(std::move(new_element));
1441   }
1442   return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1443 }
1444 
MoveIntoTuple(absl::Span<Literal> elements)1445 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1446     absl::Span<Literal> elements) {
1447   std::vector<Shape> element_shapes;
1448   for (const Literal& element : elements) {
1449     element_shapes.push_back(element.shape());
1450   }
1451   Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1452                   /*allocate_arrays=*/false);
1453   for (int i = 0; i < elements.size(); ++i) {
1454     TF_CHECK_OK(
1455         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1456   }
1457   return literal;
1458 }
1459 
1460 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1461 bool LiteralBase::Piece::EqualElementsInternal(
1462     const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1463   if (multi_index->size() == subshape().rank()) {
1464     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1465   }
1466   for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
1467     multi_index->push_back(i);
1468     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1469       return false;
1470     }
1471     multi_index->pop_back();
1472   }
1473   return true;
1474 }
1475 
EqualElements(const LiteralBase::Piece & other) const1476 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1477   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1478 
1479   if (ShapeUtil::Equal(subshape(), other.subshape()) &&
1480       LayoutUtil::IsDenseArray(subshape())) {
1481     CHECK_EQ(size_bytes(), other.size_bytes());
1482     return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1483   }
1484 
1485   std::vector<int64> multi_index;
1486   switch (subshape().element_type()) {
1487     case PRED:
1488       return EqualElementsInternal<bool>(other, &multi_index);
1489     case U8:
1490       return EqualElementsInternal<uint8>(other, &multi_index);
1491     case S16:
1492       return EqualElementsInternal<int16>(other, &multi_index);
1493     case S32:
1494       return EqualElementsInternal<int32>(other, &multi_index);
1495     case S64:
1496       return EqualElementsInternal<int64>(other, &multi_index);
1497     case U16:
1498       return EqualElementsInternal<uint16>(other, &multi_index);
1499     case U32:
1500       return EqualElementsInternal<uint32>(other, &multi_index);
1501     case U64:
1502       return EqualElementsInternal<uint64>(other, &multi_index);
1503     case F32:
1504       return EqualElementsInternal<float>(other, &multi_index);
1505     case F64:
1506       return EqualElementsInternal<double>(other, &multi_index);
1507     case F16:
1508       return EqualElementsInternal<half>(other, &multi_index);
1509     case BF16:
1510       return EqualElementsInternal<bfloat16>(other, &multi_index);
1511     case C64:
1512       return EqualElementsInternal<complex64>(other, &multi_index);
1513     case C128:
1514       return EqualElementsInternal<complex128>(other, &multi_index);
1515     default:
1516       LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1517                  << PrimitiveType_Name(subshape().element_type());
1518   }
1519 }
1520 
operator ==(const LiteralBase & other) const1521 bool LiteralBase::operator==(const LiteralBase& other) const {
1522   if (!ShapeUtil::Compatible(shape(), other.shape())) {
1523     return false;
1524   }
1525 
1526   return root_piece().ForEachSubpieceWithBool(
1527       [&](const ShapeIndex& index, const Piece& piece) {
1528         if (!piece.subshape().IsArray()) {
1529           return true;
1530         }
1531 
1532         const Piece& other_piece = other.piece(index);
1533         if (!piece.EqualElements(other_piece)) {
1534           return false;
1535         }
1536         return true;
1537       });
1538 }
1539 
1540 namespace {
1541 
1542 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1543 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1544                                   NativeT value) {
1545   for (int64 i = 0; i < data.size(); ++i) {
1546     if (data[i] != value) {
1547       return false;
1548     }
1549   }
1550   return true;
1551 }
1552 
1553 }  // namespace
1554 
IsAll(int8 value) const1555 bool LiteralBase::IsAll(int8 value) const {
1556   return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1557                                                   const Piece& piece) {
1558     if (!piece.subshape().IsArray()) {
1559       return true;
1560     }
1561 
1562     auto piece_is_all = [&]() {
1563       switch (shape().element_type()) {
1564         case U8:
1565           if (value >= 0) {
1566             return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1567           }
1568           return false;
1569         case U16:
1570           if (value >= 0) {
1571             return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1572           }
1573           return false;
1574         case U32:
1575           if (value >= 0) {
1576             return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1577           }
1578           return false;
1579         case U64:
1580           if (value >= 0) {
1581             return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1582           }
1583           return false;
1584         case S8:
1585           return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1586         case S16:
1587           return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1588         case S32:
1589           return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1590         case S64:
1591           return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1592         case F32:
1593           return AllElementsEqualValue<float>(piece.data<float>(), value);
1594         case F64:
1595           return AllElementsEqualValue<double>(piece.data<double>(), value);
1596         case F16:
1597           return AllElementsEqualValue<half>(piece.data<half>(),
1598                                              static_cast<half>(value));
1599         case BF16:
1600           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1601                                                  static_cast<bfloat16>(value));
1602         case PRED:
1603           if (value == 0) {
1604             return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1605           }
1606           if (value == 1) {
1607             return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1608           }
1609           return false;
1610         default:
1611           return false;
1612       }
1613       return false;
1614     };
1615 
1616     if (!piece_is_all()) {
1617       return false;
1618     }
1619     return true;
1620   });
1621 }
1622 
IsAllFloat(float value) const1623 bool LiteralBase::IsAllFloat(float value) const {
1624   return root_piece().ForEachSubpieceWithBool(
1625       [&](const ShapeIndex& index, const Piece& piece) {
1626         if (!piece.subshape().IsArray()) {
1627           return true;
1628         }
1629 
1630         switch (shape().element_type()) {
1631           case F32:
1632             return AllElementsEqualValue<float>(piece.data<float>(), value);
1633           case F64:
1634             return AllElementsEqualValue<double>(piece.data<double>(), value);
1635           case F16:
1636             return AllElementsEqualValue<half>(piece.data<half>(),
1637                                                static_cast<half>(value));
1638           case BF16:
1639             return AllElementsEqualValue<bfloat16>(
1640                 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1641           default:
1642             return false;
1643         }
1644       });
1645 }
1646 
IsAllComplex(complex64 value) const1647 bool LiteralBase::IsAllComplex(complex64 value) const {
1648   switch (shape().element_type()) {
1649     case C64:
1650       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1651                                               value);
1652     case C128:
1653       return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1654                                                value);
1655     default:
1656       return false;
1657   }
1658 }
1659 
IsAllFirst() const1660 bool LiteralBase::IsAllFirst() const {
1661   return root_piece().ForEachSubpieceWithBool(
1662       [&](const ShapeIndex& index, const Piece& piece) {
1663         if (!piece.subshape().IsArray()) {
1664           return true;
1665         }
1666 
1667         // Empty shapes are not all the first element since there is no first
1668         // element.
1669         if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1670           return false;
1671         }
1672         auto piece_is_all = [&]() {
1673           switch (piece.subshape().element_type()) {
1674             case PRED: {
1675               auto data = piece.data<bool>();
1676               return AllElementsEqualValue<bool>(data, data[0]);
1677             }
1678             // 8 bit types
1679             case S8: {
1680               auto data = piece.data<int8>();
1681               return AllElementsEqualValue<int8>(data, data[0]);
1682             }
1683             case U8: {
1684               auto data = piece.data<uint8>();
1685               return AllElementsEqualValue<uint8>(data, data[0]);
1686             }
1687             // 16 bit types
1688             case BF16: {
1689               auto data = piece.data<bfloat16>();
1690               return AllElementsEqualValue<bfloat16>(data, data[0]);
1691             }
1692             case F16: {
1693               auto data = piece.data<half>();
1694               return AllElementsEqualValue<half>(data, data[0]);
1695             }
1696             case S16: {
1697               auto data = piece.data<int16>();
1698               return AllElementsEqualValue<int16>(data, data[0]);
1699             }
1700             case U16: {
1701               auto data = piece.data<uint16>();
1702               return AllElementsEqualValue<uint16>(data, data[0]);
1703             }
1704             // 32 bit types
1705             case F32: {
1706               auto data = piece.data<float>();
1707               return AllElementsEqualValue<float>(data, data[0]);
1708             }
1709             case U32: {
1710               auto data = piece.data<uint32>();
1711               return AllElementsEqualValue<uint32>(data, data[0]);
1712             }
1713             case S32: {
1714               auto data = piece.data<int32>();
1715               return AllElementsEqualValue<int32>(data, data[0]);
1716             }
1717             // 64 bit types
1718             case C64: {
1719               auto data = piece.data<complex64>();
1720               return AllElementsEqualValue<complex64>(data, data[0]);
1721             }
1722             case F64: {
1723               auto data = piece.data<double>();
1724               return AllElementsEqualValue<double>(data, data[0]);
1725             }
1726             case S64: {
1727               auto data = piece.data<int64>();
1728               return AllElementsEqualValue<int64>(data, data[0]);
1729             }
1730             case U64: {
1731               auto data = piece.data<uint64>();
1732               return AllElementsEqualValue<uint64>(data, data[0]);
1733             }
1734 
1735             case C128: {
1736               auto data = piece.data<complex128>();
1737               return AllElementsEqualValue<complex128>(data, data[0]);
1738             }
1739             default:
1740               return false;
1741           }
1742         };
1743 
1744         if (!piece_is_all()) {
1745           return false;
1746         }
1747         return true;
1748       });
1749 }
1750 
IsR1Iota() const1751 bool LiteralBase::IsR1Iota() const {
1752   if (!shape().IsArray()) {
1753     return false;
1754   }
1755 
1756   if (shape().rank() != 1) {
1757     return false;
1758   }
1759 
1760   auto is_iota_at_idx = [&](const int64 idx) {
1761     switch (shape().element_type()) {
1762       case U8:
1763         return Get<uint8>({idx}) == idx;
1764       case U16:
1765         return Get<uint16>({idx}) == idx;
1766       case U32:
1767         return Get<uint32>({idx}) == idx;
1768       case U64:
1769         return Get<uint64>({idx}) == idx;
1770       case S8:
1771         return Get<int8>({idx}) == idx;
1772       case S16:
1773         return Get<int16>({idx}) == idx;
1774       case S32:
1775         return Get<int32>({idx}) == idx;
1776       case S64:
1777         return Get<int64>({idx}) == idx;
1778       case F32:
1779         return Get<float>({idx}) == idx;
1780       case F64:
1781         return Get<double>({idx}) == idx;
1782       case F16:
1783         return Get<half>({idx}) == static_cast<half>(idx);
1784       case BF16:
1785         return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1786       case C64:
1787         return Get<complex64>({idx}) == complex64(idx, 0.0f);
1788       case C128:
1789         return Get<complex128>({idx}) == complex128(idx, 0.0f);
1790       case PRED:
1791         return Get<bool>({idx}) == idx;
1792       // token, opaque, tuple, etc. are all not iota.
1793       default:
1794         return false;
1795     }
1796   };
1797 
1798   const int64 elements = ShapeUtil::ElementsIn(shape());
1799   for (int64 idx = 0; idx < elements; ++idx) {
1800     if (!is_iota_at_idx(idx)) {
1801       return false;
1802     }
1803   }
1804 
1805   return true;
1806 }
1807 
IsZero(absl::Span<const int64> indices) const1808 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1809   CHECK(shape().IsArray());
1810   switch (shape().element_type()) {
1811     case U8:
1812       return Get<uint8>(indices) == 0;
1813     case U16:
1814       return Get<uint16>(indices) == 0;
1815     case U32:
1816       return Get<uint32>(indices) == 0;
1817     case U64:
1818       return Get<uint64>(indices) == 0;
1819     case S8:
1820       return Get<int8>(indices) == 0;
1821     case S16:
1822       return Get<int16>(indices) == 0;
1823     case S32:
1824       return Get<int32>(indices) == 0;
1825     case S64:
1826       return Get<int64>(indices) == 0;
1827     case F32:
1828       return Get<float>(indices) == 0.0f;
1829     case F64:
1830       return Get<double>(indices) == 0.0;
1831     case C64:
1832       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1833     case C128:
1834       return Get<complex128>(indices) == complex128(0.0f, 0.0f);
1835     case F16:
1836       return Get<half>(indices) == static_cast<half>(0.0f);
1837     case BF16:
1838       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
1839     case PRED:
1840       return Get<bool>(indices) == false;
1841     default:
1842       LOG(FATAL) << "Input literal must be an array.";
1843   }
1844 }
1845 
1846 namespace {
1847 
1848 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)1849 void CopyToRepeatedField(RepeatedFieldT* dest,
1850                          const absl::Span<const NativeT> src) {
1851   *dest = RepeatedFieldT(src.begin(), src.end());
1852 }
1853 
1854 }  // namespace
1855 
WriteToProto(LiteralProto * proto) const1856 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
1857   *proto->mutable_shape() = subshape().ToProto();
1858   switch (subshape().element_type()) {
1859     case PRED:
1860       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
1861       break;
1862     case S8:
1863       proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
1864                      element_count());
1865       break;
1866     case U8:
1867       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
1868                      element_count());
1869       break;
1870     case U32:
1871       CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
1872       break;
1873     case U64:
1874       CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
1875       break;
1876     case S32:
1877       CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
1878       break;
1879     case S64:
1880       CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
1881       break;
1882     case U16:
1883       *proto->mutable_u16s() = string(
1884           reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
1885       if (!kLittleEndian) {
1886         ConvertEndianShort(proto->mutable_u16s());
1887       }
1888       break;
1889     case S16:
1890       *proto->mutable_s16s() = string(
1891           reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
1892       if (!kLittleEndian) {
1893         ConvertEndianShort(proto->mutable_s16s());
1894       }
1895       break;
1896     case F16:
1897       *proto->mutable_f16s() = string(
1898           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
1899       if (!kLittleEndian) {
1900         ConvertEndianShort(proto->mutable_f16s());
1901       }
1902       break;
1903     case BF16:
1904       *proto->mutable_bf16s() = string(
1905           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
1906       if (!kLittleEndian) {
1907         ConvertEndianShort(proto->mutable_bf16s());
1908       }
1909       break;
1910     case F32:
1911       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
1912       break;
1913     case F64:
1914       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
1915       break;
1916     case C64:
1917       for (complex64 value : data<complex64>()) {
1918         proto->add_c64s(value.real());
1919         proto->add_c64s(value.imag());
1920       }
1921       break;
1922     case C128:
1923       for (complex128 value : data<complex128>()) {
1924         proto->add_c128s(value.real());
1925         proto->add_c128s(value.imag());
1926       }
1927       break;
1928     case TUPLE:
1929     case TOKEN:
1930       // Nothing to do but assign the shape which is done above.
1931       return;
1932     default:
1933       // TODO(b/111551621): Support serializing more PrimitiveTypes.
1934       LOG(FATAL) << "Unhandled primitive type "
1935                  << PrimitiveType_Name(subshape().element_type());
1936   }
1937 }
1938 
untyped_data() const1939 const void* LiteralBase::Piece::untyped_data() const {
1940   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1941   return buffer();
1942 }
1943 
untyped_data()1944 void* LiteralBase::Piece::untyped_data() {
1945   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
1946   return buffer();
1947 }
1948 
1949 namespace {
1950 
1951 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)1952 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
1953                              const RepeatedFieldT& src) {
1954   if (dest.size() != src.size()) {
1955     return InvalidArgument(
1956         "Expected %lu elements in LiteralProto repeated field, has %d",
1957         dest.size(), src.size());
1958   }
1959   std::copy(src.begin(), src.end(), dest.begin());
1960   return Status::OK();
1961 }
1962 
1963 }  // namespace
1964 
CopyFromProto(const LiteralProto & proto)1965 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
1966   // These conditions should have been checked in
1967   // MutableLiteralBase::CreateFromProto.
1968   TF_RET_CHECK(proto.has_shape());
1969   Shape shape(proto.shape());
1970   TF_RET_CHECK(LayoutUtil::HasLayout(shape));
1971   TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
1972 
1973   if (LayoutUtil::IsSparseArray(subshape())) {
1974     // Compute the number of elements (indices) in the sparse shape and reserve
1975     // the necessary space in spare_indices.
1976     TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
1977     TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
1978         << "Unexpected number of indices in proto ("
1979         << proto.sparse_indices_size() << ") for shape of rank "
1980         << subshape().rank();
1981     const int64 index_count = proto.sparse_indices_size() / subshape().rank();
1982     sparse_indices()->Resize(index_count);
1983 
1984     // Copy the indices from the proto into the SparseIndexArray object.
1985     TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
1986                                              proto.sparse_indices()));
1987   }
1988 
1989   switch (subshape().element_type()) {
1990     case PRED:
1991       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
1992       break;
1993     case S8: {
1994       auto s8_data = data<int8>();
1995       TF_RET_CHECK(proto.s8s().size() == s8_data.size());
1996       std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
1997     } break;
1998     case U8: {
1999       auto u8_data = data<uint8>();
2000       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2001       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2002     } break;
2003     case S32:
2004       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2005       break;
2006     case S64:
2007       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2008       break;
2009     case U32:
2010       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2011       break;
2012     case U64:
2013       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2014       break;
2015     case S16: {
2016       const string& s(proto.s16s());
2017       TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2018       memcpy(untyped_data(), s.data(), s.size());
2019       if (!kLittleEndian) {
2020         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2021       }
2022     } break;
2023     case U16: {
2024       const string& s(proto.u16s());
2025       TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2026       memcpy(untyped_data(), s.data(), s.size());
2027       if (!kLittleEndian) {
2028         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2029       }
2030     } break;
2031     case F16: {
2032       const string& s(proto.f16s());
2033       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2034       memcpy(untyped_data(), s.data(), s.size());
2035       if (!kLittleEndian) {
2036         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2037       }
2038     } break;
2039 
2040     case BF16: {
2041       const string& s(proto.bf16s());
2042       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2043       memcpy(untyped_data(), s.data(), s.size());
2044       if (!kLittleEndian) {
2045         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2046       }
2047     } break;
2048     case F32:
2049       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2050       break;
2051     case F64:
2052       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2053       break;
2054     case C64: {
2055       auto complex_data = data<complex64>();
2056       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2057       for (int64 i = 0; i < complex_data.size(); ++i) {
2058         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2059       }
2060       break;
2061     }
2062     case C128: {
2063       auto complex_data = data<complex128>();
2064       TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
2065       for (int64 i = 0; i < complex_data.size(); ++i) {
2066         complex_data[i] =
2067             complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2068       }
2069       break;
2070     }
2071     case TUPLE:
2072       return InvalidArgument("Should not be called on tuple shapes: %s",
2073                              ShapeUtil::HumanString(subshape()));
2074     default:
2075       return InvalidArgument("Is called on unsupported shape: %s",
2076                              ShapeUtil::HumanString(subshape()));
2077   }
2078   return Status::OK();
2079 }
2080 
ToProto() const2081 LiteralProto LiteralBase::ToProto() const {
2082   LiteralProto proto;
2083   root_piece().ForEachSubpiece(
2084       [&](const ShapeIndex& index, const Piece& piece) {
2085         LiteralProto* proto_piece = &proto;
2086         for (int64 i : index) {
2087           while (proto_piece->tuple_literals_size() <= i) {
2088             proto_piece->add_tuple_literals();
2089           }
2090           proto_piece = proto_piece->mutable_tuple_literals(i);
2091         }
2092         piece.WriteToProto(proto_piece);
2093       });
2094 
2095   if (LayoutUtil::IsSparseArray(shape())) {
2096     CopyToRepeatedField(proto.mutable_sparse_indices(),
2097                         sparse_indices()->data());
2098   }
2099 
2100   return proto;
2101 }
2102 
untyped_data(const ShapeIndex & shape_index) const2103 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2104   return piece(shape_index).untyped_data();
2105 }
2106 
untyped_data(const ShapeIndex & shape_index)2107 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2108   return piece(shape_index).untyped_data();
2109 }
2110 
size_bytes(const ShapeIndex & shape_index) const2111 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2112   return piece(shape_index).size_bytes();
2113 }
2114 
GetR1U8AsString() const2115 string LiteralBase::GetR1U8AsString() const {
2116   CHECK(shape().IsArray());
2117   CHECK_EQ(shape().rank(), 1);
2118   CHECK_EQ(shape().element_type(), U8);
2119   return string(absl::bit_cast<const char*>(data<uint8>().data()),
2120                 ShapeUtil::ElementsIn(shape()));
2121 }
2122 
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2123 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2124                                                Piece* src_piece,
2125                                                Piece* dest_piece) {
2126   DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2127       << "src_piece has shape: "
2128       << ShapeUtil::HumanString(src_piece->subshape())
2129       << "dest_piece has shape: "
2130       << ShapeUtil::HumanString(dest_piece->subshape());
2131   if (shape.IsTuple()) {
2132     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2133       const Shape& subshape = shape.tuple_shapes(i);
2134 
2135       auto child_piece = Piece();
2136       child_piece.set_subshape(&subshape);
2137 
2138       CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2139 
2140       dest_piece->emplace_back(std::move(child_piece));
2141     }
2142   } else if (shape.IsArray()) {
2143     dest_piece->set_buffer(src_piece->buffer());
2144   } else {
2145     // If the shape is neither an array nor tuple, then it must be
2146     // zero-sized. Otherwise, some memory needs to be allocated for it.
2147     CHECK_EQ(dest_piece->size_bytes(), 0);
2148   }
2149 }
2150 
~MutableLiteralBase()2151 MutableLiteralBase::~MutableLiteralBase() {}
2152 
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2153 MutableBorrowingLiteral::MutableBorrowingLiteral(
2154     const MutableBorrowingLiteral& literal)
2155     : MutableLiteralBase() {
2156   shape_ = absl::make_unique<Shape>(literal.shape());
2157   CHECK(LayoutUtil::HasLayout(*shape_));
2158 
2159   root_piece_ = new Piece();
2160   root_piece_->set_subshape(shape_.get());
2161 
2162   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2163 }
2164 
operator =(const MutableBorrowingLiteral & literal)2165 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2166     const MutableBorrowingLiteral& literal) {
2167   shape_ = absl::make_unique<Shape>(literal.shape());
2168   CHECK(LayoutUtil::HasLayout(*shape_));
2169 
2170   root_piece_ = new Piece();
2171   root_piece_->set_subshape(shape_.get());
2172 
2173   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2174 
2175   return *this;
2176 }
2177 
MutableBorrowingLiteral(const MutableLiteralBase & literal)2178 MutableBorrowingLiteral::MutableBorrowingLiteral(
2179     const MutableLiteralBase& literal)
2180     : MutableLiteralBase() {
2181   shape_ = absl::make_unique<Shape>(literal.shape());
2182   CHECK(LayoutUtil::HasLayout(*shape_));
2183 
2184   root_piece_ = new Piece();
2185   root_piece_->set_subshape(shape_.get());
2186 
2187   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2188 }
2189 
MutableBorrowingLiteral(MutableLiteralBase * literal)2190 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2191     : MutableLiteralBase() {
2192   shape_ = absl::make_unique<Shape>(literal->shape());
2193   CHECK(LayoutUtil::HasLayout(*shape_));
2194 
2195   root_piece_ = new Piece();
2196   root_piece_->set_subshape(shape_.get());
2197 
2198   CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2199 }
2200 
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2201 MutableBorrowingLiteral::MutableBorrowingLiteral(
2202     MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2203     : MutableLiteralBase() {
2204   shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2205   CHECK(LayoutUtil::HasLayout(*shape_));
2206 
2207   root_piece_ = new Piece();
2208   root_piece_->set_subshape(shape_.get());
2209 
2210   CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2211 }
2212 
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2213 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2214                                                  const Shape& shape)
2215     : MutableLiteralBase() {
2216   shape_ = absl::make_unique<Shape>(shape);
2217   CHECK(LayoutUtil::HasLayout(*shape_));
2218   CHECK(!shape_->IsTuple());
2219 
2220   root_piece_ = new Piece();
2221   root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2222   root_piece_->set_subshape(shape_.get());
2223 }
2224 
~MutableBorrowingLiteral()2225 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2226   if (root_piece_ != nullptr) {
2227     root_piece_->ForEachMutableSubpiece(
2228         [&](const ShapeIndex& index, Piece* piece) {
2229           if (piece->buffer() != nullptr) {
2230             delete piece->sparse_indices();
2231           }
2232         });
2233     delete root_piece_;
2234   }
2235 }
2236 
LiteralSlice(const LiteralBase & literal)2237 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2238     : LiteralBase(), root_piece_(&literal.root_piece()) {}
2239 
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2240 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2241                            const ShapeIndex& view_root)
2242     : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2243 
BuildPieceSubtree(const Shape & shape,Piece * piece)2244 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2245   CHECK(shape.IsTuple());
2246   for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2247     const Shape& subshape = shape.tuple_shapes(i);
2248 
2249     auto child_piece = Piece();
2250     child_piece.set_subshape(&subshape);
2251 
2252     if (subshape.IsTuple()) {
2253       BuildPieceSubtree(subshape, &child_piece);
2254     }
2255 
2256     piece->emplace_back(std::move(child_piece));
2257   }
2258 }
2259 
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2260 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2261     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2262   CHECK(shape_->IsArray());
2263   CHECK(LayoutUtil::HasLayout(*shape_));
2264 
2265   root_piece_ = Piece();
2266   root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2267   root_piece_.set_subshape(shape_.get());
2268 }
2269 
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2270 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2271                                    const Shape& shape)
2272     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2273   CHECK(shape_->IsTuple());
2274   CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2275   CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2276   root_piece_ = Piece();
2277   root_piece_.set_subshape(shape_.get());
2278   BuildPieceSubtree(*shape_, &root_piece_);
2279 
2280   for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2281     const auto& src_shape = shape_->tuple_shapes(i);
2282     CHECK(src_shape.IsArray());
2283     root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2284   }
2285 }
2286 
2287 }  // namespace xla
2288