1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_util.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24
25 #include "tensorflow/compiler/xla/index_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/casts.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/lib/strings/stringprintf.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37
38 using tensorflow::strings::Printf;
39 using tensorflow::strings::StrCat;
40
41 namespace xla {
42
43 namespace {
44
45 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
46
47 // Converts between little and big endian, assuming elements in the array are 16
48 // bits long.
ConvertEndianShort(char * bytes,int64 size)49 void ConvertEndianShort(char* bytes, int64 size) {
50 CHECK_EQ(size / 2, 0);
51 for (int64 i = 0; i < size; i += 2) {
52 std::swap(bytes[i], bytes[i + 1]);
53 }
54 }
55
56 } // namespace
57
operator <<(std::ostream & out,const Literal & literal)58 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
59 out << literal.ToString();
60 return out;
61 }
62
StrideConfig(const Shape & source_shape,const Shape & dest_shape,tensorflow::gtl::ArraySlice<int64> dimensions)63 Literal::StrideConfig::StrideConfig(
64 const Shape& source_shape, const Shape& dest_shape,
65 tensorflow::gtl::ArraySlice<int64> dimensions)
66 : dimensions(dimensions),
67 base(dimensions.size(), 0),
68 step(dimensions.size(), 1) {
69 if (!dimensions.empty()) {
70 // Selects the shape with the largest minor dimension as the one upon
71 // which to run the tight stride loop.
72 if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
73 dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
74 minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
75 dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
76 } else {
77 minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
78 source_stride =
79 IndexUtil::GetDimensionStride(source_shape, minor_dimension);
80 }
81 minor_loop_size = dimensions[minor_dimension];
82 step[minor_dimension] = minor_loop_size;
83 }
84 }
85
Literal(const Shape & shape)86 Literal::Literal(const Shape& shape)
87 : Literal(shape, /*allocate_arrays=*/true) {}
88
Literal(const Shape & shape,bool allocate_arrays)89 Literal::Literal(const Shape& shape, bool allocate_arrays)
90 : shape_(shape), pieces_(shape), owns_buffers_(true) {
91 CHECK(LayoutUtil::HasLayout(shape));
92 for (auto& pair : pieces_) {
93 const ShapeIndex& index = pair.first;
94 Piece& piece = pair.second;
95
96 piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
97 const Shape& subshape = piece.subshape();
98 if (ShapeUtil::IsArray(subshape)) {
99 if (allocate_arrays) {
100 piece.set_buffer(new char[piece.size_bytes()]);
101 if (LayoutUtil::IsSparseArray(subshape)) {
102 piece.set_sparse_indices(new SparseIndexArray(
103 LayoutUtil::MaxSparseElements(subshape.layout()),
104 ShapeUtil::Rank(subshape)));
105 }
106 } else {
107 piece.set_buffer(nullptr);
108 }
109 }
110 }
111 }
112
~Literal()113 Literal::~Literal() { DeallocateBuffers(); }
114
DeallocateBuffers()115 void Literal::DeallocateBuffers() {
116 if (owns_buffers_) {
117 for (auto& pair : pieces_) {
118 Piece& piece = pair.second;
119 if (piece.buffer() != nullptr) {
120 delete[] piece.buffer();
121 delete piece.sparse_indices();
122 }
123 }
124 }
125 }
126
Literal(Literal && other)127 Literal::Literal(Literal&& other) {
128 shape_ = std::move(other.shape_);
129 pieces_ = std::move(other.pieces_);
130 // We need to iterate through the pieces to set the subshape pointer
131 // properly. It must refer to subshapes within shape_.
132 for (auto& pair : pieces_) {
133 const ShapeIndex& index = pair.first;
134 Piece& piece = pair.second;
135 piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
136 }
137 owns_buffers_ = other.owns_buffers_;
138
139 other.shape_ = ShapeUtil::MakeNil();
140 other.pieces_ = ShapeTree<Piece>(other.shape_);
141 other.piece({}).set_subshape(&other.shape_);
142 }
143
operator =(Literal && other)144 Literal& Literal::operator=(Literal&& other) {
145 DeallocateBuffers();
146 shape_ = std::move(other.shape_);
147 pieces_ = std::move(other.pieces_);
148 // We need to iterate through the pieces to set the subshape pointer
149 // properly. It must refer to subshapes within shape_.
150 for (auto& pair : pieces_) {
151 const ShapeIndex& index = pair.first;
152 Piece& piece = pair.second;
153 piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
154 }
155 owns_buffers_ = other.owns_buffers_;
156
157 other.shape_ = ShapeUtil::MakeNil();
158 other.pieces_ = ShapeTree<Piece>(other.shape_);
159 other.piece({}).set_subshape(&other.shape_);
160 return *this;
161 }
162
CreateFromShape(const Shape & shape)163 std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
164 auto literal = MakeUnique<Literal>(shape);
165 for (auto& pair : literal->pieces_) {
166 Piece& piece = pair.second;
167 if (ShapeUtil::IsArray(piece.subshape())) {
168 memset(piece.untyped_data(), 0, piece.size_bytes());
169 }
170 }
171 return literal;
172 }
173
sparse_indices(const ShapeIndex & shape_index) const174 const SparseIndexArray* Literal::sparse_indices(
175 const ShapeIndex& shape_index) const {
176 return piece(shape_index).sparse_indices();
177 }
178
sparse_indices(const ShapeIndex & shape_index)179 SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
180 return piece(shape_index).sparse_indices();
181 }
182
CreateFromDimensions(PrimitiveType primitive_type,tensorflow::gtl::ArraySlice<int64> dimensions)183 /* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
184 PrimitiveType primitive_type,
185 tensorflow::gtl::ArraySlice<int64> dimensions) {
186 return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
187 }
188
189 template <typename NativeT>
CopySliceFromInternal(const Literal & src_literal,tensorflow::gtl::ArraySlice<int64> src_base,tensorflow::gtl::ArraySlice<int64> dest_base,tensorflow::gtl::ArraySlice<int64> copy_size)190 Status Literal::CopySliceFromInternal(
191 const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
192 tensorflow::gtl::ArraySlice<int64> dest_base,
193 tensorflow::gtl::ArraySlice<int64> copy_size) {
194 TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
195 TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
196
197 auto linear_index = [](const Shape& shape,
198 tensorflow::gtl::ArraySlice<int64> multi_index) {
199 return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
200 };
201
202 if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
203 ShapeUtil::Rank(shape()) == 0) {
204 // If any of the two shapes are scalars, we can just call the StridedCopy()
205 // directly, and we know we will be copying only one value.
206 TF_RET_CHECK(copy_size.empty());
207 StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
208 src_literal.data<NativeT>(),
209 linear_index(src_literal.shape(), src_base), 0, 1);
210 } else if (!ShapeUtil::HasZeroElements(shape()) &&
211 !ShapeUtil::HasZeroElements(src_literal.shape())) {
212 // Perform copy if neither src nor dest has dimensions with zero element,
213 // otherwise it's a no-op.
214 TF_RET_CHECK(src_base.size() == dest_base.size());
215 TF_RET_CHECK(src_base.size() == copy_size.size());
216
217 // Scan the source from minor, stepping in copy size blocks, then within
218 // the index enumaration functor, do a strided copy advancing source index
219 // by one (walking through the minor dimension), and destination index by
220 // proper stride size at the matching dimension.
221 DimensionVector src_indexes(src_base.size(), 0);
222 DimensionVector dest_indexes(dest_base.size(), 0);
223 Literal::StrideConfig stride_config(src_literal.shape(), shape(),
224 copy_size);
225
226 auto copy_proc = [&](const std::vector<int64>& indexes) {
227 // Map from multi-dimensional index, to source index.
228 std::transform(indexes.begin(), indexes.end(), src_base.begin(),
229 src_indexes.begin(), std::plus<int64>());
230 // Map from multi-dimensional index, to destination index.
231 std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
232 dest_indexes.begin(), std::plus<int64>());
233
234 int64 src_index = linear_index(src_literal.shape(), src_indexes);
235 int64 dest_index = linear_index(shape(), dest_indexes);
236
237 // `this->` is needed to workaround MSVC bug: #16882
238 StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
239 src_literal.data<NativeT>(), src_index,
240 stride_config.source_stride, stride_config.minor_loop_size);
241 return true;
242 };
243
244 ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
245 stride_config.dimensions, stride_config.step,
246 copy_proc);
247 }
248 return Status::OK();
249 }
250
DecomposeTuple()251 std::vector<Literal> Literal::DecomposeTuple() {
252 CHECK(ShapeUtil::IsTuple(shape()));
253 std::vector<Literal> elements;
254 for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
255 elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
256 /*allocate_arrays=*/false));
257 Literal& element = elements.back();
258 for (auto& pair : element.pieces_) {
259 const ShapeIndex& index = pair.first;
260 Piece& dest_piece = pair.second;
261 ShapeIndex src_index = {i};
262 for (int64 j : index) {
263 src_index.push_back(j);
264 }
265 Piece& src_piece = piece(src_index);
266
267 // Move the respective buffer and sparse indices over to the element
268 // Literal.
269 dest_piece.set_buffer(src_piece.buffer());
270 src_piece.set_buffer(nullptr);
271 dest_piece.set_sparse_indices(src_piece.sparse_indices());
272 src_piece.set_sparse_indices(nullptr);
273 }
274 }
275 // Set this literal to be nil-shaped.
276 *this = Literal();
277 return elements;
278 }
279
MoveIntoTuple(tensorflow::gtl::MutableArraySlice<Literal> elements)280 /* static */ Literal Literal::MoveIntoTuple(
281 tensorflow::gtl::MutableArraySlice<Literal> elements) {
282 std::vector<Shape> element_shapes;
283 for (const Literal& element : elements) {
284 element_shapes.push_back(element.shape());
285 }
286 Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
287 /*allocate_arrays=*/false);
288 for (int i = 0; i < elements.size(); ++i) {
289 TF_CHECK_OK(
290 literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
291 }
292 return literal;
293 }
294
295 namespace {
296
297 // Copies the elements in 'src' to 'dest'. The shape and layout of the data in
298 // the array slices are indicated by dest_shape and src_shape respectively.
299 template <typename NativeT>
CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,tensorflow::gtl::ArraySlice<NativeT> src,const Shape & dest_shape,const Shape & src_shape)300 void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
301 tensorflow::gtl::ArraySlice<NativeT> src,
302 const Shape& dest_shape, const Shape& src_shape) {
303 CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
304 if (ShapeUtil::HasZeroElements(dest_shape)) {
305 return;
306 }
307 std::vector<int64> index(ShapeUtil::Rank(dest_shape));
308 do {
309 dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
310 src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
311 } while (IndexUtil::BumpIndices(dest_shape, &index));
312 }
313
314 } // namespace
315
CopyFrom(const Literal::Piece & src)316 Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
317 if (ShapeUtil::Equal(subshape(), src.subshape())) {
318 // If the layouts are equal it's faster just to memcpy.
319 memcpy(buffer(), src.buffer(), src.size_bytes());
320 } else {
321 TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
322 std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
323 switch (subshape().element_type()) {
324 #define COPY_ELEMENTS(XLA_T, NATIVE_T) \
325 case (XLA_T): \
326 CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
327 subshape(), src.subshape()); \
328 break;
329 COPY_ELEMENTS(U8, uint8);
330 COPY_ELEMENTS(U16, uint16);
331 COPY_ELEMENTS(U32, uint32);
332 COPY_ELEMENTS(U64, uint64);
333 COPY_ELEMENTS(S8, int8);
334 COPY_ELEMENTS(S16, int16);
335 COPY_ELEMENTS(S32, int32);
336 COPY_ELEMENTS(S64, int64);
337 COPY_ELEMENTS(F16, half);
338 COPY_ELEMENTS(BF16, bfloat16);
339 COPY_ELEMENTS(F32, float);
340 COPY_ELEMENTS(F64, double);
341 COPY_ELEMENTS(C64, complex64);
342 COPY_ELEMENTS(PRED, bool);
343 #undef COPY_ELEMENTS
344 default:
345 return Unimplemented(
346 "Unhandled primitive type %s",
347 PrimitiveType_Name(subshape().element_type()).c_str());
348 }
349 }
350 return Status::OK();
351 }
352
CopyFrom(const Literal & src_literal,const ShapeIndex & dest_shape_index,const ShapeIndex & src_shape_index)353 Status Literal::CopyFrom(const Literal& src_literal,
354 const ShapeIndex& dest_shape_index,
355 const ShapeIndex& src_shape_index) {
356 const Shape& dest_subshape =
357 ShapeUtil::GetSubshape(shape(), dest_shape_index);
358 const Shape& src_subshape =
359 ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
360 if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
361 return InvalidArgument(
362 "Destination subshape incompatible with source subshape: %s vs %s",
363 ShapeUtil::HumanString(dest_subshape).c_str(),
364 ShapeUtil::HumanString(src_subshape).c_str());
365 }
366
367 for (auto& pair : pieces_) {
368 const ShapeIndex& index = pair.first;
369 Piece& piece = pair.second;
370 if (!ShapeUtil::IsArray(piece.subshape())) {
371 continue;
372 }
373
374 // Determine if this index is in the part of this literal that we want to
375 // copy over from src_literal.
376 bool in_subtree_to_copy = true;
377 for (int i = 0; i < dest_shape_index.size(); ++i) {
378 if (index[i] != dest_shape_index[i]) {
379 in_subtree_to_copy = false;
380 break;
381 }
382 }
383 if (!in_subtree_to_copy) {
384 continue;
385 }
386
387 // Construct the index of the corresponding piece in the source literal.
388 ShapeIndex src_piece_index = src_shape_index;
389 for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
390 src_piece_index.push_back(index[i]);
391 }
392
393 TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
394 }
395 return Status::OK();
396 }
397
MoveFrom(Literal && src_literal,const ShapeIndex & dest_shape_index)398 Status Literal::MoveFrom(Literal&& src_literal,
399 const ShapeIndex& dest_shape_index) {
400 const Shape& dest_subshape =
401 ShapeUtil::GetSubshape(shape(), dest_shape_index);
402 if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
403 return InvalidArgument(
404 "Destination subshape not equal to source shape: %s vs %s",
405 ShapeUtil::HumanString(dest_subshape).c_str(),
406 ShapeUtil::HumanString(src_literal.shape()).c_str());
407 }
408
409 if (!(owns_buffers_ && src_literal.owns_buffers_)) {
410 return InvalidArgument(
411 "Source and destination literals must both own their buffers (ie, not "
412 "be views)");
413 }
414
415 for (auto& pair : src_literal.pieces_) {
416 const ShapeIndex& src_index = pair.first;
417 Piece& src_piece = pair.second;
418 if (!ShapeUtil::IsArray(src_piece.subshape())) {
419 continue;
420 }
421
422 ShapeIndex dest_index = dest_shape_index;
423 for (int64 i : src_index) {
424 dest_index.push_back(i);
425 }
426 Piece& dest_piece = piece(dest_index);
427 delete[] dest_piece.buffer();
428 dest_piece.set_buffer(src_piece.buffer());
429 delete dest_piece.sparse_indices();
430 dest_piece.set_sparse_indices(src_piece.sparse_indices());
431 }
432
433 src_literal.shape_ = ShapeUtil::MakeNil();
434 src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
435 src_literal.piece({}).set_subshape(&src_literal.shape_);
436 return Status::OK();
437 }
438
CopySliceFrom(const Literal & src_literal,tensorflow::gtl::ArraySlice<int64> src_base,tensorflow::gtl::ArraySlice<int64> dest_base,tensorflow::gtl::ArraySlice<int64> copy_size)439 Status Literal::CopySliceFrom(const Literal& src_literal,
440 tensorflow::gtl::ArraySlice<int64> src_base,
441 tensorflow::gtl::ArraySlice<int64> dest_base,
442 tensorflow::gtl::ArraySlice<int64> copy_size) {
443 TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
444 TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
445 << ShapeUtil::HumanString(src_literal.shape());
446 TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
447
448 switch (shape().element_type()) {
449 case U8:
450 return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
451 copy_size);
452 case U16:
453 return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
454 copy_size);
455 case U32:
456 return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
457 copy_size);
458 case U64:
459 return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
460 copy_size);
461 case S8:
462 return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
463 copy_size);
464 case S16:
465 return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
466 copy_size);
467 case S32:
468 return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
469 copy_size);
470 case S64:
471 return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
472 copy_size);
473 case F16:
474 return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
475 copy_size);
476 case BF16:
477 return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
478 copy_size);
479 case F32:
480 return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
481 copy_size);
482 case F64:
483 return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
484 copy_size);
485 case C64:
486 return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
487 copy_size);
488 case PRED:
489 return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
490 copy_size);
491 default:
492 break;
493 }
494 return Unimplemented("Unhandled primitive type %d", shape().element_type());
495 }
496
Zero(PrimitiveType primitive_type)497 /* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
498 switch (primitive_type) {
499 case U8:
500 return std::move(*Literal::CreateR0<uint8>(0));
501 case U32:
502 return std::move(*Literal::CreateR0<uint32>(0));
503 case U64:
504 return std::move(*Literal::CreateR0<uint64>(0));
505 case S8:
506 return std::move(*Literal::CreateR0<int8>(0));
507 case S32:
508 return std::move(*Literal::CreateR0<int32>(0));
509 case S64:
510 return std::move(*Literal::CreateR0<int64>(0));
511 case F16:
512 return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
513 case BF16:
514 return std::move(
515 *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
516 case F32:
517 return std::move(*Literal::CreateR0<float>(0));
518 case F64:
519 return std::move(*Literal::CreateR0<double>(0));
520 case C64:
521 return std::move(*Literal::CreateR0<complex64>(0));
522 case PRED:
523 return std::move(*Literal::CreateR0<bool>(false));
524 case S16:
525 case U16:
526 LOG(FATAL) << "u16/s16 literals not yet implemented";
527 case TUPLE:
528 LOG(FATAL) << "tuple element type cannot take on value of 0";
529 case OPAQUE:
530 LOG(FATAL) << "opaque element type cannot take on value of 0";
531 default:
532 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
533 }
534 }
535
One(PrimitiveType primitive_type)536 /* static */ Literal Literal::One(PrimitiveType primitive_type) {
537 switch (primitive_type) {
538 case U8:
539 return std::move(*Literal::CreateR0<uint8>(1));
540 case U32:
541 return std::move(*Literal::CreateR0<uint32>(1));
542 case U64:
543 return std::move(*Literal::CreateR0<uint64>(1));
544 case S8:
545 return std::move(*Literal::CreateR0<int8>(1));
546 case S32:
547 return std::move(*Literal::CreateR0<int32>(1));
548 case S64:
549 return std::move(*Literal::CreateR0<int64>(1));
550 case F16:
551 return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
552 case BF16:
553 return std::move(
554 *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
555 case F32:
556 return std::move(*Literal::CreateR0<float>(1));
557 case F64:
558 return std::move(*Literal::CreateR0<double>(1));
559 case C64:
560 return std::move(*Literal::CreateR0<complex64>(1));
561 case PRED:
562 return std::move(*Literal::CreateR0<bool>(true));
563 case S16:
564 case U16:
565 LOG(FATAL) << "u16/s16 literals not yet implemented";
566 case TUPLE:
567 LOG(FATAL) << "tuple element type cannot take on value of 1";
568 case OPAQUE:
569 LOG(FATAL) << "opaque element type cannot take on value of 1";
570 default:
571 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
572 }
573 }
574
MinValue(PrimitiveType primitive_type)575 /* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
576 switch (primitive_type) {
577 case U8:
578 return std::move(
579 *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
580 case U32:
581 return std::move(
582 *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
583 case U64:
584 return std::move(
585 *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
586 case S8:
587 return std::move(
588 *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
589 case S32:
590 return std::move(
591 *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
592 case S64:
593 return std::move(
594 *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
595 case F32:
596 return std::move(
597 *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
598 case F64:
599 return std::move(
600 *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
601 case C64:
602 LOG(FATAL) << "C64 element type has no minimum value";
603 case PRED:
604 return std::move(*Literal::CreateR0<bool>(false));
605 case S16:
606 case U16:
607 LOG(FATAL) << "u16/s16 literals not yet implemented";
608 case F16:
609 return std::move(*Literal::CreateR0<half>(
610 static_cast<half>(-std::numeric_limits<float>::infinity())));
611 case BF16:
612 return std::move(*Literal::CreateR0<bfloat16>(
613 static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
614 case TUPLE:
615 LOG(FATAL) << "tuple element type has no minimum value";
616 case OPAQUE:
617 LOG(FATAL) << "opaque element type has no minimum value";
618 default:
619 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
620 }
621 }
622
MaxValue(PrimitiveType primitive_type)623 /* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
624 switch (primitive_type) {
625 case U8:
626 return std::move(
627 *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
628 case U32:
629 return std::move(
630 *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
631 case U64:
632 return std::move(
633 *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
634 case S8:
635 return std::move(
636 *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
637 case S32:
638 return std::move(
639 *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
640 case S64:
641 return std::move(
642 *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
643 case F32:
644 return std::move(
645 *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
646 case F64:
647 return std::move(
648 *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
649 case PRED:
650 return std::move(*Literal::CreateR0<bool>(true));
651 case S16:
652 case U16:
653 LOG(FATAL) << "u16/s16 literals not yet implemented";
654 case F16:
655 return std::move(*Literal::CreateR0<half>(
656 static_cast<half>(std::numeric_limits<float>::infinity())));
657 case BF16:
658 return std::move(*Literal::CreateR0<bfloat16>(
659 static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
660 case TUPLE:
661 LOG(FATAL) << "tuple element type has no maximum value";
662 case OPAQUE:
663 LOG(FATAL) << "opaque element type has no maximum value";
664 default:
665 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
666 }
667 }
668
CreateR1(const tensorflow::core::Bitmap & values)669 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
670 const tensorflow::core::Bitmap& values) {
671 auto literal = MakeUnique<Literal>(
672 ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
673 literal->PopulateR1(values);
674 return literal;
675 }
676
PopulateR1(const tensorflow::core::Bitmap & values)677 void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
678 CHECK(ShapeUtil::IsArray(shape()));
679 CHECK_EQ(ShapeUtil::Rank(shape()), 1);
680 CHECK_EQ(element_count(), values.bits());
681 CHECK_EQ(shape().element_type(), PRED);
682 for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
683 Set({i}, values.get(i));
684 }
685 }
686
CreateR1U8(tensorflow::StringPiece value)687 /* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
688 tensorflow::StringPiece value) {
689 auto literal = MakeUnique<Literal>(
690 ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
691 for (int i = 0; i < value.size(); ++i) {
692 literal->Set<uint8>({i}, value[i]);
693 }
694 return literal;
695 }
696
CreateR2F32Linspace(float from,float to,int64 rows,int64 cols)697 /* static */ std::unique_ptr<Literal> Literal::CreateR2F32Linspace(float from,
698 float to,
699 int64 rows,
700 int64 cols) {
701 auto value = MakeLinspaceArray2D(from, to, rows, cols);
702 return CreateR2FromArray2D(*value);
703 }
704
Relayout(const Layout & new_layout,const ShapeIndex & shape_index) const705 std::unique_ptr<Literal> Literal::Relayout(
706 const Layout& new_layout, const ShapeIndex& shape_index) const {
707 // Create new shape with 'new_layout' set at the given shape index.
708 Shape new_shape = shape();
709 Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
710 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
711 *subshape->mutable_layout() = new_layout;
712 auto result = MakeUnique<Literal>(new_shape);
713 TF_CHECK_OK(result->CopyFrom(*this));
714 return result;
715 }
716
Relayout(const Shape & shape_with_layout) const717 std::unique_ptr<Literal> Literal::Relayout(
718 const Shape& shape_with_layout) const {
719 CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
720 << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
721 << " not compatible with literal shape "
722 << ShapeUtil::HumanString(shape());
723 std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
724 ShapeUtil::ForEachSubshape(
725 result->shape(),
726 [this, &result](const Shape& subshape, const ShapeIndex& index) {
727 if (ShapeUtil::IsArray(subshape)) {
728 TF_CHECK_OK(result->CopyFrom(*this,
729 /*dest_shape_index=*/index,
730 /*src_shape_index=*/index));
731 }
732 });
733 return result;
734 }
735
Reshape(tensorflow::gtl::ArraySlice<int64> dimensions) const736 StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
737 tensorflow::gtl::ArraySlice<int64> dimensions) const {
738 if (!ShapeUtil::IsArray(shape())) {
739 return InvalidArgument("Reshape does not support tuples.");
740 }
741 std::unique_ptr<Literal> output;
742 if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
743 output =
744 Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
745 } else {
746 output = CloneToUnique();
747 }
748 // Because the layout is monotonic, we can simply reuse the same sequence of
749 // values without changing their order.
750 output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
751
752 int64 elements_before = ShapeUtil::ElementsIn(shape());
753 int64 elements_after = ShapeUtil::ElementsIn(output->shape());
754 if (elements_before != elements_after) {
755 return InvalidArgument(
756 "Shapes before and after Literal::Reshape have different numbers "
757 "of elements: %s vs %s.",
758 ShapeUtil::HumanString(shape()).c_str(),
759 ShapeUtil::HumanString(output->shape()).c_str());
760 }
761 return std::move(output);
762 }
763
Transpose(tensorflow::gtl::ArraySlice<int64> permutation) const764 std::unique_ptr<Literal> Literal::Transpose(
765 tensorflow::gtl::ArraySlice<int64> permutation) const {
766 CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
767 CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
768 << "Given permutation is not a permutation of dimension numbers";
769 // To transpose the array, we just permute the dimensions and layout, and
770 // do a straight memory copy of the raw data set.
771 // This is considerably faster than iterating over every array element using
772 // the EachCell<>() and Set<>() APIs.
773 std::vector<int64> inverse_permutation = InversePermutation(permutation);
774 Shape permuted_shape =
775 ShapeUtil::PermuteDimensions(inverse_permutation, shape());
776 // Replace the layout with one affine to this shape, such that a
777 // transpose operation can be performed by leaving the flat values
778 // representation intact.
779 // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
780 // The shape with affine layout resulting from that operation will be
781 // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
782 // most minor.
783 //
784 // Essentially, given MinMaj(Di) the position of the Di dimension within the
785 // minor to major vector, and given T(Di) the index that the original Di
786 // dimension has within the transposed array, a layout is affine if
787 // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
788 // vector of the affine layout.
789 CHECK(LayoutUtil::IsDenseArray(permuted_shape));
790 Layout* layout = permuted_shape.mutable_layout();
791 layout->clear_minor_to_major();
792 for (auto index : LayoutUtil::MinorToMajor(shape())) {
793 layout->add_minor_to_major(inverse_permutation[index]);
794 }
795 std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
796 DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
797 ShapeUtil::ByteSizeOf(shape()));
798 std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
799 root_piece().size_bytes());
800 return new_literal;
801 }
802
Slice(tensorflow::gtl::ArraySlice<int64> start_indices,tensorflow::gtl::ArraySlice<int64> limit_indices) const803 std::unique_ptr<Literal> Literal::Slice(
804 tensorflow::gtl::ArraySlice<int64> start_indices,
805 tensorflow::gtl::ArraySlice<int64> limit_indices) const {
806 CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
807
808 DimensionVector result_dimensions;
809 for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
810 CHECK_GE(start_indices[dnum], 0);
811 CHECK_LE(limit_indices[dnum], shape().dimensions(dnum));
812 int64 dimension = limit_indices[dnum] - start_indices[dnum];
813 CHECK_GT(dimension, 0);
814 result_dimensions.push_back(dimension);
815 }
816 const auto result_shape =
817 ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
818 LayoutUtil::MinorToMajor(shape()));
819
820 auto result_literal = MakeUnique<Literal>(result_shape);
821
822 DimensionVector new_indices(ShapeUtil::Rank(result_shape));
823 switch (result_shape.element_type()) {
824 case F32:
825 result_literal->EachCell<float>(
826 [&](tensorflow::gtl::ArraySlice<int64> indices, float /*value*/) {
827 for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
828 new_indices[i] = indices[i] + start_indices[i];
829 }
830 float value = Get<float>(new_indices);
831 result_literal->Set<float>(indices, value);
832 });
833 return result_literal;
834 case C64:
835 result_literal->EachCell<complex64>(
836 [&](tensorflow::gtl::ArraySlice<int64> indices, complex64 /*value*/) {
837 for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
838 new_indices[i] = indices[i] + start_indices[i];
839 }
840 complex64 value = Get<complex64>(new_indices);
841 result_literal->Set<complex64>(indices, value);
842 });
843 return result_literal;
844 case S32:
845 result_literal->EachCell<int32>(
846 [&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
847 for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
848 new_indices[i] = indices[i] + start_indices[i];
849 }
850 int32 value = Get<int32>(new_indices);
851 result_literal->Set<int32>(indices, value);
852 });
853 return result_literal;
854 case U32:
855 result_literal->EachCell<uint32>(
856 [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 /*value*/) {
857 for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
858 new_indices[i] = indices[i] + start_indices[i];
859 }
860 uint32 value = Get<uint32>(new_indices);
861 result_literal->Set<uint32>(indices, value);
862 });
863 return result_literal;
864 default:
865 LOG(FATAL) << "not yet implemented: "
866 << PrimitiveType_Name(result_shape.element_type());
867 }
868 }
869
Clone() const870 Literal Literal::Clone() const {
871 Literal result(shape());
872 TF_CHECK_OK(result.CopyFrom(*this));
873 return result;
874 }
875
CloneToUnique() const876 std::unique_ptr<Literal> Literal::CloneToUnique() const {
877 auto result = MakeUnique<Literal>(shape());
878 TF_CHECK_OK(result->CopyFrom(*this));
879 return result;
880 }
881
GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,const ShapeIndex & shape_index) const882 string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
883 const ShapeIndex& shape_index) const {
884 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
885 CHECK(LayoutUtil::IsDenseArray(subshape));
886 switch (subshape.element_type()) {
887 case PRED:
888 return Get<bool>(multi_index, shape_index) ? "true" : "false";
889 case S8:
890 return StrCat(Get<int8>(multi_index, shape_index));
891 case S16:
892 return StrCat(Get<int16>(multi_index, shape_index));
893 case S32:
894 return StrCat(Get<int32>(multi_index, shape_index));
895 case S64:
896 return StrCat(Get<int64>(multi_index, shape_index));
897 case U8:
898 return StrCat(Get<uint8>(multi_index, shape_index));
899 case U16:
900 return StrCat(Get<uint16>(multi_index, shape_index));
901 case U32:
902 return StrCat(Get<uint32>(multi_index, shape_index));
903 case U64:
904 return StrCat(Get<uint64>(multi_index, shape_index));
905 case F16:
906 return StrCat(Get<half>(multi_index, shape_index));
907 case F32:
908 return StrCat(Get<float>(multi_index, shape_index));
909 case BF16:
910 return StrCat(
911 static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
912 case F64:
913 return StrCat(Get<double>(multi_index, shape_index));
914 case C64: {
915 complex64 c = Get<complex64>(multi_index, shape_index);
916 return StrCat("(", c.real(), ", ", c.imag(), ")");
917 }
918 default:
919 LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
920 }
921 }
922
GetSparseElementAsString(int64 sparse_element_number,const ShapeIndex & shape_index) const923 string Literal::GetSparseElementAsString(int64 sparse_element_number,
924 const ShapeIndex& shape_index) const {
925 const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
926 CHECK(LayoutUtil::IsSparseArray(subshape));
927 switch (subshape.element_type()) {
928 case PRED:
929 return GetSparseElement<bool>(sparse_element_number, shape_index)
930 ? "true"
931 : "false";
932 case S8:
933 return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
934 case S16:
935 return StrCat(
936 GetSparseElement<int16>(sparse_element_number, shape_index));
937 case S32:
938 return StrCat(
939 GetSparseElement<int32>(sparse_element_number, shape_index));
940 case S64:
941 return StrCat(
942 GetSparseElement<int64>(sparse_element_number, shape_index));
943 case U8:
944 return StrCat(
945 GetSparseElement<uint8>(sparse_element_number, shape_index));
946 case U16:
947 return StrCat(
948 GetSparseElement<uint16>(sparse_element_number, shape_index));
949 case U32:
950 return StrCat(
951 GetSparseElement<uint32>(sparse_element_number, shape_index));
952 case U64:
953 return StrCat(
954 GetSparseElement<uint64>(sparse_element_number, shape_index));
955 case F16:
956 return StrCat(GetSparseElement<half>(sparse_element_number, shape_index));
957 case F32:
958 return StrCat(
959 GetSparseElement<float>(sparse_element_number, shape_index));
960 case BF16:
961 return StrCat(static_cast<float>(
962 GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
963 case F64:
964 return StrCat(
965 GetSparseElement<double>(sparse_element_number, shape_index));
966 case C64: {
967 complex64 c =
968 GetSparseElement<complex64>(sparse_element_number, shape_index);
969 return StrCat("(", c.real(), ", ", c.imag(), ")");
970 }
971 default:
972 LOG(FATAL) << "Invalid element type for sparse arrays: "
973 << PrimitiveType_Name(subshape.element_type());
974 }
975 }
976
GetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index) const977 StatusOr<int64> Literal::GetIntegralAsS64(
978 tensorflow::gtl::ArraySlice<int64> multi_index) const {
979 CHECK(LayoutUtil::IsDenseArray(shape()));
980 switch (shape().element_type()) {
981 case PRED:
982 return Get<bool>(multi_index);
983 case U8:
984 return Get<uint8>(multi_index);
985 case S32:
986 return Get<int32>(multi_index);
987 case S64:
988 return Get<int64>(multi_index);
989 case U32:
990 return Get<uint32>(multi_index);
991 case U64:
992 return Get<uint64>(multi_index);
993 default:
994 return FailedPrecondition(
995 "Array element type is not integral: %s",
996 PrimitiveType_Name(shape().element_type()).c_str());
997 }
998 }
999
GetSparseIndex(int64 sparse_element_number,const ShapeIndex & shape_index) const1000 tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
1001 int64 sparse_element_number, const ShapeIndex& shape_index) const {
1002 const Piece& p = piece(shape_index);
1003 CHECK_GE(sparse_element_number, 0);
1004 CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
1005 return p.sparse_indices()->At(sparse_element_number);
1006 }
1007
SortSparseElements(const ShapeIndex & shape_index)1008 void Literal::SortSparseElements(const ShapeIndex& shape_index) {
1009 piece(shape_index).SortSparseElements();
1010 }
1011
SortSparseElements()1012 void Literal::Piece::SortSparseElements() {
1013 switch (subshape().element_type()) {
1014 case PRED:
1015 SortSparseElementsInternal<bool>();
1016 break;
1017 case S8:
1018 SortSparseElementsInternal<int8>();
1019 break;
1020 case U8:
1021 SortSparseElementsInternal<uint8>();
1022 break;
1023 case S16:
1024 SortSparseElementsInternal<int16>();
1025 break;
1026 case U16:
1027 SortSparseElementsInternal<uint16>();
1028 break;
1029 case S32:
1030 SortSparseElementsInternal<int32>();
1031 break;
1032 case U32:
1033 SortSparseElementsInternal<uint32>();
1034 break;
1035 case S64:
1036 SortSparseElementsInternal<int64>();
1037 break;
1038 case U64:
1039 SortSparseElementsInternal<uint64>();
1040 break;
1041 case F32:
1042 SortSparseElementsInternal<float>();
1043 break;
1044 case F64:
1045 SortSparseElementsInternal<double>();
1046 break;
1047 case C64:
1048 SortSparseElementsInternal<complex64>();
1049 break;
1050 case F16:
1051 SortSparseElementsInternal<half>();
1052 break;
1053 case BF16:
1054 SortSparseElementsInternal<bfloat16>();
1055 break;
1056 default:
1057 LOG(FATAL) << "Element type not valid for sparse array: "
1058 << PrimitiveType_Name(subshape().element_type());
1059 }
1060 }
1061
1062 template <typename NativeT>
SortSparseElementsInternal()1063 void Literal::Piece::SortSparseElementsInternal() {
1064 CHECK(LayoutUtil::IsSparseArray(subshape()));
1065 int64 num_elements = sparse_indices()->index_count();
1066 auto values = data<NativeT>();
1067 CHECK_LE(num_elements, values.size());
1068 sparse_indices()->SortWithValues(
1069 tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
1070 }
1071
1072 namespace {
1073
ToStringHelper(const Literal & literal,const ShapeIndex & shape_index,bool print_layout,std::vector<string> * pieces)1074 void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
1075 bool print_layout, std::vector<string>* pieces) {
1076 const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
1077
1078 auto shape_to_string = [print_layout](const Shape& shape) {
1079 if (print_layout) {
1080 return ShapeUtil::HumanStringWithLayout(shape);
1081 } else {
1082 return ShapeUtil::HumanString(shape);
1083 }
1084 };
1085
1086 // TODO(b/32894291): refactor this code to reduce code duplication.
1087 if (ShapeUtil::IsTuple(subshape)) {
1088 pieces->push_back(shape_to_string(subshape));
1089 pieces->push_back(" (\n");
1090 std::vector<string> tuple_pieces;
1091 for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
1092 ShapeIndex element_index = shape_index;
1093 element_index.push_back(i);
1094 std::vector<string> element_pieces;
1095 ToStringHelper(literal, element_index, print_layout, &element_pieces);
1096 tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
1097 }
1098 pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
1099 pieces->push_back("\n)");
1100 return;
1101 }
1102
1103 if (LayoutUtil::IsSparseArray(subshape)) {
1104 pieces->push_back(shape_to_string(subshape));
1105 pieces->push_back("{");
1106 int64 rank = ShapeUtil::Rank(subshape);
1107 int64 num_elements = literal.sparse_element_count();
1108 for (int64 i = 0; i < num_elements; ++i) {
1109 if (i > 0) {
1110 pieces->push_back(", ");
1111 }
1112 if (rank == 1) {
1113 pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
1114 pieces->push_back(": ");
1115 } else {
1116 pieces->push_back("[");
1117 pieces->push_back(
1118 tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
1119 pieces->push_back("]: ");
1120 }
1121 pieces->push_back(literal.GetSparseElementAsString(i));
1122 }
1123 pieces->push_back("}");
1124 return;
1125 }
1126
1127 CHECK(LayoutUtil::IsDenseArray(subshape));
1128
1129 auto element_to_string =
1130 [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
1131 PrimitiveType element_type = subshape.element_type();
1132 if (element_type == PRED) {
1133 // We display predicates in a densely packed form.
1134 return literal.Get<bool>(indices, shape_index) ? "1" : "0";
1135 }
1136 return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
1137 literal.GetAsString(indices, shape_index);
1138 };
1139
1140 if (ShapeUtil::Rank(subshape) == 0) {
1141 pieces->push_back(literal.GetAsString({}, shape_index));
1142 } else if (ShapeUtil::Rank(subshape) == 1) {
1143 pieces->push_back("{");
1144 for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
1145 pieces->push_back(element_to_string({i0}));
1146 }
1147 pieces->push_back("}");
1148 } else if (ShapeUtil::Rank(subshape) == 2) {
1149 pieces->push_back(shape_to_string(subshape));
1150 pieces->push_back(" {\n");
1151 for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
1152 pieces->push_back(" { ");
1153 for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
1154 pieces->push_back(element_to_string({i0, i1}));
1155 }
1156 pieces->push_back(" ");
1157 pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
1158 }
1159 pieces->push_back("}");
1160 } else if (ShapeUtil::Rank(subshape) == 3) {
1161 pieces->push_back(shape_to_string(subshape));
1162 pieces->push_back(" {\n");
1163 for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
1164 pieces->push_back(i0 > 0 ? ",\n{" : "{");
1165 for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
1166 pieces->push_back(i1 > 0 ? ",\n { " : " { ");
1167 for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
1168 pieces->push_back(element_to_string({i0, i1, i2}));
1169 }
1170 pieces->push_back(" }");
1171 }
1172 pieces->push_back(" }");
1173 }
1174 pieces->push_back("\n}");
1175 } else if (ShapeUtil::Rank(subshape) == 4) {
1176 pieces->push_back(shape_to_string(subshape));
1177 pieces->push_back(" {\n");
1178 for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
1179 pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
1180 for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
1181 pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
1182 for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
1183 pieces->push_back(" {");
1184 for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
1185 pieces->push_back(element_to_string({i0, i1, i2, i3}));
1186 }
1187 pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
1188 }
1189 pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
1190 : " },\n");
1191 }
1192 pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
1193 }
1194 pieces->push_back("}");
1195 } else if (ShapeUtil::Rank(subshape) == 5) {
1196 pieces->push_back(shape_to_string(subshape));
1197 pieces->push_back(" {\n");
1198 for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
1199 pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
1200 for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
1201 pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
1202 for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
1203 pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
1204 for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
1205 pieces->push_back(" {");
1206 for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
1207 pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
1208 }
1209 pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
1210 : "},\n");
1211 }
1212 pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
1213 : " },\n");
1214 }
1215 pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
1216 : " },\n");
1217 }
1218 pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
1219 }
1220 pieces->push_back("}");
1221 } else {
1222 pieces->push_back(shape_to_string(subshape));
1223 pieces->push_back(" {");
1224 literal.EachCellAsString(
1225 [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
1226 pieces->push_back(" ");
1227 pieces->push_back(value);
1228 });
1229 pieces->push_back("}");
1230 }
1231 }
1232
1233 } // namespace
1234
sparse_element_count() const1235 int64 Literal::sparse_element_count() const {
1236 CHECK(LayoutUtil::IsSparseArray(shape()));
1237 return sparse_indices()->index_count();
1238 }
1239
ToString(bool print_layout) const1240 string Literal::ToString(bool print_layout) const {
1241 std::vector<string> pieces;
1242 ToStringHelper(*this, {}, print_layout, &pieces);
1243 return tensorflow::str_util::Join(pieces, "");
1244 }
1245
MakeTuple(tensorflow::gtl::ArraySlice<const Literal * > elements)1246 /* static */ std::unique_ptr<Literal> Literal::MakeTuple(
1247 tensorflow::gtl::ArraySlice<const Literal*> elements) {
1248 std::vector<Shape> element_shapes;
1249 for (const Literal* element : elements) {
1250 element_shapes.push_back(element->shape());
1251 }
1252 auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
1253 for (int i = 0; i < elements.size(); ++i) {
1254 TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
1255 }
1256 return literal;
1257 }
1258
MakeTupleOwned(std::vector<std::unique_ptr<Literal>> elements)1259 /* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
1260 std::vector<std::unique_ptr<Literal>> elements) {
1261 std::vector<Shape> element_shapes;
1262 element_shapes.reserve(elements.size());
1263 for (const auto& element : elements) {
1264 element_shapes.push_back(element->shape());
1265 }
1266 auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
1267 for (int64 i = 0; i < elements.size(); ++i) {
1268 TF_CHECK_OK(
1269 literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
1270 }
1271 return literal;
1272 }
1273
EachCellAsString(const std::function<void (tensorflow::gtl::ArraySlice<int64> indices,const string & value)> & per_cell) const1274 void Literal::EachCellAsString(
1275 const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
1276 const string& value)>& per_cell) const {
1277 if (ShapeUtil::HasZeroElements(shape())) {
1278 return;
1279 }
1280 std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
1281 shape(), /*linear_index=*/0);
1282 do {
1283 per_cell(indices, GetAsString(indices));
1284 } while (IndexUtil::BumpIndices(shape(), &indices));
1285 }
1286
1287 namespace {
1288 template <typename NativeSrcT, typename NativeDestT>
ConvertBetweenNativeTypes(const Literal & src_literal)1289 std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
1290 CHECK(ShapeUtil::IsArray(src_literal.shape()));
1291 auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
1292 src_literal.shape(),
1293 primitive_util::NativeToPrimitiveType<NativeDestT>()));
1294 auto src_data = src_literal.data<NativeSrcT>();
1295 auto dest_data = result_literal->template data<NativeDestT>();
1296 int64 num_elements = src_literal.element_count();
1297
1298 for (int64 i = 0; i < num_elements; ++i) {
1299 dest_data[i] = static_cast<NativeDestT>(src_data[i]);
1300 }
1301 return result_literal;
1302 }
1303
1304 template <PrimitiveType primitive_src_type>
ConvertToC64(const Literal & src_literal)1305 std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
1306 CHECK(ShapeUtil::IsArray(src_literal.shape()));
1307 auto result_literal = MakeUnique<Literal>(
1308 ShapeUtil::ChangeElementType(src_literal.shape(), C64));
1309 using NativeSrcT =
1310 typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
1311 tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
1312 src_literal.data<NativeSrcT>();
1313 tensorflow::gtl::MutableArraySlice<complex64> dest_data =
1314 result_literal->data<complex64>();
1315 int64 num_elements = src_literal.element_count();
1316 for (int64 i = 0; i < num_elements; ++i) {
1317 dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
1318 }
1319 return result_literal;
1320 }
1321
1322 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
ConvertIfTypesMatch(const Literal & src_literal)1323 std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
1324 CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
1325 return ConvertBetweenNativeTypes<
1326 typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
1327 typename primitive_util::PrimitiveTypeToNative<
1328 primitive_dest_type>::type>(src_literal);
1329 }
1330
1331 template <PrimitiveType primitive_src_type>
ConvertIfDestTypeMatches(const Literal & src_literal,PrimitiveType primitive_dest_type)1332 StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
1333 const Literal& src_literal, PrimitiveType primitive_dest_type) {
1334 switch (primitive_dest_type) {
1335 #define CONVERT_IF_TYPES_MATCH(type) \
1336 case (type): \
1337 return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
1338 CONVERT_IF_TYPES_MATCH(PRED)
1339 CONVERT_IF_TYPES_MATCH(S8)
1340 CONVERT_IF_TYPES_MATCH(S32)
1341 CONVERT_IF_TYPES_MATCH(S64)
1342 CONVERT_IF_TYPES_MATCH(U8)
1343 CONVERT_IF_TYPES_MATCH(U32)
1344 CONVERT_IF_TYPES_MATCH(U64)
1345 CONVERT_IF_TYPES_MATCH(F16)
1346 CONVERT_IF_TYPES_MATCH(F32)
1347 CONVERT_IF_TYPES_MATCH(F64)
1348 CONVERT_IF_TYPES_MATCH(BF16)
1349 #undef CONVERT_IF_TYPES_MATCH
1350 case C64:
1351 return ConvertToC64<primitive_src_type>(src_literal);
1352 // Other types are not yet supported.
1353 default:
1354 return InvalidArgument(
1355 "Unimplemented: Convert from type %s to type %s",
1356 PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
1357 PrimitiveType_Name(primitive_dest_type).c_str());
1358 }
1359 }
1360
1361 } // namespace
1362
Convert(PrimitiveType primitive_dest_type) const1363 StatusOr<std::unique_ptr<Literal>> Literal::Convert(
1364 PrimitiveType primitive_dest_type) const {
1365 TF_RET_CHECK(ShapeUtil::IsArray(shape()));
1366 switch (shape().element_type()) {
1367 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
1368 case (type): \
1369 return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type);
1370 CONVERT_IF_DEST_TYPE_MATCHES(PRED)
1371 CONVERT_IF_DEST_TYPE_MATCHES(S8)
1372 CONVERT_IF_DEST_TYPE_MATCHES(S32)
1373 CONVERT_IF_DEST_TYPE_MATCHES(S64)
1374 CONVERT_IF_DEST_TYPE_MATCHES(U8)
1375 CONVERT_IF_DEST_TYPE_MATCHES(U32)
1376 CONVERT_IF_DEST_TYPE_MATCHES(U64)
1377 CONVERT_IF_DEST_TYPE_MATCHES(F16)
1378 CONVERT_IF_DEST_TYPE_MATCHES(F32)
1379 CONVERT_IF_DEST_TYPE_MATCHES(F64)
1380 CONVERT_IF_DEST_TYPE_MATCHES(BF16)
1381 #undef CONVERT_IF_DEST_TYPE_MATCHES
1382 // Other types are not yet supported.
1383 default:
1384 return InvalidArgument("Unimplemented: Convert from type %s to type %s",
1385 PrimitiveType_Name(shape().element_type()).c_str(),
1386 PrimitiveType_Name(primitive_dest_type).c_str());
1387 }
1388 }
1389
1390 template <typename NativeT>
EqualElementsInternal(const Literal::Piece & other,std::vector<int64> * multi_index) const1391 bool Literal::Piece::EqualElementsInternal(
1392 const Literal::Piece& other, std::vector<int64>* multi_index) const {
1393 if (multi_index->size() == ShapeUtil::Rank(subshape())) {
1394 return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
1395 }
1396 for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
1397 multi_index->push_back(i);
1398 if (!EqualElementsInternal<NativeT>(other, multi_index)) {
1399 return false;
1400 }
1401 multi_index->pop_back();
1402 }
1403 return true;
1404 }
1405
EqualElements(const Literal::Piece & other) const1406 bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
1407 DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
1408
1409 std::vector<int64> multi_index;
1410 switch (subshape().element_type()) {
1411 case PRED:
1412 return EqualElementsInternal<bool>(other, &multi_index);
1413 case U8:
1414 return EqualElementsInternal<uint8>(other, &multi_index);
1415 case S32:
1416 return EqualElementsInternal<int32>(other, &multi_index);
1417 case S64:
1418 return EqualElementsInternal<int64>(other, &multi_index);
1419 case U32:
1420 return EqualElementsInternal<uint32>(other, &multi_index);
1421 case U64:
1422 return EqualElementsInternal<uint64>(other, &multi_index);
1423 case F32:
1424 return EqualElementsInternal<float>(other, &multi_index);
1425 case F64:
1426 return EqualElementsInternal<double>(other, &multi_index);
1427 case F16:
1428 return EqualElementsInternal<half>(other, &multi_index);
1429 case BF16:
1430 return EqualElementsInternal<bfloat16>(other, &multi_index);
1431 case C64:
1432 return EqualElementsInternal<complex64>(other, &multi_index);
1433 default:
1434 LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
1435 << PrimitiveType_Name(subshape().element_type());
1436 }
1437 }
1438
operator ==(const Literal & other) const1439 bool Literal::operator==(const Literal& other) const {
1440 if (!ShapeUtil::Compatible(shape(), other.shape())) {
1441 return false;
1442 }
1443 for (const auto& pair : pieces_) {
1444 const ShapeIndex& index = pair.first;
1445 const Piece& piece = pair.second;
1446 if (!ShapeUtil::IsArray(piece.subshape())) {
1447 continue;
1448 }
1449
1450 const Piece& other_piece = other.piece(index);
1451 if (!piece.EqualElements(other_piece)) {
1452 return false;
1453 }
1454 }
1455 return true;
1456 }
1457
1458 namespace {
1459
1460 template <typename NativeT>
AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,NativeT value)1461 static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
1462 NativeT value) {
1463 for (int64 i = 0; i < data.size(); ++i) {
1464 if (data[i] != value) {
1465 return false;
1466 }
1467 }
1468 return true;
1469 }
1470
1471 } // namespace
1472
IsAll(int8 value) const1473 bool Literal::IsAll(int8 value) const {
1474 for (const auto& pair : pieces_) {
1475 const Piece& piece = pair.second;
1476 if (!ShapeUtil::IsArray(piece.subshape())) {
1477 continue;
1478 }
1479
1480 auto piece_is_all = [&]() {
1481 switch (shape().element_type()) {
1482 case U8:
1483 if (value >= 0) {
1484 return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
1485 }
1486 return false;
1487 case U32:
1488 if (value >= 0) {
1489 return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
1490 }
1491 return false;
1492 case U64:
1493 if (value >= 0) {
1494 return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
1495 }
1496 return false;
1497 case S8:
1498 return AllElementsEqualValue<int8>(piece.data<int8>(), value);
1499 case S32:
1500 return AllElementsEqualValue<int32>(piece.data<int32>(), value);
1501 case S64:
1502 return AllElementsEqualValue<int64>(piece.data<int64>(), value);
1503 case F32:
1504 return AllElementsEqualValue<float>(piece.data<float>(), value);
1505 case F64:
1506 return AllElementsEqualValue<double>(piece.data<double>(), value);
1507 case F16:
1508 return AllElementsEqualValue<half>(piece.data<half>(),
1509 static_cast<half>(value));
1510 case BF16:
1511 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1512 static_cast<bfloat16>(value));
1513 case PRED:
1514 if (value == 0) {
1515 return AllElementsEqualValue<bool>(piece.data<bool>(), false);
1516 }
1517 if (value == 1) {
1518 return AllElementsEqualValue<bool>(piece.data<bool>(), true);
1519 }
1520 return false;
1521 default:
1522 return false;
1523 }
1524 return false;
1525 };
1526
1527 if (!piece_is_all()) {
1528 return false;
1529 }
1530 }
1531 return true;
1532 }
1533
IsAllFloat(float value) const1534 bool Literal::IsAllFloat(float value) const {
1535 for (const auto& pair : pieces_) {
1536 const Piece& piece = pair.second;
1537 if (!ShapeUtil::IsArray(piece.subshape())) {
1538 continue;
1539 }
1540
1541 auto piece_is_all = [&]() {
1542 switch (shape().element_type()) {
1543 case F32:
1544 return AllElementsEqualValue<float>(piece.data<float>(), value);
1545 case F64:
1546 return AllElementsEqualValue<double>(piece.data<double>(), value);
1547 case F16:
1548 return AllElementsEqualValue<half>(piece.data<half>(),
1549 static_cast<half>(value));
1550 case BF16:
1551 return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
1552 static_cast<bfloat16>(value));
1553 default:
1554 return false;
1555 }
1556 };
1557 if (!piece_is_all()) {
1558 return false;
1559 }
1560 }
1561 return true;
1562 }
1563
IsAllComplex(complex64 value) const1564 bool Literal::IsAllComplex(complex64 value) const {
1565 switch (shape().element_type()) {
1566 case C64:
1567 return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
1568 value);
1569 default:
1570 return false;
1571 }
1572 }
1573
IsZero(tensorflow::gtl::ArraySlice<int64> indices) const1574 bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
1575 CHECK(ShapeUtil::IsArray(shape()));
1576 switch (shape().element_type()) {
1577 case U8:
1578 return Get<uint8>(indices) == 0;
1579 case U32:
1580 return Get<uint32>(indices) == 0;
1581 case U64:
1582 return Get<uint64>(indices) == 0;
1583 case S8:
1584 return Get<int8>(indices) == 0;
1585 case S32:
1586 return Get<int32>(indices) == 0;
1587 case S64:
1588 return Get<int64>(indices) == 0;
1589 case F32:
1590 return Get<float>(indices) == 0.0f;
1591 case F64:
1592 return Get<double>(indices) == 0.0;
1593 case C64:
1594 return Get<complex64>(indices) == complex64(0.0f, 0.0f);
1595 case F16:
1596 return Get<half>(indices) == static_cast<half>(0.0f);
1597 case BF16:
1598 return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
1599 case PRED:
1600 return Get<bool>(indices) == false;
1601 default:
1602 LOG(FATAL) << "Input literal must be an array.";
1603 }
1604 }
1605
1606 namespace {
1607
1608 template <typename RepeatedFieldT, typename NativeT>
CopyToRepeatedField(RepeatedFieldT * dest,const tensorflow::gtl::ArraySlice<NativeT> src)1609 void CopyToRepeatedField(RepeatedFieldT* dest,
1610 const tensorflow::gtl::ArraySlice<NativeT> src) {
1611 *dest = RepeatedFieldT(src.begin(), src.end());
1612 }
1613
1614 } // namespace
1615
WriteToProto(LiteralProto * proto) const1616 void Literal::Piece::WriteToProto(LiteralProto* proto) const {
1617 *proto->mutable_shape() = subshape();
1618 switch (subshape().element_type()) {
1619 case PRED:
1620 CopyToRepeatedField(proto->mutable_preds(), data<bool>());
1621 break;
1622 case U8:
1623 proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
1624 element_count());
1625 break;
1626 case U32:
1627 CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
1628 break;
1629 case U64:
1630 CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
1631 break;
1632 case S32:
1633 CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
1634 break;
1635 case S64:
1636 CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
1637 break;
1638 case F16:
1639 *proto->mutable_f16s() = string(
1640 reinterpret_cast<const char*>(data<half>().data()), size_bytes());
1641 if (!kLittleEndian) {
1642 ConvertEndianShort(const_cast<char*>(proto->mutable_f16s()->data()),
1643 proto->f16s().size());
1644 }
1645 break;
1646 case BF16:
1647 *proto->mutable_bf16s() = string(
1648 reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
1649 if (!kLittleEndian) {
1650 ConvertEndianShort(const_cast<char*>(proto->mutable_bf16s()->data()),
1651 proto->bf16s().size());
1652 }
1653 break;
1654 case F32:
1655 CopyToRepeatedField(proto->mutable_f32s(), data<float>());
1656 break;
1657 case F64:
1658 CopyToRepeatedField(proto->mutable_f64s(), data<double>());
1659 break;
1660 case C64:
1661 for (complex64 value : data<complex64>()) {
1662 proto->add_c64s(value.real());
1663 proto->add_c64s(value.imag());
1664 }
1665 break;
1666 case TUPLE:
1667 // Nothing to do but assign the shape which is done above.
1668 return;
1669 default:
1670 LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
1671 }
1672 }
1673
untyped_data() const1674 const void* Literal::Piece::untyped_data() const {
1675 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
1676 return buffer();
1677 }
1678
untyped_data()1679 void* Literal::Piece::untyped_data() {
1680 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
1681 return buffer();
1682 }
1683
1684 namespace {
1685
1686 template <typename RepeatedFieldT, typename NativeT>
CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,const RepeatedFieldT & src)1687 Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
1688 const RepeatedFieldT& src) {
1689 if (dest.size() != src.size()) {
1690 return InvalidArgument(
1691 "Expected %lu elements in LiteralProto repeated field, has %d",
1692 dest.size(), src.size());
1693 }
1694 std::copy(src.begin(), src.end(), dest.begin());
1695 return Status::OK();
1696 }
1697
1698 } // namespace
1699
CopyFromProto(const LiteralProto & proto)1700 Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
1701 // These conditions should have been checked in Literal::CreateFromProto.
1702 TF_RET_CHECK(proto.has_shape());
1703 TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
1704 TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
1705
1706 switch (subshape().element_type()) {
1707 case PRED:
1708 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
1709 break;
1710 case U8: {
1711 auto u8_data = data<uint8>();
1712 TF_RET_CHECK(proto.u8s().size() == u8_data.size());
1713 std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
1714 } break;
1715 case S32:
1716 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
1717 break;
1718 case S64:
1719 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
1720 break;
1721 case U32:
1722 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
1723 break;
1724 case U64:
1725 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
1726 break;
1727 case F16: {
1728 const string& s(proto.f16s());
1729 TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
1730 memcpy(untyped_data(), s.data(), s.size());
1731 if (!kLittleEndian) {
1732 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1733 }
1734 } break;
1735
1736 case BF16: {
1737 const string& s(proto.bf16s());
1738 TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
1739 memcpy(untyped_data(), s.data(), s.size());
1740 if (!kLittleEndian) {
1741 ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
1742 }
1743 } break;
1744 case F32:
1745 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
1746 break;
1747 case F64:
1748 TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
1749 break;
1750 case C64: {
1751 auto complex_data = data<complex64>();
1752 TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
1753 for (int64 i = 0; i < complex_data.size(); ++i) {
1754 complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
1755 }
1756 } break;
1757 case TUPLE:
1758 LOG(FATAL) << "Should not be called on tuple shapes: "
1759 << ShapeUtil::HumanString(subshape());
1760 break;
1761 default:
1762 LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
1763 }
1764 return Status::OK();
1765 }
1766
ToProto() const1767 LiteralProto Literal::ToProto() const {
1768 LiteralProto proto;
1769 for (const auto& pair : pieces_) {
1770 const ShapeIndex& index = pair.first;
1771 const Piece& piece = pair.second;
1772
1773 LiteralProto* proto_piece = &proto;
1774 for (int64 i : index) {
1775 while (proto_piece->tuple_literals_size() <= i) {
1776 proto_piece->add_tuple_literals();
1777 }
1778 proto_piece = proto_piece->mutable_tuple_literals(i);
1779 }
1780 piece.WriteToProto(proto_piece);
1781 }
1782
1783 if (LayoutUtil::IsSparseArray(shape())) {
1784 CopyToRepeatedField(proto.mutable_sparse_indices(),
1785 sparse_indices()->data());
1786 }
1787
1788 return proto;
1789 }
1790
1791 /* static */
CreateFromProto(const LiteralProto & proto)1792 StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
1793 const LiteralProto& proto) {
1794 if (!proto.has_shape()) {
1795 return InvalidArgument("LiteralProto has no shape");
1796 }
1797 if (!LayoutUtil::HasLayout(proto.shape())) {
1798 return InvalidArgument("LiteralProto has no layout");
1799 }
1800
1801 auto literal = MakeUnique<Literal>(proto.shape());
1802
1803 for (auto& pair : literal->pieces_) {
1804 const ShapeIndex& index = pair.first;
1805 Piece& piece = pair.second;
1806 const LiteralProto* proto_element = &proto;
1807 for (int64 i : index) {
1808 TF_RET_CHECK(i < proto_element->tuple_literals_size());
1809 proto_element = &proto_element->tuple_literals(i);
1810 }
1811
1812 if (ShapeUtil::IsTuple(piece.subshape())) {
1813 if (proto_element->tuple_literals_size() !=
1814 ShapeUtil::TupleElementCount(piece.subshape())) {
1815 return InvalidArgument(
1816 "Expected %lld tuple elements in LiteralProto, has %d",
1817 ShapeUtil::TupleElementCount(piece.subshape()),
1818 proto_element->tuple_literals_size());
1819 }
1820 continue;
1821 }
1822
1823 TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
1824 TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
1825 }
1826 return std::move(literal);
1827 }
1828
untyped_data(const ShapeIndex & shape_index) const1829 const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
1830 return piece(shape_index).untyped_data();
1831 }
1832
untyped_data(const ShapeIndex & shape_index)1833 void* Literal::untyped_data(const ShapeIndex& shape_index) {
1834 return piece(shape_index).untyped_data();
1835 }
1836
size_bytes(const ShapeIndex & shape_index) const1837 int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
1838 return piece(shape_index).size_bytes();
1839 }
1840
GetR1U8AsString() const1841 string Literal::GetR1U8AsString() const {
1842 CHECK(ShapeUtil::IsArray(shape()));
1843 CHECK_EQ(ShapeUtil::Rank(shape()), 1);
1844 CHECK_EQ(shape().element_type(), U8);
1845 return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
1846 ShapeUtil::ElementsIn(shape()));
1847 }
1848
Create(const Literal & literal,const ShapeIndex & view_root)1849 /* static */ const LiteralView LiteralView::Create(
1850 const Literal& literal, const ShapeIndex& view_root) {
1851 return LiteralView(literal, view_root);
1852 }
1853
LiteralView(const Literal & literal,const ShapeIndex & view_root)1854 LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
1855 shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
1856 pieces_ = ShapeTree<Piece>(shape_);
1857 owns_buffers_ = false;
1858 for (auto& pair : pieces_) {
1859 const ShapeIndex& index = pair.first;
1860 Piece& piece = pair.second;
1861
1862 ShapeIndex src_index = view_root;
1863 for (int64 i : index) {
1864 src_index.push_back(i);
1865 }
1866 const Piece& src_piece = literal.piece(src_index);
1867 piece.set_buffer(src_piece.buffer());
1868 piece.set_sparse_indices(src_piece.sparse_indices());
1869 piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
1870 }
1871 }
1872
~LiteralView()1873 LiteralView::~LiteralView() {}
1874
LiteralView(const LiteralView & other)1875 LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
1876
operator =(const LiteralView & other)1877 LiteralView& LiteralView::operator=(const LiteralView& other) {
1878 CopyFrom(other);
1879 return *this;
1880 }
1881
CopyFrom(const LiteralView & other)1882 void LiteralView::CopyFrom(const LiteralView& other) {
1883 // We can't use the default copy-constructor/copy-assignment because
1884 // Piece::subshape_ points to subshapes within the Shape of the owning
1885 // Literal/LiteralView.
1886 shape_ = other.shape();
1887 pieces_ = other.pieces_;
1888 for (auto& pair : pieces_) {
1889 const ShapeIndex& index = pair.first;
1890 Piece& piece = pair.second;
1891 piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
1892 }
1893 owns_buffers_ = false;
1894 }
1895
1896 } // namespace xla
1897