1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24 
25 #include "absl/base/casts.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/index_util.h"
34 #include "tensorflow/compiler/xla/permutation_util.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mem.h"
45 #include "tensorflow/core/platform/types.h"
46 
47 namespace xla {
48 namespace {
49 
50 using absl::StrCat;
51 
52 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
53 // Literals can be used as DMA targets, which can require alignment. We
54 // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
55 // alignment.
56 constexpr int kMinimumAlignment = 64;
57 
58 // Converts between little and big endian.
59 //
60 // Precondition: size % 2 == 0 (elements in the array are 16 bits long)
ConvertEndianShort(string * bytes)61 void ConvertEndianShort(string* bytes) {
62   CHECK_EQ(bytes->size() % 2, 0);
63   for (int64 i = 0, end = bytes->size(); i < end; i += 2) {
64     std::swap((*bytes)[i], (*bytes)[i + 1]);
65   }
66 }
67 
ConvertEndianShort(char * bytes,int64 size)68 void ConvertEndianShort(char* bytes, int64 size) {
69   CHECK_EQ(size % 2, 0);
70   for (int64 i = 0; i < size; i += 2) {
71     std::swap(bytes[i], bytes[i + 1]);
72   }
73 }
74 
CompactOneline(const string & input)75 string CompactOneline(const string& input) {
76   string result;
77   std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
78   bool first = true;
79   // Concatenate elements in "v" with spaces separating them, but ignoring
80   // empty entries.
81   for (const auto& s : v) {
82     if (s.empty()) {
83       continue;
84     }
85     absl::StrAppend(&result, (first ? "" : " "), s);
86     first = false;
87   }
88   return result;
89 }
90 
91 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
92 // able to transparently access the raw 16-bit value contained within.
93 template <typename T>
GetRawValue(T val)94 T GetRawValue(T val) {
95   return val;
96 }
GetRawValue(Eigen::half val)97 uint16 GetRawValue(Eigen::half val) { return val.x; }
98 
LiteralProtoHasValues(const LiteralProto & proto)99 bool LiteralProtoHasValues(const LiteralProto& proto) {
100   return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
101          proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
102          proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
103          proto.c64s_size() || proto.c128s_size() ||
104          proto.tuple_literals_size() || !proto.f16s().empty() ||
105          !proto.bf16s().empty() || !proto.u16s().empty() ||
106          !proto.s16s().empty();
107 }
108 
109 }  // namespace
110 
~LiteralBase()111 LiteralBase::~LiteralBase() {}
112 
operator <<(std::ostream & out,const Literal & literal)113 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
114   out << literal.ToString();
115   return out;
116 }
117 
StrideConfig(const Shape & source_shape,const Shape & dest_shape,absl::Span<const int64> dimensions)118 MutableLiteralBase::StrideConfig::StrideConfig(
119     const Shape& source_shape, const Shape& dest_shape,
120     absl::Span<const int64> dimensions)
121     : dimensions(dimensions),
122       base(dimensions.size(), 0),
123       step(dimensions.size(), 1) {
124   if (!dimensions.empty()) {
125     // Selects the shape with the largest minor dimension as the one upon
126     // which to run the tight stride loop.
127     if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
128         dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
129       minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
130       dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
131     } else {
132       minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
133       source_stride =
134           IndexUtil::GetDimensionStride(source_shape, minor_dimension);
135     }
136     minor_loop_size = dimensions[minor_dimension];
137     step[minor_dimension] = minor_loop_size;
138   }
139 }
140 
Literal(const Shape & shape)141 Literal::Literal(const Shape& shape)
142     : Literal(shape, /*allocate_arrays=*/true) {}
143 
SetPiece(const Shape & shape,Piece * piece,bool allocate_arrays)144 void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
145   if (shape.IsTuple()) {
146     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
147       const Shape& subshape = shape.tuple_shapes(i);
148 
149       auto child_piece = Piece();
150       child_piece.set_subshape(&subshape);
151 
152       SetPiece(subshape, &child_piece, allocate_arrays);
153 
154       piece->emplace_back(std::move(child_piece));
155     }
156   } else if (shape.IsArray()) {
157     if (allocate_arrays) {
158       piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
159           piece->size_bytes(), kMinimumAlignment)));
160       if (shape.is_dynamic()) {
161         CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
162         piece->set_dynamic_size_buffer(
163             static_cast<int32*>(tensorflow::port::AlignedMalloc(
164                 piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
165       }
166     }
167   } else {
168     // If the shape is neither an array nor tuple, then it must be
169     // zero-sized. Otherwise, some memory needs to be allocated for it.
170     CHECK_EQ(piece->size_bytes(), 0);
171   }
172 }
173 
Literal(const Shape & shape,bool allocate_arrays)174 Literal::Literal(const Shape& shape, bool allocate_arrays)
175     : MutableLiteralBase() {
176   shape_ = absl::make_unique<Shape>(shape);
177   CHECK(LayoutUtil::HasLayout(*shape_));
178   root_piece_ = new Piece();
179   root_piece_->set_subshape(shape_.get());
180   CHECK(&root_piece_->subshape() == shape_.get());
181 
182   SetPiece(*shape_, root_piece_, allocate_arrays);
183 }
184 
~Literal()185 Literal::~Literal() {
186   if (root_piece_ != nullptr) {
187     DeallocateBuffers();
188     delete root_piece_;
189   }
190 }
191 
DeallocateBuffers()192 void Literal::DeallocateBuffers() {
193   root_piece_->ForEachMutableSubpiece(
194       [&](const ShapeIndex& index, Piece* piece) {
195         if (piece->buffer() != nullptr) {
196           tensorflow::port::AlignedFree(piece->buffer());
197         }
198         if (piece->dynamic_size_buffer() != nullptr) {
199           tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
200         }
201       });
202 }
203 
Literal(Literal && other)204 Literal::Literal(Literal&& other) : MutableLiteralBase() {
205   *this = std::move(other);
206 }
207 
operator =(Literal && other)208 Literal& Literal::operator=(Literal&& other) {
209   DCHECK(&other.root_piece_->subshape() == other.shape_.get());
210   using std::swap;
211   swap(shape_, other.shape_);
212   swap(root_piece_, other.root_piece_);
213   DCHECK(&root_piece_->subshape() == shape_.get());
214 
215   return *this;
216 }
217 
CreateFromShape(const Shape & shape)218 Literal LiteralBase::CreateFromShape(const Shape& shape) {
219   Literal literal(shape);
220   literal.root_piece_->ForEachMutableSubpiece(
221       [&](const ShapeIndex& index, Piece* piece) {
222         if (piece->subshape().IsArray()) {
223           memset(piece->untyped_data(), 0, piece->size_bytes());
224         }
225       });
226   return literal;
227 }
228 
GetDynamicSize(int64 dim_index) const229 int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
230   return GetDynamicSize(dim_index, {});
231 }
232 
GetDynamicSize(int64 dim_index,const ShapeIndex & shape_index) const233 int32 LiteralBase::GetDynamicSize(int64 dim_index,
234                                   const ShapeIndex& shape_index) const {
235   return piece(shape_index).GetDynamicSize(dim_index);
236 }
237 
GetFirstInteger() const238 absl::optional<int64> LiteralBase::GetFirstInteger() const {
239   switch (shape().element_type()) {
240     case U8:
241       return GetFirstElement<uint8>();
242     case U16:
243       return GetFirstElement<uint16>();
244     case U32:
245       return GetFirstElement<uint32>();
246     case U64: {
247       int64 v = GetFirstElement<uint64>();
248       if (v < 0) {
249         return absl::nullopt;
250       }
251       return v;
252     }
253     case S8:
254       return GetFirstElement<int8>();
255     case S16:
256       return GetFirstElement<int16>();
257     case S32:
258       return GetFirstElement<int32>();
259     case S64:
260       return GetFirstElement<int64>();
261     default:
262       return absl::nullopt;
263   }
264 }
265 
266 template <typename NativeT>
CopySliceFromInternal(const LiteralBase & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)267 Status MutableLiteralBase::CopySliceFromInternal(
268     const LiteralBase& src_literal, absl::Span<const int64> src_base,
269     absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
270   const int64 src_base_size = src_base.size();
271   const int64 dest_base_size = dest_base.size();
272   TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
273   TF_RET_CHECK(shape().rank() == dest_base_size);
274 
275   auto linear_index = [](const Shape& shape,
276                          absl::Span<const int64> multi_index) {
277     return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
278   };
279 
280   if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
281     // If any of the two shapes are scalars, we can just call the StridedCopy()
282     // directly, and we know we will be copying only one value.
283     TF_RET_CHECK(copy_size.empty());
284     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
285                 src_literal.data<NativeT>(),
286                 linear_index(src_literal.shape(), src_base), 0, 1);
287   } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
288              !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
289     // Perform copy if neither src nor dest has dimensions with zero element,
290     // otherwise it's a no-op.
291     TF_RET_CHECK(src_base.size() == dest_base.size());
292     TF_RET_CHECK(src_base.size() == copy_size.size());
293 
294     // Scan the source from minor, stepping in copy size blocks, then within
295     // the index enumeration functor, do a strided copy advancing source index
296     // by one (walking through the minor dimension), and destination index by
297     // proper stride size at the matching dimension.
298     DimensionVector src_indexes(src_base.size(), 0);
299     DimensionVector dest_indexes(dest_base.size(), 0);
300     MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
301                                                    copy_size);
302 
303     auto copy_proc = [&](absl::Span<const int64> indexes) {
304       // Map from multi-dimensional index, to source index.
305       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
306                      src_indexes.begin(), std::plus<int64>());
307       // Map from multi-dimensional index, to destination index.
308       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
309                      dest_indexes.begin(), std::plus<int64>());
310 
311       int64 src_index = linear_index(src_literal.shape(), src_indexes);
312       int64 dest_index = linear_index(shape(), dest_indexes);
313 
314       // `this->` is needed to workaround MSVC bug: #16882
315       StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
316                   src_literal.data<NativeT>(), src_index,
317                   stride_config.source_stride, stride_config.minor_loop_size);
318       return true;
319     };
320 
321     ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
322                             stride_config.dimensions, stride_config.step,
323                             copy_proc);
324   }
325   return Status::OK();
326 }
327 
CopyElementFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_index,absl::Span<const int64> dest_index)328 Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
329                                            absl::Span<const int64> src_index,
330                                            absl::Span<const int64> dest_index) {
331   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
332   const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
333       src_literal.shape(), src_index);
334   const int64 dest_linear_index =
335       IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
336   const int64 primitive_size =
337       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
338 
339   char* dest_address =
340       static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
341   const char* source_address =
342       static_cast<const char*>(src_literal.untyped_data()) +
343       src_linear_index * primitive_size;
344   if (dest_address != source_address) {
345     memcpy(dest_address, source_address, primitive_size);
346   }
347   return Status::OK();
348 }
349 
CreateFromProto(const LiteralProto & proto,bool prohibit_empty_literal)350 /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
351     const LiteralProto& proto, bool prohibit_empty_literal) {
352   if (!proto.has_shape()) {
353     return InvalidArgument("LiteralProto has no shape");
354   }
355   Shape shape(proto.shape());
356   if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
357     return InvalidArgument(
358         "Literal shape cannot include OPAQUE_TYPE sub-shape");
359   }
360   if (!LayoutUtil::HasLayout(shape)) {
361     return InvalidArgument("LiteralProto has no layout");
362   }
363 
364   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
365 
366   Literal literal(shape);
367 
368   TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
369       [&](const ShapeIndex& index, Piece* piece) {
370         const LiteralProto* proto_element = &proto;
371         for (int64 i : index) {
372           CHECK(i < proto_element->tuple_literals_size());
373           proto_element = &proto_element->tuple_literals(i);
374         }
375 
376         if (piece->subshape().IsTuple()) {
377           if (proto_element->tuple_literals_size() !=
378               ShapeUtil::TupleElementCount(piece->subshape())) {
379             return InvalidArgument(
380                 "Expected %d tuple elements in LiteralProto, has %d",
381                 ShapeUtil::TupleElementCount(piece->subshape()),
382                 proto_element->tuple_literals_size());
383           }
384           return Status::OK();
385         }
386         if (piece->subshape().element_type() == TOKEN) {
387           return Status::OK();
388         }
389 
390         CHECK(piece->subshape().IsArray());
391 
392         // When prohibit_empty_literal is false (allowing literal with no
393         // values), only copy from proto if the literal proto has values. This
394         // mode is used for a learned cost model.
395         if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
396           TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
397         }
398 
399         return Status::OK();
400       }));
401 
402   return std::move(literal);
403 }
404 
DecomposeTuple()405 std::vector<Literal> Literal::DecomposeTuple() {
406   CHECK(shape().IsTuple());
407   std::vector<Literal> elements;
408   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
409     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
410                                /*allocate_arrays=*/false));
411     Literal& element = elements.back();
412     element.root_piece_->ForEachMutableSubpiece(
413         [&](const ShapeIndex& index, Piece* dest_piece) {
414           ShapeIndex src_index = {i};
415           for (int64 j : index) {
416             src_index.push_back(j);
417           }
418           Piece& src_piece = piece(src_index);
419 
420           // Move the respective buffer over to the element Literal.
421           dest_piece->set_buffer(src_piece.buffer());
422           dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
423           src_piece.set_buffer(nullptr);
424           src_piece.set_dynamic_size_buffer(nullptr);
425         });
426   }
427   // Set this literal to be nil-shaped.
428   *this = Literal();
429   return elements;
430 }
431 
432 namespace {
433 
434 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
435 // the array slices are indicated by dest_shape and src_shape respectively.
436 template <typename NativeT>
CopyElementsBetween(absl::Span<NativeT> dest,absl::Span<const NativeT> src,const Shape & dest_shape,const Shape & src_shape)437 void CopyElementsBetween(absl::Span<NativeT> dest,
438                          absl::Span<const NativeT> src, const Shape& dest_shape,
439                          const Shape& src_shape) {
440   CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
441   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
442     return;
443   }
444   std::vector<int64> index(dest_shape.rank());
445   do {
446     dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
447         src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
448   } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
449 }
450 }  // namespace
451 
GetDynamicSize(int64 dim_index) const452 int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
453   CHECK(LayoutUtil::IsDenseArray(subshape()));
454   if (!subshape_->is_dynamic_dimension(dim_index)) {
455     // This is a static dimension, return size.
456     return subshape_->dimensions(dim_index);
457   }
458   CHECK_NE(dynamic_size_buffer(), nullptr);
459   return dynamic_size_buffer_[dim_index];
460 }
461 
SetDynamicSize(int64 dim_index,int32 size)462 void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
463   CHECK(LayoutUtil::IsDenseArray(subshape()));
464   CHECK(subshape_->is_dynamic_dimension(dim_index));
465   if (dynamic_size_buffer() == nullptr) {
466     // Lazily initialize the dynamic size buffer.
467     set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
468         dynamic_size_buffer_bytes(), kMinimumAlignment)));
469     /*for (int64 i = 0; i < subshape().rank(); ++i) {
470       // Initialized to -1 to help debug.
471       dynamic_size_buffer_[i] = -1;
472     }*/
473   }
474   dynamic_size_buffer_[dim_index] = size;
475 }
476 
CopyFrom(const LiteralBase::Piece & src,bool only_dynamic_bound)477 Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
478                                     bool only_dynamic_bound) {
479   CHECK(subshape_ != nullptr);
480   CHECK(src.subshape_ != nullptr);
481   if (ShapeUtil::Equal(subshape(), src.subshape())) {
482     // If the layouts are equal it's faster just to memcpy.
483     memcpy(buffer(), src.buffer(), src.size_bytes());
484   } else {
485     std::vector<int64> origin(subshape().rank(), 0);
486     switch (subshape().element_type()) {
487 #define COPY_ELEMENTS(XLA_T, NATIVE_T)                                      \
488   case (XLA_T):                                                             \
489     if (only_dynamic_bound) {                                               \
490       CopyElementsWithDynamicBound<NATIVE_T>(src);                          \
491     } else {                                                                \
492       CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
493                                     subshape(), src.subshape());            \
494     }                                                                       \
495     break;
496       COPY_ELEMENTS(U8, uint8);
497       COPY_ELEMENTS(U16, uint16);
498       COPY_ELEMENTS(U32, uint32);
499       COPY_ELEMENTS(U64, uint64);
500       COPY_ELEMENTS(S8, int8);
501       COPY_ELEMENTS(S16, int16);
502       COPY_ELEMENTS(S32, int32);
503       COPY_ELEMENTS(S64, int64);
504       COPY_ELEMENTS(F16, half);
505       COPY_ELEMENTS(BF16, bfloat16);
506       COPY_ELEMENTS(F32, float);
507       COPY_ELEMENTS(F64, double);
508       COPY_ELEMENTS(C64, complex64);
509       COPY_ELEMENTS(C128, complex128);
510       COPY_ELEMENTS(PRED, bool);
511 #undef COPY_ELEMENTS
512       default:
513         return Unimplemented(
514             "Copying a Literal object with element type %s is not implemented.",
515             PrimitiveType_Name(subshape().element_type()));
516     }
517   }
518   DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
519   if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
520     CHECK_NE(dynamic_size_buffer_, nullptr);
521     CHECK_NE(src.dynamic_size_buffer_, nullptr);
522     memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
523            src.dynamic_size_buffer_bytes());
524   }
525   return Status::OK();
526 }
527 
SetDynamicSize(int64 dim_index,int32 size)528 void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
529   return SetDynamicSize(dim_index, {}, size);
530 }
531 
SetDynamicSize(int64 dim_index,const ShapeIndex & shape_index,int32 size)532 void MutableLiteralBase::SetDynamicSize(int64 dim_index,
533                                         const ShapeIndex& shape_index,
534                                         int32 size) {
535   Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
536   CHECK_GE(subshape_->dimensions(dim_index), size);
537   if (subshape_->dimensions(dim_index) == size) {
538     subshape_->set_dynamic_dimension(dim_index, false);
539     return;
540   }
541   subshape_->set_dynamic_dimension(dim_index, true);
542   piece(shape_index).SetDynamicSize(dim_index, size);
543 }
544 
CopyFrom(const LiteralSlice & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index,bool only_dynamic_bound)545 Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
546                                     const ShapeIndex& dest_shape_index,
547                                     const ShapeIndex& src_shape_index,
548                                     bool only_dynamic_bound) {
549   const Shape& dest_subshape =
550       ShapeUtil::GetSubshape(shape(), dest_shape_index);
551   const Shape& src_subshape =
552       ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
553   if (only_dynamic_bound) {
554     auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
555     auto compact_shape =
556         dest_subshape.is_static() ? dest_subshape : src_subshape;
557     CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
558         << compact_shape.ToString() << " vs " << bound_shape.ToString();
559   } else {
560     if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
561       return InvalidArgument(
562           "Destination subshape incompatible with source subshape: %s vs %s",
563           ShapeUtil::HumanString(dest_subshape),
564           ShapeUtil::HumanString(src_subshape));
565     }
566   }
567   return root_piece_->ForEachMutableSubpieceWithStatus(
568       [&](const ShapeIndex& index, Piece* piece) {
569         if (!piece->subshape().IsArray()) {
570           return Status::OK();
571         }
572 
573         // Determine if this index is in the part of this literal that we want
574         // to copy over from src_literal.
575         bool in_subtree_to_copy = true;
576         for (int i = 0; i < dest_shape_index.size(); ++i) {
577           if (index[i] != dest_shape_index[i]) {
578             in_subtree_to_copy = false;
579             break;
580           }
581         }
582         if (!in_subtree_to_copy) {
583           return Status::OK();
584         }
585         // Construct the index of the corresponding piece in the source literal.
586         ShapeIndex src_piece_index = src_shape_index;
587         for (int64 i = dest_shape_index.size(), end = index.size(); i < end;
588              ++i) {
589           src_piece_index.push_back(index[i]);
590         }
591         TF_RETURN_IF_ERROR(
592             piece->CopyFrom(src_literal.piece(src_piece_index),
593                             /*only_dynamic_bound=*/only_dynamic_bound));
594         return Status::OK();
595       });
596 }
597 
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)598 Status Literal::MoveFrom(Literal&& src_literal,
599                          const ShapeIndex& dest_shape_index) {
600   const Shape& dest_subshape =
601       ShapeUtil::GetSubshape(shape(), dest_shape_index);
602   if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
603     return InvalidArgument(
604         "Destination subshape not equal to source shape: %s vs %s",
605         ShapeUtil::HumanString(dest_subshape),
606         ShapeUtil::HumanString(src_literal.shape()));
607   }
608 
609   src_literal.root_piece_->ForEachSubpiece(
610       [&](const ShapeIndex& src_index, const Piece& src_piece) {
611         if (!src_piece.subshape().IsArray()) {
612           return;
613         }
614 
615         ShapeIndex dest_index = dest_shape_index;
616         for (int64 i : src_index) {
617           dest_index.push_back(i);
618         }
619         Piece& dest_piece = piece(dest_index);
620         tensorflow::port::AlignedFree(dest_piece.buffer());
621         tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
622         dest_piece.set_buffer(src_piece.buffer());
623         dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
624       });
625 
626   src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
627   delete src_literal.root_piece_;
628   src_literal.root_piece_ = new LiteralBase::Piece();
629   src_literal.root_piece_->set_subshape(src_literal.shape_.get());
630 
631   return Status::OK();
632 }
633 
CopySliceFrom(const LiteralSlice & src_literal,absl::Span<const int64> src_base,absl::Span<const int64> dest_base,absl::Span<const int64> copy_size)634 Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
635                                          absl::Span<const int64> src_base,
636                                          absl::Span<const int64> dest_base,
637                                          absl::Span<const int64> copy_size) {
638   TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
639   TF_RET_CHECK(src_literal.shape().IsArray())
640       << ShapeUtil::HumanString(src_literal.shape());
641   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
642 
643   switch (shape().element_type()) {
644     case U8:
645       return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
646                                           copy_size);
647     case U16:
648       return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
649                                            copy_size);
650     case U32:
651       return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
652                                            copy_size);
653     case U64:
654       return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
655                                            copy_size);
656     case S8:
657       return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
658                                          copy_size);
659     case S16:
660       return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
661                                           copy_size);
662     case S32:
663       return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
664                                           copy_size);
665     case S64:
666       return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
667                                           copy_size);
668     case F16:
669       return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
670                                          copy_size);
671     case BF16:
672       return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
673                                              copy_size);
674     case F32:
675       return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
676                                           copy_size);
677     case F64:
678       return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
679                                            copy_size);
680     case C64:
681       return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
682                                               copy_size);
683     case C128:
684       return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
685                                                copy_size);
686     case PRED:
687       return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
688                                          copy_size);
689     default:
690       break;
691   }
692   return Unimplemented(
693       "Copying a slice from a Literal object with element type %d is not "
694       "implemented.",
695       shape().element_type());
696 }
697 
PopulateR1(const tensorflow::core::Bitmap & values)698 void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
699   CHECK(shape().IsArray());
700   CHECK_EQ(shape().rank(), 1);
701   CHECK_EQ(element_count(), values.bits());
702   CHECK_EQ(shape().element_type(), PRED);
703   for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
704     Set({i}, values.get(i));
705   }
706 }
707 
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const708 Literal LiteralBase::Relayout(const Layout& new_layout,
709                               const ShapeIndex& shape_index) const {
710   // Create new shape with 'new_layout' set at the given shape index.
711   Shape new_shape = shape();
712   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
713   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
714   *subshape->mutable_layout() = new_layout;
715   Literal result(new_shape);
716   TF_CHECK_OK(result.CopyFrom(*this));
717   return result;
718 }
719 
Relayout(const Shape & shape_with_layout) const720 Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
721   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
722       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
723       << " not compatible with literal shape "
724       << ShapeUtil::HumanString(shape());
725   Literal result = CreateFromShape(shape_with_layout);
726   ShapeUtil::ForEachSubshape(
727       result.shape(),
728       [this, &result](const Shape& subshape, const ShapeIndex& index) {
729         if (subshape.IsArray()) {
730           TF_CHECK_OK(result.CopyFrom(*this,
731                                       /*dest_shape_index=*/index,
732                                       /*src_shape_index=*/index));
733         }
734       });
735   return result;
736 }
737 
ToBoundedDynamic(const Shape & bounded_shape) const738 Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
739   CHECK(bounded_shape.is_dynamic());
740   Literal result(bounded_shape);
741   ShapeUtil::ForEachSubshape(
742       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
743         if (!subshape.IsArray()) {
744           return;
745         }
746         for (int64 i = 0; i < subshape.rank(); ++i) {
747           result.SetDynamicSize(i, subshape.dimensions(i));
748         }
749       });
750   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
751 
752   return result;
753 }
754 
ToStatic() const755 Literal LiteralBase::ToStatic() const {
756   // Create new shape with 'new_layout' set at the given shape index.
757   Shape new_shape = shape();
758   ShapeUtil::ForEachMutableSubshape(
759       &new_shape, [this](Shape* subshape, const ShapeIndex& index) {
760         if (!subshape->IsArray()) {
761           return;
762         }
763         for (int64 i = 0; i < subshape->rank(); ++i) {
764           subshape->set_dynamic_dimension(i, false);
765           subshape->set_dimensions(i, GetDynamicSize(i, index));
766         }
767       });
768   Literal result(new_shape);
769   TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
770   return result;
771 }
772 
Broadcast(const Shape & result_shape,absl::Span<const int64> dimensions) const773 StatusOr<Literal> LiteralBase::Broadcast(
774     const Shape& result_shape, absl::Span<const int64> dimensions) const {
775   if (!shape().IsArray()) {
776     return InvalidArgument("Broadcast only supports arrays.");
777   }
778 
779   for (int64 i = 0, end = dimensions.size(); i < end; i++) {
780     TF_RET_CHECK(shape().dimensions(i) ==
781                  result_shape.dimensions(dimensions[i]));
782   }
783 
784   Literal result(result_shape);
785 
786   // scratch_source_index is temporary storage space for the computed index into
787   // the input literal.  We put it here to avoid allocating an std::vector in
788   // every iteration of ShapeUtil::ForEachIndex.
789   std::vector<int64> scratch_source_index(shape().dimensions_size());
790 
791   char* dest_data = static_cast<char*>(result.untyped_data());
792   const char* source_data = static_cast<const char*>(untyped_data());
793   const int64 primitive_size =
794       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
795 
796   for (int64 i = 0; i < dimensions.size(); ++i) {
797     int64 dynamic_size = GetDynamicSize(i);
798     result.SetDynamicSize(dimensions[i], dynamic_size);
799   }
800 
801   ShapeUtil::ForEachIndex(
802       result_shape, [&](absl::Span<const int64> output_index) {
803         for (int64 i = 0, end = dimensions.size(); i < end; ++i) {
804           scratch_source_index[i] = output_index[dimensions[i]];
805         }
806         int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
807             result_shape, output_index);
808         int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
809             shape(), scratch_source_index);
810         memcpy(dest_data + primitive_size * dest_index,
811                source_data + primitive_size * source_index, primitive_size);
812         return true;
813       });
814 
815   return std::move(result);
816 }
817 
Reshape(absl::Span<const int64> dimensions) const818 StatusOr<Literal> LiteralBase::Reshape(
819     absl::Span<const int64> dimensions) const {
820   if (!shape().IsArray()) {
821     return InvalidArgument("Reshape does not support tuples.");
822   }
823   if (shape().is_dynamic()) {
824     return Unimplemented("Dynamic reshape is not implemented.");
825   }
826   Literal output;
827   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
828     output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
829   } else {
830     output = Clone();
831   }
832   // Because the layout is monotonic, we can simply reuse the same sequence of
833   // values without changing their order.
834   *output.mutable_shape_do_not_use() =
835       ShapeUtil::MakeShape(shape().element_type(), dimensions);
836 
837   int64 elements_before = ShapeUtil::ElementsIn(shape());
838   int64 elements_after = ShapeUtil::ElementsIn(output.shape());
839   if (elements_before != elements_after) {
840     return InvalidArgument(
841         "Shapes before and after Literal::Reshape have different numbers "
842         "of elements: %s vs %s.",
843         ShapeUtil::HumanString(shape()),
844         ShapeUtil::HumanString(output.shape()));
845   }
846   return std::move(output);
847 }
848 
Transpose(absl::Span<const int64> permutation) const849 Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
850   CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
851   CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
852       << "Given permutation is not a permutation of dimension numbers";
853   // To transpose the array, we just permute the dimensions and layout, and
854   // do a straight memory copy of the raw data set.
855   // This is considerably faster than iterating over every array element using
856   // the EachCell<>() and Set<>() APIs.
857   Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
858   // Replace the layout with one affine to this shape, such that a
859   // transpose operation can be performed by leaving the flat values
860   // representation intact.
861   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
862   // The shape with affine layout resulting from that operation will be
863   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
864   // most minor.
865   //
866   // Essentially, given MinMaj(Di) the position of the Di dimension within the
867   // minor to major vector, and given T(Di) the index that the original Di
868   // dimension has within the transposed array, a layout is affine if
869   // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
870   // vector of the affine layout.
871   std::vector<int64> inverse_permutation = InversePermutation(permutation);
872   CHECK(LayoutUtil::IsDenseArray(permuted_shape));
873   Layout* layout = permuted_shape.mutable_layout();
874   layout->clear_minor_to_major();
875   for (auto index : LayoutUtil::MinorToMajor(shape())) {
876     layout->add_minor_to_major(inverse_permutation[index]);
877   }
878   Literal new_literal(permuted_shape);
879   for (int64 i = 0; i < shape().rank(); i++) {
880     new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
881   }
882   DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
883             ShapeUtil::ByteSizeOf(shape()));
884   std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
885   return new_literal;
886 }
887 
888 template <typename NativeT>
SliceInternal(const Shape & result_shape,absl::Span<const int64> start_indices) const889 Literal LiteralBase::SliceInternal(
890     const Shape& result_shape, absl::Span<const int64> start_indices) const {
891   Literal result_literal(result_shape);
892   DimensionVector new_indices(result_shape.rank());
893   CHECK(result_literal
894             .Populate<NativeT>([&](absl::Span<const int64> indices) {
895               for (int64 i = 0; i < result_shape.rank(); ++i) {
896                 new_indices[i] = indices[i] + start_indices[i];
897               }
898               return Get<NativeT>(new_indices);
899             })
900             .ok());
901   for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
902     if (shape().is_dynamic_dimension(dnum)) {
903       int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
904       CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
905       dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
906       result_literal.SetDynamicSize(dnum, dynamic_size);
907     }
908   }
909   return result_literal;
910 }
911 
Slice(absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices) const912 Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
913                            absl::Span<const int64> limit_indices) const {
914   CHECK(shape().IsArray()) << "tuple is not supported for slice";
915 
916   DimensionVector result_dimensions;
917   for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
918     CHECK_GE(start_indices[dnum], 0);
919     CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
920         << "dnum = " << dnum;
921     int64 dimension = limit_indices[dnum] - start_indices[dnum];
922     CHECK_GE(dimension, 0) << "dnum = " << dnum;
923     result_dimensions.push_back(dimension);
924   }
925   auto result_shape =
926       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
927                                      LayoutUtil::MinorToMajor(shape()));
928   ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
929   switch (result_shape.element_type()) {
930     case PRED:
931       return SliceInternal<bool>(result_shape, start_indices);
932     case U8:
933       return SliceInternal<uint8>(result_shape, start_indices);
934     case U16:
935       return SliceInternal<uint16>(result_shape, start_indices);
936     case U32:
937       return SliceInternal<uint32>(result_shape, start_indices);
938     case U64:
939       return SliceInternal<uint64>(result_shape, start_indices);
940     case S8:
941       return SliceInternal<int8>(result_shape, start_indices);
942     case S16:
943       return SliceInternal<int16>(result_shape, start_indices);
944     case S32:
945       return SliceInternal<int32>(result_shape, start_indices);
946     case S64:
947       return SliceInternal<int64>(result_shape, start_indices);
948     case F16:
949       return SliceInternal<half>(result_shape, start_indices);
950     case BF16:
951       return SliceInternal<bfloat16>(result_shape, start_indices);
952     case F32:
953       return SliceInternal<float>(result_shape, start_indices);
954     case F64:
955       return SliceInternal<double>(result_shape, start_indices);
956     case C64:
957       return SliceInternal<complex64>(result_shape, start_indices);
958     case C128:
959       return SliceInternal<complex128>(result_shape, start_indices);
960     default:
961       LOG(FATAL) << "not yet implemented: "
962                  << PrimitiveType_Name(result_shape.element_type());
963   }
964 }
965 
Clone() const966 Literal LiteralBase::Clone() const {
967   Literal result(shape());
968   TF_CHECK_OK(result.CopyFrom(*this));
969   return result;
970 }
971 
GetAsString(absl::Span<const int64> multi_index,const ShapeIndex & shape_index) const972 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
973                                 const ShapeIndex& shape_index) const {
974   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
975   CHECK(LayoutUtil::IsDenseArray(subshape));
976   switch (subshape.element_type()) {
977     case PRED:
978       return Get<bool>(multi_index, shape_index) ? "true" : "false";
979     case S8:
980       return StrCat(Get<int8>(multi_index, shape_index));
981     case S16:
982       return StrCat(Get<int16>(multi_index, shape_index));
983     case S32:
984       return StrCat(Get<int32>(multi_index, shape_index));
985     case S64:
986       return StrCat(Get<int64>(multi_index, shape_index));
987     case U8:
988       return StrCat(Get<uint8>(multi_index, shape_index));
989     case U16:
990       return StrCat(Get<uint16>(multi_index, shape_index));
991     case U32:
992       return StrCat(Get<uint32>(multi_index, shape_index));
993     case U64:
994       return StrCat(Get<uint64>(multi_index, shape_index));
995     case F16:
996       return RoundTripFpToString(Get<half>(multi_index, shape_index));
997     case F32:
998       return RoundTripFpToString(Get<float>(multi_index, shape_index));
999     case BF16:
1000       return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
1001     case F64:
1002       return RoundTripFpToString(Get<double>(multi_index, shape_index));
1003     case C64: {
1004       complex64 c = Get<complex64>(multi_index, shape_index);
1005       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1006                     RoundTripFpToString(c.imag()), ")");
1007     }
1008     case C128: {
1009       complex128 c = Get<complex128>(multi_index, shape_index);
1010       return StrCat("(", RoundTripFpToString(c.real()), ", ",
1011                     RoundTripFpToString(c.imag()), ")");
1012     }
1013     default:
1014       LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
1015   }
1016 }
1017 
GetIntegralAsS64(absl::Span<const int64> multi_index) const1018 absl::optional<int64> LiteralBase::GetIntegralAsS64(
1019     absl::Span<const int64> multi_index) const {
1020   CHECK(LayoutUtil::IsDenseArray(shape()));
1021   switch (shape().element_type()) {
1022     case PRED:
1023       return Get<bool>(multi_index);
1024     case S8:
1025       return Get<int8>(multi_index);
1026     case U8:
1027       return Get<uint8>(multi_index);
1028     case S16:
1029       return Get<int16>(multi_index);
1030     case U16:
1031       return Get<uint16>(multi_index);
1032     case S32:
1033       return Get<int32>(multi_index);
1034     case U32:
1035       return Get<uint32>(multi_index);
1036     case S64:
1037       return Get<int64>(multi_index);
1038     case U64:
1039       return Get<uint64>(multi_index);
1040     default:
1041       return absl::nullopt;
1042   }
1043 }
1044 
GetAsDouble(absl::Span<const int64> multi_index) const1045 absl::optional<double> LiteralBase::GetAsDouble(
1046     absl::Span<const int64> multi_index) const {
1047   CHECK(LayoutUtil::IsDenseArray(shape()));
1048   switch (shape().element_type()) {
1049     case F16:
1050       return static_cast<double>(Get<half>(multi_index));
1051     case F32:
1052       return static_cast<double>(Get<float>(multi_index));
1053     case F64:
1054       return Get<double>(multi_index);
1055     case BF16:
1056       return static_cast<double>(Get<bfloat16>(multi_index));
1057     default:
1058       return absl::nullopt;
1059   }
1060 }
1061 
GetAsComplex128(absl::Span<const int64> multi_index) const1062 absl::optional<complex128> LiteralBase::GetAsComplex128(
1063     absl::Span<const int64> multi_index) const {
1064   switch (shape().element_type()) {
1065     case BF16:
1066       return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
1067     case F16:
1068       return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
1069     case F32:
1070       return {{Get<float>(multi_index), 0}};
1071     case F64:
1072       return {{Get<double>(multi_index), 0}};
1073     case C64:
1074       return {Get<complex64>(multi_index)};
1075     case C128:
1076       return {Get<complex128>(multi_index)};
1077     case S8:
1078       return {Get<int8>(multi_index)};
1079     default:
1080       return absl::nullopt;
1081   }
1082 }
1083 
Hash() const1084 size_t LiteralBase::Hash() const {
1085   using tensorflow::Hash64;
1086   using tensorflow::Hash64Combine;
1087 
1088   size_t hash_value = ShapeUtil::Hash(shape());
1089 
1090   ShapeUtil::ForEachSubshape(
1091       shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1092         if (!subshape.IsArray()) {
1093           return;
1094         }
1095 
1096         CHECK(LayoutUtil::IsDense(subshape.layout()));
1097         hash_value = Hash64Combine(
1098             hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
1099                                size_bytes(index)));
1100       });
1101 
1102   return hash_value;
1103 }
1104 
SetIntegralAsS64(absl::Span<const int64> multi_index,int64 value)1105 Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
1106                                             int64 value) {
1107   CHECK(LayoutUtil::IsDenseArray(shape()));
1108   switch (shape().element_type()) {
1109     case PRED:
1110       Set<bool>(multi_index, value);
1111       break;
1112     case U8:
1113       Set<uint8>(multi_index, value);
1114       break;
1115     case S32:
1116       Set<int32>(multi_index, value);
1117       break;
1118     case S64:
1119       Set<int64>(multi_index, value);
1120       break;
1121     case U32:
1122       Set<uint32>(multi_index, value);
1123       break;
1124     case U64:
1125       Set<uint64>(multi_index, value);
1126       break;
1127     default:
1128       return FailedPrecondition("Array element type is not integral: %s",
1129                                 PrimitiveType_Name(shape().element_type()));
1130   }
1131   return Status::OK();
1132 }
1133 
SetFromDouble(absl::Span<const int64> multi_index,double value)1134 Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
1135                                          double value) {
1136   CHECK(LayoutUtil::IsDenseArray(shape()));
1137   switch (shape().element_type()) {
1138     case F16:
1139       Set<half>(multi_index, Eigen::half(value));
1140       break;
1141     case F32:
1142       Set<float>(multi_index, value);
1143       break;
1144     case F64:
1145       Set<double>(multi_index, value);
1146       break;
1147     case BF16:
1148       Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
1149       break;
1150     default:
1151       return FailedPrecondition("Array element type is not floating: %s",
1152                                 PrimitiveType_Name(shape().element_type()));
1153   }
1154   return Status::OK();
1155 }
1156 
1157 namespace {
1158 
ShapeToString(bool print_layout,const Shape & shape)1159 string ShapeToString(bool print_layout, const Shape& shape) {
1160   return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
1161                       : ShapeUtil::HumanString(shape);
1162 }
1163 
1164 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1165                     bool print_shape, bool print_layout,
1166                     std::vector<string>* pieces);
1167 
TupleToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1168 void TupleToStringHelper(const LiteralBase& literal,
1169                          const ShapeIndex& shape_index, bool print_shape,
1170                          bool print_layout, std::vector<string>* pieces) {
1171   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1172   pieces->push_back("(\n");
1173   std::vector<string> tuple_pieces;
1174   for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1175     ShapeIndex element_index = shape_index;
1176     element_index.push_back(i);
1177     std::vector<string> element_pieces;
1178     ToStringHelper(literal, element_index, print_shape, print_layout,
1179                    &element_pieces);
1180     tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
1181   }
1182   pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
1183   pieces->push_back("\n)");
1184 }
1185 
DenseArrayToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1186 void DenseArrayToStringHelper(const LiteralBase& literal,
1187                               const ShapeIndex& shape_index, bool print_shape,
1188                               bool print_layout, std::vector<string>* pieces) {
1189   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1190   int64 rank = subshape.rank();
1191 
1192   std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
1193       to_string_recursive = [&](absl::Span<const int64> dimensions,
1194                                 std::vector<int64>* accum_indices) {
1195         // dimensions.size() decreases by 1 at each recursive call,
1196         // and accum_indices->size() increases by 1.
1197         // Their sum is equal to the rank of the tensor.
1198         CHECK_EQ(rank, dimensions.size() + accum_indices->size());
1199 
1200         auto brace_to_string = [&](string brace) -> string {
1201           // Handle 1D tensor
1202           if (rank == 1) {
1203             return brace;
1204           }
1205           // Handle the innermost tensor of a 2D+ tensor.
1206           if (dimensions.size() == 1 && brace == "{") {
1207             return StrCat("  ", brace, dimensions[0] <= 1 ? "" : " ");
1208           }
1209           if (dimensions.size() == 1 && brace == "}") {
1210             return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
1211           }
1212           // Handle the non-innermost tensors of a 2D+ tensor.
1213           if (brace == "{") {
1214             const int64 accum_indices_size = accum_indices->size();
1215             if (rank > 3 && !accum_indices->empty() &&
1216                 accum_indices_size < rank) {
1217               int index = accum_indices->size() - 1;
1218               int value = accum_indices->back();
1219               return StrCat(brace, " /*i", index, "=", value, "*/\n");
1220             }
1221             return StrCat(brace, "\n");
1222           }
1223           return StrCat("\n", brace);
1224         };
1225 
1226         if (dimensions.empty()) {
1227           // Display predicates as 0s and 1s so that the string is more dense.
1228           string elem;
1229           if (subshape.element_type() == PRED && rank > 0) {
1230             elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
1231           } else {
1232             elem = literal.GetAsString(*accum_indices, shape_index);
1233           }
1234           pieces->push_back(elem);
1235         } else {
1236           pieces->push_back(brace_to_string("{"));
1237           for (int i = 0; i < dimensions[0]; ++i) {
1238             std::vector<int64> cloned_indices(*accum_indices);
1239             cloned_indices.push_back(i);
1240             to_string_recursive(dimensions.subspan(1), &cloned_indices);
1241             if (i < dimensions[0] - 1) {
1242               pieces->push_back(",");
1243               pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
1244             }
1245           }
1246           pieces->push_back(brace_to_string("}"));
1247         }
1248       };
1249 
1250   if (print_shape) {
1251     pieces->push_back(ShapeToString(print_layout, subshape));
1252     if (subshape.is_dynamic()) {
1253       pieces->push_back("(");
1254       for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
1255         pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
1256         if (i < subshape.dimensions_size() - 1) {
1257           pieces->push_back(",");
1258         }
1259       }
1260       pieces->push_back(")");
1261     }
1262     pieces->push_back(" ");
1263   }
1264   std::vector<int64> indices = {};
1265   std::vector<int64> dimensions;
1266   dimensions.reserve(subshape.rank());
1267   for (int64 i = 0; i < subshape.rank(); ++i) {
1268     dimensions.push_back(literal.GetDynamicSize(i, shape_index));
1269   }
1270   to_string_recursive(dimensions, &indices);
1271 }
1272 
ToStringHelper(const LiteralBase & literal,const ShapeIndex & shape_index,bool print_shape,bool print_layout,std::vector<string> * pieces)1273 void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
1274                     bool print_shape, bool print_layout,
1275                     std::vector<string>* pieces) {
1276   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1277   CHECK(LayoutUtil::HasLayout(literal.shape()));
1278   CHECK(LayoutUtil::HasLayout(subshape));
1279   if (subshape.IsTuple()) {
1280     TupleToStringHelper(literal, shape_index, print_shape, print_layout,
1281                         pieces);
1282   } else if (subshape.IsToken()) {
1283     pieces->push_back("token");
1284   } else {
1285     CHECK(LayoutUtil::IsDenseArray(subshape));
1286     DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
1287                              pieces);
1288   }
1289 }
1290 
1291 }  // namespace
1292 
ToString() const1293 string LiteralBase::ToString() const {
1294   std::vector<string> pieces;
1295   CHECK(LayoutUtil::HasLayout(this->shape()));
1296   ToStringHelper(*this, {}, /*print_shape=*/true,
1297                  /*print_layout=*/false, &pieces);
1298   return absl::StrJoin(pieces, "");
1299 }
1300 
ToStringOneline() const1301 string LiteralBase::ToStringOneline() const {
1302   return CompactOneline(ToString());
1303 }
1304 
ToStringWithoutShape() const1305 string LiteralBase::ToStringWithoutShape() const {
1306   std::vector<string> pieces;
1307   CHECK(LayoutUtil::HasLayout(this->shape()));
1308   ToStringHelper(*this, {}, /*print_shape=*/false,
1309                  /*print_layout=*/false, &pieces);
1310   return absl::StrJoin(pieces, "");
1311 }
1312 
ToStringWithoutShapeOneline() const1313 string LiteralBase::ToStringWithoutShapeOneline() const {
1314   return CompactOneline(ToStringWithoutShape());
1315 }
1316 
ToStringWithLayout() const1317 string LiteralBase::ToStringWithLayout() const {
1318   std::vector<string> pieces;
1319   CHECK(LayoutUtil::HasLayout(this->shape()));
1320   ToStringHelper(*this, {}, /*print_shape=*/true,
1321                  /*print_layout=*/true, &pieces);
1322   return absl::StrJoin(pieces, "");
1323 }
1324 
EachCellAsString(const std::function<void (absl::Span<const int64> indices,const string & value)> & per_cell) const1325 void LiteralBase::EachCellAsString(
1326     const std::function<void(absl::Span<const int64> indices,
1327                              const string& value)>& per_cell) const {
1328   if (ShapeUtil::IsZeroElementArray(shape())) {
1329     return;
1330   }
1331   std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1332       shape(), /*linear_index=*/0);
1333   do {
1334     per_cell(indices, GetAsString(indices));
1335   } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
1336 }
1337 
1338 namespace {
1339 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
ConvertBetweenNativeTypesWithConverter(const LiteralBase & src_literal,const ConverterType & converter)1340 Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
1341                                                const ConverterType& converter) {
1342   CHECK(src_literal.shape().IsArray());
1343   Literal result_literal(ShapeUtil::ChangeElementType(
1344       src_literal.shape(),
1345       primitive_util::NativeToPrimitiveType<NativeDestT>()));
1346   auto src_data = src_literal.data<NativeSrcT>();
1347   auto dest_data = result_literal.template data<NativeDestT>();
1348   int64 num_elements = src_literal.element_count();
1349 
1350   for (int64 i = 0; i < num_elements; ++i) {
1351     dest_data[i] = converter(src_data[i]);
1352   }
1353   return result_literal;
1354 }
1355 
1356 template <typename NativeSrcT, typename NativeDestT>
1357 typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
1358                             (std::is_same<NativeDestT, complex64>::value ||
1359                              std::is_same<NativeDestT, complex128>::value),
1360                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1361 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1362   auto converter = [](NativeSrcT src) {
1363     return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
1364   };
1365   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1366       src_literal, converter);
1367 }
1368 
1369 template <typename NativeSrcT, typename NativeDestT>
1370 typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
1371                             (!std::is_same<NativeDestT, complex64>::value &&
1372                              !std::is_same<NativeDestT, complex128>::value),
1373                         Literal>::type
ConvertBetweenNativeTypes(const LiteralBase & src_literal)1374 ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
1375   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
1376   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1377       src_literal, converter);
1378 }
1379 
1380 template <typename NativeSrcT, typename NativeDestT>
1381 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
1382                          !std::is_same<NativeDestT, Eigen::half>::value),
1383                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1384 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1385   auto converter = [](NativeSrcT src) {
1386     return absl::bit_cast<NativeDestT>(GetRawValue(src));
1387   };
1388   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
1389       src_literal, converter);
1390 }
1391 
1392 template <typename NativeSrcT, typename NativeDestT>
1393 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
1394                          std::is_same<NativeDestT, Eigen::half>::value),
1395                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1396 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1397   // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
1398   // cast to unsigned short and then use raw_uint16_to_half.
1399   auto converter = [](NativeSrcT src) {
1400     return Eigen::half_impl::raw_uint16_to_half(
1401         absl::bit_cast<uint16>(GetRawValue(src)));
1402   };
1403   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
1404       src_literal, converter);
1405 }
1406 
1407 // This template specialization is here to make the compiler happy. bit_cast has
1408 // a static check that the types are the same size. This specialization should
1409 // never be used because the source and destination types are checked for
1410 // identical sizes higher up.
1411 template <typename NativeSrcT, typename NativeDestT>
1412 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
1413                         Literal>::type
BitcastBetweenNativeTypes(const LiteralBase & src_literal)1414 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
1415   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
1416 }
1417 
1418 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const LiteralBase & src_literal,bool bitcast)1419 Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
1420   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1421   if (bitcast) {
1422     return BitcastBetweenNativeTypes<
1423         typename primitive_util::PrimitiveTypeToNative<
1424             primitive_src_type>::type,
1425         typename primitive_util::PrimitiveTypeToNative<
1426             primitive_dest_type>::type>(src_literal);
1427   } else {
1428     return ConvertBetweenNativeTypes<
1429         typename primitive_util::PrimitiveTypeToNative<
1430             primitive_src_type>::type,
1431         typename primitive_util::PrimitiveTypeToNative<
1432             primitive_dest_type>::type>(src_literal);
1433   }
1434 }
1435 
1436 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const LiteralBase & src_literal,PrimitiveType primitive_dest_type,bool bitcast)1437 StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
1438                                            PrimitiveType primitive_dest_type,
1439                                            bool bitcast) {
1440   switch (primitive_dest_type) {
1441 #define CONVERT_IF_TYPES_MATCH(type)                                    \
1442   case (type):                                                          \
1443     return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
1444                                                            bitcast);
1445     CONVERT_IF_TYPES_MATCH(PRED)
1446     CONVERT_IF_TYPES_MATCH(S8)
1447     CONVERT_IF_TYPES_MATCH(S16)
1448     CONVERT_IF_TYPES_MATCH(S32)
1449     CONVERT_IF_TYPES_MATCH(S64)
1450     CONVERT_IF_TYPES_MATCH(U8)
1451     CONVERT_IF_TYPES_MATCH(U16)
1452     CONVERT_IF_TYPES_MATCH(U32)
1453     CONVERT_IF_TYPES_MATCH(U64)
1454     CONVERT_IF_TYPES_MATCH(F16)
1455     CONVERT_IF_TYPES_MATCH(F32)
1456     CONVERT_IF_TYPES_MATCH(F64)
1457     CONVERT_IF_TYPES_MATCH(BF16)
1458 #undef CONVERT_IF_TYPES_MATCH
1459     case C64:
1460       if (bitcast) {
1461         break;
1462       }
1463       return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
1464     case C128:
1465       if (bitcast) {
1466         break;
1467       }
1468       return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
1469     // Other types are not yet supported.
1470     default:
1471       break;
1472   }
1473   return Unimplemented("Converting from type %s to type %s is not implemented.",
1474                        PrimitiveType_Name(src_literal.shape().element_type()),
1475                        PrimitiveType_Name(primitive_dest_type));
1476 }
1477 
ConvertSwitch(const LiteralBase & literal,PrimitiveType primitive_dest_type,bool bitcast)1478 StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
1479                                 PrimitiveType primitive_dest_type,
1480                                 bool bitcast) {
1481   TF_RET_CHECK(literal.shape().IsArray());
1482   if (literal.shape().element_type() == primitive_dest_type) {
1483     return literal.Clone();
1484   }
1485   switch (literal.shape().element_type()) {
1486 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
1487   case (type):                                                            \
1488     return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
1489                                             bitcast);
1490     CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1491     CONVERT_IF_DEST_TYPE_MATCHES(S8)
1492     CONVERT_IF_DEST_TYPE_MATCHES(S16)
1493     CONVERT_IF_DEST_TYPE_MATCHES(S32)
1494     CONVERT_IF_DEST_TYPE_MATCHES(S64)
1495     CONVERT_IF_DEST_TYPE_MATCHES(U8)
1496     CONVERT_IF_DEST_TYPE_MATCHES(U16)
1497     CONVERT_IF_DEST_TYPE_MATCHES(U32)
1498     CONVERT_IF_DEST_TYPE_MATCHES(U64)
1499     CONVERT_IF_DEST_TYPE_MATCHES(F16)
1500     CONVERT_IF_DEST_TYPE_MATCHES(F32)
1501     CONVERT_IF_DEST_TYPE_MATCHES(F64)
1502     CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1503 #undef CONVERT_IF_DEST_TYPE_MATCHES
1504       // Other types are not yet supported.
1505     default:
1506       return Unimplemented("%s from type %s to type %s is not implemented.",
1507                            (bitcast ? "Bitcast converting" : "Converting"),
1508                            PrimitiveType_Name(literal.shape().element_type()),
1509                            PrimitiveType_Name(primitive_dest_type));
1510   }
1511 }
1512 
1513 }  // namespace
1514 
Convert(PrimitiveType primitive_dest_type) const1515 StatusOr<Literal> LiteralBase::Convert(
1516     PrimitiveType primitive_dest_type) const {
1517   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
1518 }
1519 
BitcastConvert(PrimitiveType primitive_dest_type) const1520 StatusOr<Literal> LiteralBase::BitcastConvert(
1521     PrimitiveType primitive_dest_type) const {
1522   if (primitive_util::BitWidth(shape().element_type()) !=
1523       primitive_util::BitWidth(primitive_dest_type)) {
1524     return InvalidArgument(
1525         "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
1526         "%d",
1527         PrimitiveType_Name(shape().element_type()),
1528         PrimitiveType_Name(primitive_dest_type),
1529         primitive_util::BitWidth(shape().element_type()),
1530         primitive_util::BitWidth(primitive_dest_type));
1531   }
1532   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
1533 }
1534 
ConvertToShape(const Shape & dest_shape) const1535 StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
1536   if (!dest_shape.IsTuple()) {
1537     return Convert(dest_shape.element_type());
1538   }
1539   std::vector<Literal> elements;
1540   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
1541     auto element = LiteralSlice(*this, {i});
1542     TF_ASSIGN_OR_RETURN(
1543         auto new_element,
1544         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
1545     elements.push_back(std::move(new_element));
1546   }
1547   return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
1548 }
1549 
MoveIntoTuple(absl::Span<Literal> elements)1550 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
1551     absl::Span<Literal> elements) {
1552   std::vector<Shape> element_shapes;
1553   for (const Literal& element : elements) {
1554     element_shapes.push_back(element.shape());
1555   }
1556   Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
1557                   /*allocate_arrays=*/false);
1558   for (int i = 0, end = elements.size(); i < end; ++i) {
1559     TF_CHECK_OK(
1560         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
1561   }
1562   return literal;
1563 }
1564 
1565 template <typename NativeT>
CopyElementsWithDynamicBound(const LiteralBase::Piece & src)1566 void LiteralBase::Piece::CopyElementsWithDynamicBound(
1567     const LiteralBase::Piece& src) {
1568   auto dest_shape = subshape();
1569   auto src_shape = src.subshape();
1570 
1571   // At least one shape has to be static as bound.
1572   CHECK(dest_shape.is_static() || src_shape.is_static());
1573   auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
1574   if (ShapeUtil::IsZeroElementArray(dest_shape)) {
1575     return;
1576   }
1577   std::vector<int64> index(dest_shape.rank());
1578   do {
1579     bool out_of_bound = false;
1580     for (int64 i = 0; i < index.size(); ++i) {
1581       // Do not copy elements beyond dynamic bound.
1582       if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
1583         out_of_bound = true;
1584       }
1585     }
1586     if (out_of_bound) {
1587       continue;
1588     }
1589     data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
1590                                                                   index)] =
1591         src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
1592             src_shape, index)];
1593   } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
1594 }
1595 
1596 template <typename NativeT>
EqualElementsInternal(const LiteralBase::Piece & other,std::vector<int64> * multi_index) const1597 bool LiteralBase::Piece::EqualElementsInternal(
1598     const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
1599   if (multi_index->size() == subshape().rank()) {
1600     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1601   }
1602   for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
1603     multi_index->push_back(i);
1604     if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1605       return false;
1606     }
1607     multi_index->pop_back();
1608   }
1609   return true;
1610 }
1611 
EqualDynamicSize(const LiteralBase::Piece & other) const1612 bool LiteralBase::Piece::EqualDynamicSize(
1613     const LiteralBase::Piece& other) const {
1614   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1615   if (subshape().is_static()) {
1616     return true;
1617   }
1618 
1619   for (int64 i = 0; i < subshape().rank(); ++i) {
1620     if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
1621       return false;
1622     }
1623   }
1624   return true;
1625 }
1626 
EqualElements(const LiteralBase::Piece & other) const1627 bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
1628   if (subshape().is_static() &&
1629       ShapeUtil::Equal(subshape(), other.subshape()) &&
1630       LayoutUtil::IsDenseArray(subshape())) {
1631     CHECK_EQ(size_bytes(), other.size_bytes());
1632     return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
1633   }
1634 
1635   std::vector<int64> multi_index;
1636   switch (subshape().element_type()) {
1637     case PRED:
1638       return EqualElementsInternal<bool>(other, &multi_index);
1639     case S8:
1640       return EqualElementsInternal<int8>(other, &multi_index);
1641     case S16:
1642       return EqualElementsInternal<int16>(other, &multi_index);
1643     case S32:
1644       return EqualElementsInternal<int32>(other, &multi_index);
1645     case S64:
1646       return EqualElementsInternal<int64>(other, &multi_index);
1647     case U8:
1648       return EqualElementsInternal<uint8>(other, &multi_index);
1649     case U16:
1650       return EqualElementsInternal<uint16>(other, &multi_index);
1651     case U32:
1652       return EqualElementsInternal<uint32>(other, &multi_index);
1653     case U64:
1654       return EqualElementsInternal<uint64>(other, &multi_index);
1655     case F32:
1656       return EqualElementsInternal<float>(other, &multi_index);
1657     case F64:
1658       return EqualElementsInternal<double>(other, &multi_index);
1659     case F16:
1660       return EqualElementsInternal<half>(other, &multi_index);
1661     case BF16:
1662       return EqualElementsInternal<bfloat16>(other, &multi_index);
1663     case C64:
1664       return EqualElementsInternal<complex64>(other, &multi_index);
1665     case C128:
1666       return EqualElementsInternal<complex128>(other, &multi_index);
1667     default:
1668       LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
1669                  << PrimitiveType_Name(subshape().element_type());
1670   }
1671 }
1672 
operator ==(const LiteralBase & other) const1673 bool LiteralBase::operator==(const LiteralBase& other) const {
1674   // Checking the structure of tuple literals. Checks for dense arrays are
1675   // performed below.
1676   if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
1677     return false;
1678   }
1679 
1680   return root_piece().ForEachSubpieceWithBool(
1681       [&](const ShapeIndex& index, const Piece& piece) {
1682         const Piece& other_piece = other.piece(index);
1683         const Shape& subshape = piece.subshape();
1684         const Shape& other_subshape = other_piece.subshape();
1685         if (subshape.element_type() != other_subshape.element_type()) {
1686           return false;
1687         }
1688         if (!piece.subshape().IsArray()) {
1689           return true;
1690         }
1691         if (subshape.rank() != other_subshape.rank()) {
1692           return false;
1693         }
1694 
1695         for (int64 i = 0; i < subshape.rank(); ++i) {
1696           if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
1697             return false;
1698           }
1699         }
1700 
1701         if (!piece.EqualElements(other_piece)) {
1702           return false;
1703         }
1704         return true;
1705       });
1706 }
1707 
1708 namespace {
1709 
1710 template <typename NativeT>
AllElementsEqualValue(absl::Span<const NativeT> data,NativeT value)1711 static bool AllElementsEqualValue(absl::Span<const NativeT> data,
1712                                   NativeT value) {
1713   for (int64 i = 0; i < data.size(); ++i) {
1714     if (data[i] != value) {
1715       return false;
1716     }
1717   }
1718   return true;
1719 }
1720 
1721 }  // namespace
1722 
IsAll(int8 value) const1723 bool LiteralBase::IsAll(int8 value) const {
1724   return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
1725                                                   const Piece& piece) {
1726     if (!piece.subshape().IsArray()) {
1727       return true;
1728     }
1729 
1730     auto piece_is_all = [&]() {
1731       switch (shape().element_type()) {
1732         case U8:
1733           if (value >= 0) {
1734             return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1735           }
1736           return false;
1737         case U16:
1738           if (value >= 0) {
1739             return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
1740           }
1741           return false;
1742         case U32:
1743           if (value >= 0) {
1744             return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1745           }
1746           return false;
1747         case U64:
1748           if (value >= 0) {
1749             return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1750           }
1751           return false;
1752         case S8:
1753           return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1754         case S16:
1755           return AllElementsEqualValue<int16>(piece.data<int16>(), value);
1756         case S32:
1757           return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1758         case S64:
1759           return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1760         case F32:
1761           return AllElementsEqualValue<float>(piece.data<float>(), value);
1762         case F64:
1763           return AllElementsEqualValue<double>(piece.data<double>(), value);
1764         case F16:
1765           return AllElementsEqualValue<half>(piece.data<half>(),
1766                                              static_cast<half>(value));
1767         case BF16:
1768           return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1769                                                  static_cast<bfloat16>(value));
1770         case PRED:
1771           if (value == 0) {
1772             return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1773           }
1774           if (value == 1) {
1775             return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1776           }
1777           return false;
1778         default:
1779           return false;
1780       }
1781       return false;
1782     };
1783 
1784     if (!piece_is_all()) {
1785       return false;
1786     }
1787     return true;
1788   });
1789 }
1790 
IsAllFloat(float value) const1791 bool LiteralBase::IsAllFloat(float value) const {
1792   return root_piece().ForEachSubpieceWithBool(
1793       [&](const ShapeIndex& index, const Piece& piece) {
1794         if (!piece.subshape().IsArray()) {
1795           return true;
1796         }
1797 
1798         switch (shape().element_type()) {
1799           case F32:
1800             return AllElementsEqualValue<float>(piece.data<float>(), value);
1801           case F64:
1802             return AllElementsEqualValue<double>(piece.data<double>(), value);
1803           case F16:
1804             return AllElementsEqualValue<half>(piece.data<half>(),
1805                                                static_cast<half>(value));
1806           case BF16:
1807             return AllElementsEqualValue<bfloat16>(
1808                 piece.data<bfloat16>(), static_cast<bfloat16>(value));
1809           default:
1810             return false;
1811         }
1812       });
1813 }
1814 
IsAllComplex(complex64 value) const1815 bool LiteralBase::IsAllComplex(complex64 value) const {
1816   switch (shape().element_type()) {
1817     case C64:
1818       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1819                                               value);
1820     case C128:
1821       return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
1822                                                value);
1823     default:
1824       return false;
1825   }
1826 }
1827 
IsAllFirst() const1828 bool LiteralBase::IsAllFirst() const {
1829   return root_piece().ForEachSubpieceWithBool(
1830       [&](const ShapeIndex& index, const Piece& piece) {
1831         if (!piece.subshape().IsArray()) {
1832           return true;
1833         }
1834 
1835         // Empty shapes are not all the first element since there is no first
1836         // element.
1837         if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
1838           return false;
1839         }
1840         auto piece_is_all = [&]() {
1841           switch (piece.subshape().element_type()) {
1842             case PRED: {
1843               auto data = piece.data<bool>();
1844               return AllElementsEqualValue<bool>(data, data[0]);
1845             }
1846             // 8 bit types
1847             case S8: {
1848               auto data = piece.data<int8>();
1849               return AllElementsEqualValue<int8>(data, data[0]);
1850             }
1851             case U8: {
1852               auto data = piece.data<uint8>();
1853               return AllElementsEqualValue<uint8>(data, data[0]);
1854             }
1855             // 16 bit types
1856             case BF16: {
1857               auto data = piece.data<bfloat16>();
1858               return AllElementsEqualValue<bfloat16>(data, data[0]);
1859             }
1860             case F16: {
1861               auto data = piece.data<half>();
1862               return AllElementsEqualValue<half>(data, data[0]);
1863             }
1864             case S16: {
1865               auto data = piece.data<int16>();
1866               return AllElementsEqualValue<int16>(data, data[0]);
1867             }
1868             case U16: {
1869               auto data = piece.data<uint16>();
1870               return AllElementsEqualValue<uint16>(data, data[0]);
1871             }
1872             // 32 bit types
1873             case F32: {
1874               auto data = piece.data<float>();
1875               return AllElementsEqualValue<float>(data, data[0]);
1876             }
1877             case U32: {
1878               auto data = piece.data<uint32>();
1879               return AllElementsEqualValue<uint32>(data, data[0]);
1880             }
1881             case S32: {
1882               auto data = piece.data<int32>();
1883               return AllElementsEqualValue<int32>(data, data[0]);
1884             }
1885             // 64 bit types
1886             case C64: {
1887               auto data = piece.data<complex64>();
1888               return AllElementsEqualValue<complex64>(data, data[0]);
1889             }
1890             case F64: {
1891               auto data = piece.data<double>();
1892               return AllElementsEqualValue<double>(data, data[0]);
1893             }
1894             case S64: {
1895               auto data = piece.data<int64>();
1896               return AllElementsEqualValue<int64>(data, data[0]);
1897             }
1898             case U64: {
1899               auto data = piece.data<uint64>();
1900               return AllElementsEqualValue<uint64>(data, data[0]);
1901             }
1902 
1903             case C128: {
1904               auto data = piece.data<complex128>();
1905               return AllElementsEqualValue<complex128>(data, data[0]);
1906             }
1907             default:
1908               return false;
1909           }
1910         };
1911 
1912         if (!piece_is_all()) {
1913           return false;
1914         }
1915         return true;
1916       });
1917 }
1918 
IsR1Iota() const1919 bool LiteralBase::IsR1Iota() const {
1920   if (!shape().IsArray()) {
1921     return false;
1922   }
1923 
1924   if (shape().rank() != 1) {
1925     return false;
1926   }
1927 
1928   auto is_iota_at_idx = [&](const int64 idx) {
1929     switch (shape().element_type()) {
1930       case U8:
1931         return static_cast<int64>(Get<uint8>({idx})) == idx;
1932       case U16:
1933         return static_cast<int64>(Get<uint16>({idx})) == idx;
1934       case U32:
1935         return static_cast<int64>(Get<uint32>({idx})) == idx;
1936       case U64:
1937         return static_cast<int64>(Get<uint64>({idx})) == idx;
1938       case S8:
1939         return Get<int8>({idx}) == idx;
1940       case S16:
1941         return Get<int16>({idx}) == idx;
1942       case S32:
1943         return Get<int32>({idx}) == idx;
1944       case S64:
1945         return Get<int64>({idx}) == idx;
1946       case F32:
1947         return Get<float>({idx}) == idx;
1948       case F64:
1949         return Get<double>({idx}) == idx;
1950       case F16:
1951         return Get<half>({idx}) == static_cast<half>(idx);
1952       case BF16:
1953         return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
1954       case C64:
1955         return Get<complex64>({idx}) == complex64(idx, 0.0f);
1956       case C128:
1957         return Get<complex128>({idx}) == complex128(idx, 0.0f);
1958       // pred, token, opaque, tuple, etc. are all not iota.
1959       default:
1960         return false;
1961     }
1962   };
1963 
1964   const int64 elements = ShapeUtil::ElementsIn(shape());
1965   for (int64 idx = 0; idx < elements; ++idx) {
1966     if (!is_iota_at_idx(idx)) {
1967       return false;
1968     }
1969   }
1970 
1971   return true;
1972 }
1973 
IsZero(absl::Span<const int64> indices) const1974 bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
1975   CHECK(shape().IsArray());
1976   switch (shape().element_type()) {
1977     case U8:
1978       return Get<uint8>(indices) == 0;
1979     case U16:
1980       return Get<uint16>(indices) == 0;
1981     case U32:
1982       return Get<uint32>(indices) == 0;
1983     case U64:
1984       return Get<uint64>(indices) == 0;
1985     case S8:
1986       return Get<int8>(indices) == 0;
1987     case S16:
1988       return Get<int16>(indices) == 0;
1989     case S32:
1990       return Get<int32>(indices) == 0;
1991     case S64:
1992       return Get<int64>(indices) == 0;
1993     case F32:
1994       return Get<float>(indices) == 0.0f;
1995     case F64:
1996       return Get<double>(indices) == 0.0;
1997     case C64:
1998       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1999     case C128:
2000       return Get<complex128>(indices) == complex128(0.0f, 0.0f);
2001     case F16:
2002       return Get<half>(indices) == static_cast<half>(0.0f);
2003     case BF16:
2004       return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
2005     case PRED:
2006       return Get<bool>(indices) == false;
2007     default:
2008       LOG(FATAL) << "Input literal must be an array.";
2009   }
2010 }
2011 
2012 namespace {
2013 
2014 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const absl::Span<const NativeT> src)2015 void CopyToRepeatedField(RepeatedFieldT* dest,
2016                          const absl::Span<const NativeT> src) {
2017   *dest = RepeatedFieldT(src.begin(), src.end());
2018 }
2019 
2020 }  // namespace
2021 
WriteToProto(LiteralProto * proto) const2022 void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
2023   *proto->mutable_shape() = subshape().ToProto();
2024   switch (subshape().element_type()) {
2025     case PRED:
2026       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
2027       break;
2028     case S8:
2029       proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
2030                      element_count());
2031       break;
2032     case U8:
2033       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
2034                      element_count());
2035       break;
2036     case U32:
2037       CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
2038       break;
2039     case U64:
2040       CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
2041       break;
2042     case S32:
2043       CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
2044       break;
2045     case S64:
2046       CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
2047       break;
2048     case U16:
2049       *proto->mutable_u16s() = string(
2050           reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
2051       if (!kLittleEndian) {
2052         ConvertEndianShort(proto->mutable_u16s());
2053       }
2054       break;
2055     case S16:
2056       *proto->mutable_s16s() = string(
2057           reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
2058       if (!kLittleEndian) {
2059         ConvertEndianShort(proto->mutable_s16s());
2060       }
2061       break;
2062     case F16:
2063       *proto->mutable_f16s() = string(
2064           reinterpret_cast<const char*>(data<half>().data()), size_bytes());
2065       if (!kLittleEndian) {
2066         ConvertEndianShort(proto->mutable_f16s());
2067       }
2068       break;
2069     case BF16:
2070       *proto->mutable_bf16s() = string(
2071           reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
2072       if (!kLittleEndian) {
2073         ConvertEndianShort(proto->mutable_bf16s());
2074       }
2075       break;
2076     case F32:
2077       CopyToRepeatedField(proto->mutable_f32s(), data<float>());
2078       break;
2079     case F64:
2080       CopyToRepeatedField(proto->mutable_f64s(), data<double>());
2081       break;
2082     case C64:
2083       for (complex64 value : data<complex64>()) {
2084         proto->add_c64s(value.real());
2085         proto->add_c64s(value.imag());
2086       }
2087       break;
2088     case C128:
2089       for (complex128 value : data<complex128>()) {
2090         proto->add_c128s(value.real());
2091         proto->add_c128s(value.imag());
2092       }
2093       break;
2094     case TUPLE:
2095     case TOKEN:
2096       // Nothing to do but assign the shape which is done above.
2097       return;
2098     default:
2099       // TODO(b/111551621): Support serializing more PrimitiveTypes.
2100       LOG(FATAL) << "Unhandled primitive type "
2101                  << PrimitiveType_Name(subshape().element_type());
2102   }
2103 }
2104 
untyped_data() const2105 const void* LiteralBase::Piece::untyped_data() const {
2106   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2107   return buffer();
2108 }
2109 
untyped_data()2110 void* LiteralBase::Piece::untyped_data() {
2111   CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
2112   return buffer();
2113 }
2114 
2115 namespace {
2116 
2117 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(absl::Span<NativeT> dest,const RepeatedFieldT & src)2118 Status CopyFromRepeatedField(absl::Span<NativeT> dest,
2119                              const RepeatedFieldT& src) {
2120   if (dest.size() != src.size()) {
2121     return InvalidArgument(
2122         "Expected %lu elements in LiteralProto repeated field, has %d",
2123         dest.size(), src.size());
2124   }
2125   std::copy(src.begin(), src.end(), dest.begin());
2126   return Status::OK();
2127 }
2128 
2129 }  // namespace
2130 
CopyFromProto(const LiteralProto & proto)2131 Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
2132   // These conditions should have been checked in
2133   // MutableLiteralBase::CreateFromProto.
2134   TF_RET_CHECK(proto.has_shape());
2135   Shape shape(proto.shape());
2136   TF_RET_CHECK(LayoutUtil::HasLayout(shape));
2137   TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
2138 
2139   switch (subshape().element_type()) {
2140     case PRED:
2141       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
2142       break;
2143     case S8: {
2144       auto s8_data = data<int8>();
2145       TF_RET_CHECK(proto.s8s().size() == s8_data.size());
2146       std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
2147     } break;
2148     case U8: {
2149       auto u8_data = data<uint8>();
2150       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
2151       std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
2152     } break;
2153     case S32:
2154       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
2155       break;
2156     case S64:
2157       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
2158       break;
2159     case U32:
2160       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
2161       break;
2162     case U64:
2163       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
2164       break;
2165     case S16: {
2166       const string& s(proto.s16s());
2167       TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
2168       memcpy(untyped_data(), s.data(), s.size());
2169       if (!kLittleEndian) {
2170         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2171       }
2172     } break;
2173     case U16: {
2174       const string& s(proto.u16s());
2175       TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
2176       memcpy(untyped_data(), s.data(), s.size());
2177       if (!kLittleEndian) {
2178         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2179       }
2180     } break;
2181     case F16: {
2182       const string& s(proto.f16s());
2183       TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
2184       memcpy(untyped_data(), s.data(), s.size());
2185       if (!kLittleEndian) {
2186         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2187       }
2188     } break;
2189 
2190     case BF16: {
2191       const string& s(proto.bf16s());
2192       TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
2193       memcpy(untyped_data(), s.data(), s.size());
2194       if (!kLittleEndian) {
2195         ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
2196       }
2197     } break;
2198     case F32:
2199       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
2200       break;
2201     case F64:
2202       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
2203       break;
2204     case C64: {
2205       auto complex_data = data<complex64>();
2206       TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
2207       for (int64 i = 0; i < complex_data.size(); ++i) {
2208         complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
2209       }
2210       break;
2211     }
2212     case C128: {
2213       auto complex_data = data<complex128>();
2214       const int64 complex_data_size_doubled = complex_data.size() * 2;
2215       TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
2216       for (int64 i = 0, end = complex_data.size(); i < end; ++i) {
2217         complex_data[i] =
2218             complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
2219       }
2220       break;
2221     }
2222     case TUPLE:
2223       return InvalidArgument("Should not be called on tuple shapes: %s",
2224                              ShapeUtil::HumanString(subshape()));
2225     default:
2226       return InvalidArgument("Is called on unsupported shape: %s",
2227                              ShapeUtil::HumanString(subshape()));
2228   }
2229   return Status::OK();
2230 }
2231 
ToProto() const2232 LiteralProto LiteralBase::ToProto() const {
2233   LiteralProto proto;
2234   root_piece().ForEachSubpiece(
2235       [&](const ShapeIndex& index, const Piece& piece) {
2236         LiteralProto* proto_piece = &proto;
2237         for (int64 i : index) {
2238           while (proto_piece->tuple_literals_size() <= i) {
2239             proto_piece->add_tuple_literals();
2240           }
2241           proto_piece = proto_piece->mutable_tuple_literals(i);
2242         }
2243         piece.WriteToProto(proto_piece);
2244       });
2245 
2246   return proto;
2247 }
2248 
untyped_data(const ShapeIndex & shape_index) const2249 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
2250   return piece(shape_index).untyped_data();
2251 }
2252 
untyped_data(const ShapeIndex & shape_index)2253 void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
2254   return piece(shape_index).untyped_data();
2255 }
2256 
size_bytes(const ShapeIndex & shape_index) const2257 int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
2258   return piece(shape_index).size_bytes();
2259 }
2260 
GetR1U8AsString() const2261 string LiteralBase::GetR1U8AsString() const {
2262   CHECK(shape().IsArray());
2263   CHECK_EQ(shape().rank(), 1);
2264   CHECK_EQ(shape().element_type(), U8);
2265   return string(absl::bit_cast<const char*>(data<uint8>().data()),
2266                 ShapeUtil::ElementsIn(shape()));
2267 }
2268 
CopyPieceSubtree(const Shape & shape,Piece * src_piece,Piece * dest_piece)2269 void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
2270                                                Piece* src_piece,
2271                                                Piece* dest_piece) {
2272   DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
2273       << "src_piece has shape: "
2274       << ShapeUtil::HumanString(src_piece->subshape())
2275       << "dest_piece has shape: "
2276       << ShapeUtil::HumanString(dest_piece->subshape());
2277   if (shape.IsTuple()) {
2278     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2279       const Shape& subshape = shape.tuple_shapes(i);
2280 
2281       auto child_piece = Piece();
2282       child_piece.set_subshape(&subshape);
2283 
2284       CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
2285 
2286       dest_piece->emplace_back(std::move(child_piece));
2287     }
2288   } else if (shape.IsArray()) {
2289     dest_piece->set_buffer(src_piece->buffer());
2290     dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
2291   } else {
2292     // If the shape is neither an array nor tuple, then it must be
2293     // zero-sized. Otherwise, some memory needs to be allocated for it.
2294     CHECK_EQ(dest_piece->size_bytes(), 0);
2295   }
2296 }
2297 
~MutableLiteralBase()2298 MutableLiteralBase::~MutableLiteralBase() {}
2299 
MutableBorrowingLiteral(const MutableBorrowingLiteral & literal)2300 MutableBorrowingLiteral::MutableBorrowingLiteral(
2301     const MutableBorrowingLiteral& literal)
2302     : MutableLiteralBase() {
2303   shape_ = absl::make_unique<Shape>(literal.shape());
2304   CHECK(LayoutUtil::HasLayout(*shape_));
2305 
2306   root_piece_ = new Piece();
2307   root_piece_->set_subshape(shape_.get());
2308 
2309   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2310 }
2311 
operator =(const MutableBorrowingLiteral & literal)2312 MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
2313     const MutableBorrowingLiteral& literal) {
2314   shape_ = absl::make_unique<Shape>(literal.shape());
2315   CHECK(LayoutUtil::HasLayout(*shape_));
2316 
2317   root_piece_ = new Piece();
2318   root_piece_->set_subshape(shape_.get());
2319 
2320   CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
2321 
2322   return *this;
2323 }
2324 
MutableBorrowingLiteral(MutableLiteralBase * literal)2325 MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
2326     : MutableLiteralBase() {
2327   shape_ = absl::make_unique<Shape>(literal->shape());
2328   CHECK(LayoutUtil::HasLayout(*shape_));
2329 
2330   root_piece_ = new Piece();
2331   root_piece_->set_subshape(shape_.get());
2332 
2333   CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
2334 }
2335 
MutableBorrowingLiteral(MutableBorrowingLiteral literal,const ShapeIndex & view_root)2336 MutableBorrowingLiteral::MutableBorrowingLiteral(
2337     MutableBorrowingLiteral literal, const ShapeIndex& view_root)
2338     : MutableLiteralBase() {
2339   shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
2340   CHECK(LayoutUtil::HasLayout(*shape_));
2341 
2342   root_piece_ = new Piece();
2343   root_piece_->set_subshape(shape_.get());
2344 
2345   CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
2346 }
2347 
MutableBorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2348 MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
2349                                                  const Shape& shape)
2350     : MutableLiteralBase() {
2351   shape_ = absl::make_unique<Shape>(shape);
2352   CHECK(LayoutUtil::HasLayout(*shape_));
2353   CHECK(!shape_->IsTuple());
2354 
2355   root_piece_ = new Piece();
2356   root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
2357   root_piece_->set_subshape(shape_.get());
2358 }
2359 
MutableBorrowingLiteral(absl::Span<char * > src_buf_ptrs,const Shape & shape)2360 MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
2361                                                  const Shape& shape)
2362     : MutableLiteralBase() {
2363   shape_ = absl::make_unique<Shape>(shape);
2364   if (!shape_->IsTuple()) {
2365     CHECK_EQ(src_buf_ptrs.size(), 1);
2366     root_piece_ = new Piece();
2367     root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
2368     root_piece_->set_subshape(shape_.get());
2369   } else {
2370     CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2371     CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2372     root_piece_ = new Piece();
2373     root_piece_->set_subshape(shape_.get());
2374 
2375     for (int i = 0; i < src_buf_ptrs.size(); ++i) {
2376       Piece child_piece;
2377       const auto& src_shape = shape_->tuple_shapes(i);
2378       CHECK(src_shape.IsArray());
2379       child_piece.set_subshape(&src_shape);
2380       child_piece.set_buffer(src_buf_ptrs[i]);
2381       root_piece_->emplace_back(std::move(child_piece));
2382     }
2383   }
2384 }
2385 
~MutableBorrowingLiteral()2386 MutableBorrowingLiteral::~MutableBorrowingLiteral() {
2387   if (root_piece_ != nullptr) {
2388     delete root_piece_;
2389   }
2390 }
2391 
LiteralSlice(const LiteralBase & literal)2392 LiteralSlice::LiteralSlice(const LiteralBase& literal)
2393     : LiteralBase(), root_piece_(&literal.root_piece()) {}
2394 
LiteralSlice(const LiteralBase & literal,const ShapeIndex & view_root)2395 LiteralSlice::LiteralSlice(const LiteralBase& literal,
2396                            const ShapeIndex& view_root)
2397     : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
2398 
BuildPieceSubtree(const Shape & shape,Piece * piece)2399 void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
2400   CHECK(shape.IsTuple());
2401   for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
2402     const Shape& subshape = shape.tuple_shapes(i);
2403 
2404     auto child_piece = Piece();
2405     child_piece.set_subshape(&subshape);
2406 
2407     if (subshape.IsTuple()) {
2408       BuildPieceSubtree(subshape, &child_piece);
2409     }
2410 
2411     piece->emplace_back(std::move(child_piece));
2412   }
2413 }
2414 
BorrowingLiteral(const char * src_buf_ptr,const Shape & shape)2415 BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
2416     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2417   CHECK(shape_->IsArray());
2418   CHECK(LayoutUtil::HasLayout(*shape_));
2419 
2420   root_piece_ = Piece();
2421   root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
2422   root_piece_.set_subshape(shape_.get());
2423 }
2424 
BorrowingLiteral(absl::Span<const char * const> src_buf_ptrs,const Shape & shape)2425 BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
2426                                    const Shape& shape)
2427     : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
2428   CHECK(shape_->IsTuple());
2429   CHECK(!ShapeUtil::IsNestedTuple(*shape_));
2430   CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
2431   root_piece_ = Piece();
2432   root_piece_.set_subshape(shape_.get());
2433   BuildPieceSubtree(*shape_, &root_piece_);
2434 
2435   for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
2436     const auto& src_shape = shape_->tuple_shapes(i);
2437     CHECK(src_shape.IsArray());
2438     root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
2439   }
2440 }
2441 
2442 }  // namespace xla
2443