1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
17 #define TENSORFLOW_COMPILER_XLA_LITERAL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <iterator>
22 #include <memory>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/array2d.h"
33 #include "tensorflow/compiler/xla/array3d.h"
34 #include "tensorflow/compiler/xla/array4d.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/bitmap.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 
52 // Forward declare Literal and LiteralSlice class to be used by the creation
53 // methods in the base class.
54 class Literal;
55 class LiteralSlice;
56 
57 // Abstract base class for literals.
58 class LiteralBase {
59  public:
60   virtual ~LiteralBase() = 0;
61 
62   // Literals are equal if they have compatible shapes and the same data
63   // values. Layout is not compared.
64   bool operator==(const LiteralBase& other) const;
65   bool operator!=(const LiteralBase& other) const { return !(*this == other); }
66 
67   // Returns the shape of the literal.
shape()68   const Shape& shape() const { return root_piece().subshape(); }
69 
70   // Serialize to proto.
71   LiteralProto ToProto() const;
72 
73   // Returns a Span of the array for this literal for the given NativeT
74   // (e.g., float). CHECKs if the subshape of the literal at the given
75   // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
76   // to native type.
77   template <typename NativeT>
78   absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
79 
80   // Returns a const pointer to (or size of) the underlying buffer holding the
81   // array at the given shape index. CHECKs if the subshape of the literal at
82   // the given ShapeIndex is not array.
83   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
84   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
85 
86   // Returns this literal's data as a string. This literal must be a rank-1 U8
87   // array.
88   string GetR1U8AsString() const;
89 
90   // Returns a string representation of the literal value. The Shape of the
91   // literal is a prefix of the literal value in the string.
92 
93   // Warning: this function can take minutes for multi-million
94   // element Literals.
95   string ToString() const;
96 
97   // Similar to ToString, but return the result in a compact
98   // one-line form.
99   string ToStringOneline() const;
100 
101   // Returns a string representation of the literal value which does *not*
102   // include the shape string.
103   string ToStringWithoutShape() const;
104 
105   // Similar to ToStringWithoutShape, but return the result in a compact
106   // one-line form.
107   string ToStringWithoutShapeOneline() const;
108 
109   // Returns a string representation of the literal value which includes the
110   // shape string with its layout.does *not* include the shape string.
111   string ToStringWithLayout() const;
112 
113   // Gets an element in the literal at the given index. The multi_index is
114   // CHECKed against the dimension sizes.
115   template <typename NativeT>
116   NativeT Get(absl::Span<const int64> multi_index,
117               const ShapeIndex& shape_index) const;
118   // Overloads of Get for array literals. CHECKs if the literal is not
119   // array-shaped and dense.
120   template <typename NativeT>
121   NativeT Get(absl::Span<const int64> multi_index) const;
122 
123   // Get the dynamic size on dim_index in the literal at the given shape_index.
124   int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const;
125   int32 GetDynamicSize(int64 dim_index) const;
126 
127   // Returns the element value at index (0, ..., 0), however many zeroes are
128   // required for that index.
129   template <typename NativeT>
130   NativeT GetFirstElement() const;
131 
132   // As above but returns any integer type casted to an int64.
133   absl::optional<int64> GetFirstInteger() const;
134 
135   // As Get(), but determines the correct type and converts the value
136   // into text.
137   string GetAsString(absl::Span<const int64> multi_index,
138                      const ShapeIndex& shape_index = {}) const;
139 
140   // Return whether the value at the specified index is equal to the provided
141   // generic `value` (T must be an arithmetic type).
142   //
143   // Precondition: must be an array.
144   template <typename T>
145   typename std::enable_if<(std::is_arithmetic<T>::value ||
146                            std::is_same<T, Eigen::half>::value ||
147                            std::is_same<T, bfloat16>::value),
148                           bool>::type
IsEqualAt(absl::Span<const int64> multi_index,T value)149   IsEqualAt(absl::Span<const int64> multi_index, T value) const {
150     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
151       return *as_s64 == value;
152     }
153     complex128 as_complex128 = *GetAsComplex128(multi_index);
154     return as_complex128.imag() == 0 && as_complex128.real() == value;
155   }
156 
IsEqualAt(absl::Span<const int64> multi_index,complex128 value)157   bool IsEqualAt(absl::Span<const int64> multi_index, complex128 value) const {
158     if (auto as_s64 = GetIntegralAsS64(multi_index)) {
159       return *as_s64 == value.real() && value.imag() == 0;
160     }
161     auto as_complex128 = GetAsComplex128(multi_index);
162     return *as_complex128 == value;
163   }
164 
165   // As Get(), but determines the correct type and converts the value into
166   // int64.  This literal must be an array.
167   absl::optional<int64> GetIntegralAsS64(
168       absl::Span<const int64> multi_index) const;
169 
170   // As Get(), but determines the correct type, and converts the value into
171   // double. This literal must be an array.
172   absl::optional<double> GetAsDouble(absl::Span<const int64> multi_index) const;
173 
174   // As Get(), but determines the correct type, and converts the value into
175   // complex128. All floating point types can be converted into complex128.
176   //
177   // This literal must be an array.
178   absl::optional<complex128> GetAsComplex128(
179       absl::Span<const int64> multi_index) const;
180 
181   // Invokes the "per cell" callback for each element in the provided
182   // literal with the element's indices and a string representation of
183   // the element's value.
184   //
185   // This function is useful if you want a polymorphic representation
186   // of the tensor's elements (turning it to a string for something
187   // like representation in a protobuf).
188   //
189   // This literal must have a dense layout.
190   void EachCellAsString(
191       const std::function<void(absl::Span<const int64> indices,
192                                const string& value)>& per_cell) const;
193   template <typename NativeT>
194   void EachCell(
195       std::function<void(absl::Span<const int64> indices, NativeT value)>
196           per_cell) const;
197 
198   // Returns whether every element in this literal is equal to value.
199   //
200   // value is an int8 because we expect this to be called with small
201   // compile-time constants (0, -1, etc.) and so that whatever value you pass
202   // can be represented exactly by floating-point types as small as 16 bits.
203   //
204   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
205   // are considered equal to true/false; other values are not considered equal
206   // to true. Also if this literal is not array-shaped false is returned.
207   bool IsAll(int8 value) const;
208 
209   // Like IsAll(const Literal&, int8), except we check whether the literal is
210   // equal to a particular floating-point number.
211   //
212   // If the literal is not a floating-point value, this always returns false.
213   //
214   // This casts value to the type of literal, then compares using ==.  The usual
215   // admonishments about floating-point equality checks apply.  We expect you to
216   // use this to check for values that can be expressed precisely as a float,
217   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
218   bool IsAllFloat(float value) const;
219 
220   // Like IsAll(const Literal&, int8), except we check whether the literal is
221   // equal to a particular complex number.
222   //
223   // If the literal is not a complex value, this always returns false.
224   //
225   // This casts value to the type of literal, then compares using ==.  The usual
226   // admonishments about floating-point equality checks apply.  We expect you to
227   // use this to check for complex values that can be expressed precisely as
228   // float pairs e.g. (-0.5, 1.0).
229   //
230   // This literal must have a dense layout.
231   bool IsAllComplex(complex64 value) const;
232 
233   // Literal consists entirely of the first element of the literal.
234   bool IsAllFirst() const;
235 
236   // Literal consists entirely of an iota.
237   bool IsR1Iota() const;
238 
239   // Returns whether this literal is zero at the specified index. This literal
240   // must be an array with a dense layout.
241   bool IsZero(absl::Span<const int64> indices) const;
242 
243   // Returns the count of the elements in the array at the given shape index in
244   // this literal.
245   int64 element_count(const ShapeIndex& index = {}) const {
246     if (index.empty()) {
247       // Common case, avoid GetSubshape().
248       return ShapeUtil::ElementsIn(shape());
249     }
250     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
251   }
252 
253   // Compute a hash for this literal.
254   size_t Hash() const;
255 
256   // Converts this literal to the given shape. Returns an error is the
257   // conversion is not possible.
258   StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
259 
260   // Converts this literal to another primitive type using a bitcast
261   // conversion. The to and from primitive types must have the same bit
262   // width. Returns an error if the conversion is not possible. This literal
263   // must be array-shaped.
264   StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
265 
266   // Converts this literal to another primitive type. Returns an error if the
267   // conversion is not possible. This literal must be array-shaped.
268   StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
269 
270   // Clones the underlying buffers into a new Literal.
271   Literal Clone() const;
272 
273   // TODO(b/67651157): The methods below which perform computation on Literals
274   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
275   // evaluator code which operates on Literals.
276   //
277   // Creates a new value that has the equivalent value as this
278   // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
279   // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
280   // minor-to-major dimension layout and the value in the cell at any given
281   // logical index (i0, i1) will be the same.
282   //
283   // For tuple shaped literals, shape_index should be used to select the inner
284   // array that the new layout applies to.
285   //
286   // Note: this is useful when the client wants to ensure that a value placed in
287   // the XLA allocation tracker has a particular layout; for efficiency
288   // purposes or avoiding unimplemented operation/layout combinations.
289   Literal Relayout(const Layout& new_layout,
290                    const ShapeIndex& shape_index = {}) const;
291 
292   // An overload of Relayout which changes the layout of the entire shape rather
293   // than being limited to a single array within the shape.
294   Literal Relayout(const Shape& shape_with_layout) const;
295 
296   // Generate a new literal whose static sizes are equal to the previous
297   // literal's dynamic sizes.
298   Literal ToStatic() const;
299 
300   // Expand a static literal into a new one with a bounded dyanmic literal. The
301   // static dimensions of the original literal becomes dynamic dimensions of the
302   // new literal, where the argument `bounded_shape` becomes the bounded shape
303   // of the new literal.
304   //
305   // Precondition: bounded_shape.is_dynamic()
306   Literal ToBoundedDynamic(const Shape& bounded_shape) const;
307 
308   // Creates a new literal by reshaping this literal to have the given
309   // dimensions. The total number of elements must not change; The
310   // implementation currently only supports monotonic dim0-major layouts.
311   // This literal must be an array.
312   StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
313 
314   // Creates a new literal by broadcasting this literal with `dimensions` to
315   // yield a literal of shape `result_shape`.
316   StatusOr<Literal> Broadcast(const Shape& result_shape,
317                               absl::Span<const int64> dimensions) const;
318 
319   // Creates a new literal by reordering the dimensions of this literal.
320   // The given `permutation` must be a permutation of the dimension numbers
321   // in the original literal, and it specifies the order of the new dimensions
322   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
323   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
324   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
325   // This literal must be an array.
326   Literal Transpose(absl::Span<const int64> permutation) const;
327 
328   // Creates a sub-array from this literal by extracting the indices
329   // [start_index, limit_index) of each dimension. The result literal has the
330   // same rank and layout as for the given literal. The number of indices in
331   // start_indices and limit_indices must be the rank of the literal, and the
332   // indices follow the order of the dimensions.
333   // This literal must be an array.
334   Literal Slice(absl::Span<const int64> start_indices,
335                 absl::Span<const int64> limit_indices) const;
336 
337   // Creates a literal with a prepended dimension with bound "times"; e.g. a
338   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
339   // literal replicated four times.
340   // This literal must be an array.
341   template <typename NativeT>
342   Literal Replicate(int64 times) const;
343 
344   // Creates a new Literal object with the shape specified as parameter.
345   // The content of the literal values is the default value of the primitive
346   // type of literal itself (0 for numeric types, and false for predicates).
347   //
348   // Note: It's an antipattern to use this method then immediately call
349   // MutableLiteralBase::Populate on the result (since that results in zero
350   // initialization, then reinitialization. Consider if a call to
351   // absl::make_unique<Literal>(shape), followed by the call to
352   // MutableLiteralBase::Populate can be used instead.
353   static Literal CreateFromShape(const Shape& shape);
354 
355  protected:
356   // A data structure representing a subshape at a particular ShapeIndex within
357   // the literal. For array-shaped ShapeIndexes, this data structure holds the
358   // pointer to the memory allocated for the array data.
359   class Piece {
360    public:
361     // Returns the buffer holding the array data for this piece as an array
362     // slice. This piece must be array-shaped.
363     template <typename NativeT>
364     absl::Span<const NativeT> data() const;
365     template <typename NativeT>
366     absl::Span<NativeT> data();
367 
368     // Returns the buffer holding the array data for this piece as a void*. This
369     // piece must be array-shaped.
370     void* untyped_data();
371     const void* untyped_data() const;
372 
373     // Gets or sets an element in the array at the given index. The multi_index
374     // is CHECKed against the dimension sizes of the array.  This piece must be
375     // array-shaped.
376     template <typename NativeT>
377     NativeT Get(absl::Span<const int64> index) const;
378     template <typename NativeT>
379     void Set(absl::Span<const int64> index, NativeT value);
380 
381     int32 GetDynamicSize(int64 dim_index) const;
382     void SetDynamicSize(int64 dim_index, int32 size);
383     // Gets/sets the buffer holding the array data.
buffer()384     char* buffer() const { return buffer_; }
set_buffer(char * buffer)385     void set_buffer(char* buffer) { buffer_ = buffer; }
386 
387     // Gets/sets the buffer holding dynamic sizes.
dynamic_size_buffer()388     int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
set_dynamic_size_buffer(int32 * dynamic_size_buffer)389     void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
390       dynamic_size_buffer_ = dynamic_size_buffer;
391     }
392 
dynamic_size_buffer_bytes()393     int64 dynamic_size_buffer_bytes() const {
394       return subshape().dimensions_size() * sizeof(int32);
395     }
396 
397     // Gets or sets the subshape of this piece. This reference points to a
398     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()399     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)400     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
401 
402     // Returns the size in bytes of the buffer holding the array data.
size_bytes()403     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
404 
405     // Returns the number of elements in this piece's array.
element_count()406     int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
407 
408     // Returns the child piece at 'index' of this piece.
child(int64 index)409     Piece& child(int64 index) { return children_[index]; }
410 
411     // Adds a child piece to this piece's children.
emplace_back(Piece child_piece)412     void emplace_back(Piece child_piece) {
413       children_.emplace_back(std::move(child_piece));
414     }
415 
416     // Returns the size of children pieces of this piece.
children_size()417     int64 children_size() { return children_.size(); }
418 
419     // Visitor functions that recursively traverses the piece and calls the
420     // given function at each child piece. The function has the type:
421     //    void (const ShapeIndex& index, const Piece& piece)
422     template <typename Fn>
ForEachSubpiece(const Fn & func)423     void ForEachSubpiece(const Fn& func) const {
424       ShapeIndex index;
425       return ForEachHelper(
426                  [&func](const ShapeIndex& index, const Piece& piece) {
427                    func(index, piece);
428                    return Status::OK();
429                  },
430                  *this, &index)
431           .IgnoreError();
432     }
433     // Same as above, but the function has the type:
434     //    Status (const ShapeIndex& index, const Piece& piece)
435     // The first non-OK return value is returned by the function.
436     template <typename Fn>
ForEachSubpieceWithStatus(const Fn & func)437     Status ForEachSubpieceWithStatus(const Fn& func) const {
438       ShapeIndex index;
439       return ForEachHelper(func, *this, &index);
440     }
441     // Same as above, but the function has the type:
442     //    Bool (const ShapeIndex& index, const Piece& piece)
443     // The first non-true return value is returned by the function.
444     template <typename Fn>
ForEachSubpieceWithBool(const Fn & func)445     bool ForEachSubpieceWithBool(const Fn& func) const {
446       ShapeIndex index;
447       return ForEachHelperBool(func, *this, &index);
448     }
449     // Same as above, but the function has the type:
450     //    Void (const ShapeIndex& index, Piece& piece)
451     template <typename Fn>
ForEachMutableSubpiece(const Fn & func)452     void ForEachMutableSubpiece(const Fn& func) {
453       ShapeIndex index;
454       return ForEachMutableHelper(
455                  [&func](const ShapeIndex& index, Piece* piece) {
456                    func(index, piece);
457                    return Status::OK();
458                  },
459                  const_cast<xla::LiteralBase::Piece*>(this), &index)
460           .IgnoreError();
461     }
462     // Same as above, but the function has the type:
463     //    Status (const ShapeIndex& index, Piece& piece)
464     // The first non-OK return value is returned by the function.
465     template <typename Fn>
ForEachMutableSubpieceWithStatus(const Fn & func)466     Status ForEachMutableSubpieceWithStatus(const Fn& func) {
467       ShapeIndex index;
468       return ForEachMutableHelper(
469           func, const_cast<xla::LiteralBase::Piece*>(this), &index);
470     }
471 
472     // Returns true if this piece and 'other' contain the same data. This piece
473     // and 'other' must be array-shaped and compatible. If a literal has dynamic
474     // shape, comparison is done only for the valid elements.
475     bool EqualElements(const Piece& other) const;
476 
477     // Returns true if this piece and other pieces have the same dynamic
478     // dimension sizes.
479     bool EqualDynamicSize(const Piece& other) const;
480 
481     // Writes the shape and data (if array-shaped) into the given proto.
482     void WriteToProto(LiteralProto* proto) const;
483 
484     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
485     // and src must be compatible. If only_dynamic_bound is true, only elements
486     // within dynamic bounds will be copied.
487     Status CopyFrom(const Piece& src, bool only_dynamic_bound);
488 
489     // Copies the data from the given proto into this piece. The shape of this
490     // piece must be equal (not just compatible) to the shape of the proto.
491     Status CopyFromProto(const LiteralProto& proto);
492 
493    private:
494     // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
495     // The first non-OK (or non-true) value is returned by the function.
496     // The callable 'func' has the same signature as described above in
497     // ForEachSubpiece*.
498     template <typename Fn>
ForEachHelper(const Fn & func,const Piece & piece,ShapeIndex * index)499     Status ForEachHelper(const Fn& func, const Piece& piece,
500                          ShapeIndex* index) const {
501       TF_RETURN_IF_ERROR(func(*index, piece));
502       for (int64 i = 0; i < piece.children_.size(); ++i) {
503         index->push_back(i);
504         TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
505         index->pop_back();
506       }
507       return Status::OK();
508     }
509     template <typename Fn>
ForEachHelperBool(const Fn & func,const Piece & piece,ShapeIndex * index)510     bool ForEachHelperBool(const Fn& func, const Piece& piece,
511                            ShapeIndex* index) const {
512       if (!func(*index, piece)) {
513         return false;
514       }
515       for (int64 i = 0; i < piece.children_.size(); ++i) {
516         index->push_back(i);
517         if (!ForEachHelperBool(func, piece.children_[i], index)) {
518           return false;
519         }
520         index->pop_back();
521       }
522       return true;
523     }
524     template <typename Fn>
ForEachMutableHelper(const Fn & func,Piece * piece,ShapeIndex * index)525     Status ForEachMutableHelper(const Fn& func, Piece* piece,
526                                 ShapeIndex* index) {
527       TF_RETURN_IF_ERROR(func(*index, piece));
528       for (int64 i = 0; i < piece->children_.size(); ++i) {
529         index->push_back(i);
530         TF_RETURN_IF_ERROR(
531             ForEachMutableHelper(func, &piece->children_[i], index));
532         index->pop_back();
533       }
534       return Status::OK();
535     }
536 
537     // Recursive helper for EqualElements.
538     template <typename NativeT>
539     bool EqualElementsInternal(const Piece& other,
540                                std::vector<int64>* multi_index) const;
541 
542     // Internal helper to copy elements from another given piece
543     template <typename NativeT>
544     void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
545 
546     // For array-shaped pieces, this is the buffer holding the literal data.
547     char* buffer_ = nullptr;
548 
549     int32* dynamic_size_buffer_ = nullptr;
550 
551     // The shape of piece. This points into the shape of the containing Literal
552     // (Literal::shape_).
553     const Shape* subshape_ = nullptr;
554 
555     // Children pieces for tuple shaped pieces.
556     std::vector<Piece> children_ = {};
557   };  // class Piece
558 
piece(const ShapeIndex & shape_index)559   const Piece& piece(const ShapeIndex& shape_index) const {
560     Piece* piece = &const_cast<Piece&>(root_piece());
561     for (const auto i : shape_index) {
562       DCHECK_GE(i, 0);
563       DCHECK_LT(i, piece->children_size());
564       piece = &piece->child(i);
565     }
566     return *piece;
567   }
568 
569   // Returns the piece at the root of the shape.
570   virtual const Piece& root_piece() const = 0;
571 
572   // LiteralSlice and Literal must access Pieces of other Literals.
573   friend class MutableLiteralBase;
574   friend class LiteralSlice;
575   friend class BorrowingLiteral;
576 
577  private:
578   template <typename NativeT>
579   Literal SliceInternal(const Shape& result_shape,
580                         absl::Span<const int64> start_indices) const;
581 };
582 
583 // Abstract base class representing a mutable literal in XLA.
584 class MutableLiteralBase : public LiteralBase {
585  public:
586   virtual ~MutableLiteralBase() = 0;
587 
588   // Returns a Span view of the array for this literal for the
589   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
590   // given ShapeIndex is not array. See primitive_util.h for the mapping from
591   // XLA type to native type.
592   template <typename NativeT>
593   absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
594   // Unhide const method from parent class.
595   using LiteralBase::data;
596 
597   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
598   // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()599   Shape* mutable_shape_do_not_use() { return shape_.get(); }
600 
601   // Set the dynamic size on dim_index in the literal at the given shape_index.
602   void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index,
603                       int32 size);
604   void SetDynamicSize(int64 dim_index, int32 size);
605 
606   // Returns a pointer to the underlying buffer holding the array at the given
607   // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
608   // is not array.
609   void* untyped_data(const ShapeIndex& shape_index = {});
610   // Unhide const method from parent class.
611   using LiteralBase::untyped_data;
612 
613   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
614   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
615   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
616   // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
617   // is true, only elements within dynamic bounds will be copied.
618   Status CopyFrom(const LiteralSlice& src_literal,
619                   const ShapeIndex& dest_shape_index = {},
620                   const ShapeIndex& src_shape_index = {},
621                   bool only_dynamic_bound = false);
622 
623   // Copies the values from src_literal, starting at src_base shape indexes,
624   // to this literal, starting at dest_base, where the copy size in each
625   // dimension is specified by copy_size.
626   // The src_literal and this literal must have the same primitive type,
627   // src_base+copy_size must fit the source literal dimensions, as well as
628   // dest_base+copy_size must fit the destination literal dimensions.
629   // Note: if either src_literal or this literal contains dimensions with zero
630   // element, then copy_size must be 0 in these dimensions while the
631   // corresponding base indices being 0.
632   // This literal and 'src_literal' must be arrays.
633   Status CopySliceFrom(const LiteralSlice& src_literal,
634                        absl::Span<const int64> src_base,
635                        absl::Span<const int64> dest_base,
636                        absl::Span<const int64> copy_size);
637 
638   // Copies one element from src_literal[src_index] to (*this)[dest_index].
639   Status CopyElementFrom(const LiteralSlice& src_literal,
640                          absl::Span<const int64> src_index,
641                          absl::Span<const int64> dest_index);
642 
643   // Sets an element in the literal at the given index. The multi_index is
644   // CHECKed against the dimension sizes.
645   template <typename NativeT>
646   void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
647            NativeT value);
648   // Overloads of Set for array literals. CHECKs if the literal is not
649   // array-shaped and dense.
650   template <typename NativeT>
651   void Set(absl::Span<const int64> multi_index, NativeT value);
652 
653   // As Set(), but truncates `value` to the literal element type before storing.
654   // This literal must be an array.
655   Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
656 
657   // As Set(), but truncates `value` to the literal element type before storing.
658   // This literal must be an array.
659   Status SetFromDouble(absl::Span<const int64> multi_index, double value);
660 
661   // Populate this literal with the given values. Examples:
662   //
663   //   // Populate with floats.
664   //   Array2D<float> float_values = ...
665   //   literal.PopulateR2FromArray2D(values);
666   //
667   //   // Populate with int32s.
668   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
669   //
670   // The shape and element type of this literal must match given values. For
671   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
672   // array of S32.
673   template <typename NativeT>
674   void PopulateR1(absl::Span<const NativeT> values);
675   void PopulateR1(const tensorflow::core::Bitmap& values);
676   template <typename NativeT>
677   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
678   template <typename NativeT>
679   void PopulateFromArray(const Array<NativeT>& values);
680   template <typename NativeT>
681   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
682   template <typename NativeT>
683   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
684   template <typename NativeT>
685   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
686 
687   // Populates literal values by calling the generator function for every cell
688   // in this literal object.
689   //
690   // generator must be a callable of the type
691   // NativeT(absl::Span<int64> indexes) or compatible.
692   //
693   // This literal must have a dense layout.
694   template <typename NativeT, typename FnType>
695   Status Populate(const FnType& generator);
696 
697   // A parallel version of Populate(). This can be used if the generator is
698   // thread-safe and the values for the shape's different elements are
699   // independent.
700   template <typename NativeT, typename FnType>
701   Status PopulateParallel(const FnType& generator);
702 
703   // Fills this literal with the given value.
704   template <typename NativeT>
705   void PopulateWithValue(NativeT value);
706 
707   // This operation is the inverse of DecomposeTuple. The given elements are
708   // moved into the tuple elements of a new tuple-shaped Literal which is
709   // returned. Upon return, each of the Literals in 'elements' is set to a nil
710   // shape (empty tuple).
711   static Literal MoveIntoTuple(absl::Span<Literal> elements);
712 
713   // Serialize from a proto.
714   static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
715                                            bool prohibit_empty_literal = true);
716 
717  protected:
718   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)719   Piece& piece(const ShapeIndex& shape_index) {
720     return const_cast<Piece&>(LiteralBase::piece(shape_index));
721   }
722 
root_piece()723   Piece& root_piece() const override { return *root_piece_; };
724 
725   // Internal template helper for the Literal::CopySliceFrom(), matching its
726   // arguments one by one.
727   template <typename NativeT>
728   Status CopySliceFromInternal(const LiteralBase& src_literal,
729                                absl::Span<const int64> src_base,
730                                absl::Span<const int64> dest_base,
731                                absl::Span<const int64> copy_size);
732 
733   // Utility structure which is used to create the optimal configuration for
734   // a ShapeUtil::ForEachIndex() scan across two literals.
735   struct StrideConfig {
736     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
737                  absl::Span<const int64> dimensions);
738 
739     // The dimensions of the stride operation. Essentially every dimension
740     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
741     // steps.
742     absl::Span<const int64> dimensions;
743     DimensionVector base;
744     DimensionVector step;
745     int64 minor_dimension = 0;
746     // The size of the strides for source and destination. One of the two
747     // (the one looping through its most minor dimension) will be 1, while
748     // the other will be the stride size at the dimension matching the other
749     // shape most minor dimension being scanned.
750     int64 dest_stride = 1;
751     int64 source_stride = 1;
752     // The size of the inner loop on the most minor dimension.
753     int64 minor_loop_size = 1;
754   };
755 
756   // Literal class always owns the shape. The parent class borrows this shape.
757   std::unique_ptr<Shape> shape_;
758 
759   Piece* root_piece_ = nullptr;
760 
761   // Implementation details shared between Populate() and PopulateParallel()
762   template <typename NativeT, typename FnType>
763   Status PopulateInternal(const FnType& generator, bool parallel);
764 
765   friend class LiteralBase;
766   friend class MutableBorrowingLiteral;
767 };
768 std::ostream& operator<<(std::ostream& out, const Literal& literal);
769 
770 // The underlying buffer and shape is always owned by this class.
771 class Literal : public MutableLiteralBase {
772  public:
Literal()773   Literal() : Literal(ShapeUtil::MakeNil()) {}
774 
775   // Create a literal of the given shape. The literal is allocated sufficient
776   // memory to hold the shape. Memory is uninitialized.
777   explicit Literal(const Shape& shape);
778   virtual ~Literal();
779 
780   // Literals are moveable, but not copyable. To copy a literal use
781   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
782   // of literals which can be expensive.
783   Literal(const Literal& other) = delete;
784   Literal& operator=(const Literal& other) = delete;
785   Literal(Literal&& other);
786   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
787   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
788   // to nullptr.
789   Literal(const Shape& shape, bool allocate_arrays);
790   Literal& operator=(Literal&& other);
791 
792   // Similar to CopyFrom, but with move semantics. The subshape of this literal
793   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
794   // (layouts and shapes must match), but need not be arrays. The memory
795   // allocated in this literal for the subshape at dest_shape_index is
796   // deallocated, and the respective buffers are replaced with those in
797   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
798   virtual Status MoveFrom(Literal&& src_literal,
799                           const ShapeIndex& dest_shape_index = {});
800 
801   // Returns a vector containing the tuple elements of this Literal as separate
802   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
803   // elements are moved into the new Literals; no data is copied. Upon return
804   // this Literal is set to a nil shape (empty tuple)
805   std::vector<Literal> DecomposeTuple();
806 
807  private:
808   // Deallocate the buffers held by this literal.
809   void DeallocateBuffers();
810 
811   // Recursively sets the subshapes and buffers of all subpieces rooted at
812   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
813   // the shape.
814   void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
815 };
816 
817 // The underlying buffer is not owned by this class and is always owned by
818 // others. The shape is not owned by this class and not mutable.
819 class MutableBorrowingLiteral : public MutableLiteralBase {
820  public:
821   virtual ~MutableBorrowingLiteral();
822 
MutableBorrowingLiteral()823   MutableBorrowingLiteral() : MutableLiteralBase() {}
824 
825   MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
826   MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
827 
828   // Implicit conversion constructors.
829   MutableBorrowingLiteral(MutableLiteralBase* literal);
830   MutableBorrowingLiteral(MutableBorrowingLiteral literal,
831                           const ShapeIndex& view_root);
832   MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
833 
834   // Create a literal from a list of buffers and a shape.
835   // Returns a tuple literal if `shape` is a tuple type.
836   MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);
837 
838  private:
839   // Recursively copies the subtree from the `src_piece` at the given child
840   // index to the `dest_piece`. For buffers only the pointers are copied, but
841   // not the content.
842   void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
843                         Piece* dest_piece);
844 };
845 
846 // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
847 // literal buffers always owned by others.
848 class LiteralSlice : public LiteralBase {
849  public:
LiteralSlice()850   LiteralSlice() : LiteralBase() {}
851 
852   // Implicit conversion constructors.
853   LiteralSlice(const LiteralBase& literal);
854   LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
855 
856  private:
root_piece()857   const Piece& root_piece() const override { return *root_piece_; };
858 
859   const Piece* root_piece_;  // Not owned.
860 };
861 
862 // A read-only Literal where the underlying buffers are never owned by this
863 // class.
864 class BorrowingLiteral : public LiteralBase {
865  public:
BorrowingLiteral()866   BorrowingLiteral() : LiteralBase() {}
867 
868   // 'src_buf_ptr' is not owned by this class and must outlive the
869   // lifetime of this class. It points to an appropriately sized buffer with
870   // data interpretered as indicated by 'shape'.
871   // This constructor is only used for array shapes.
872   BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
873   // Similar as above, except to be used for constructing non-nested tuples.
874   BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
875                    const Shape& shape);
876   // TODO(b/79707221): adding constructors for nested tuples as well.
877 
878  private:
879   // Recursively builds the subtree for the given piece and sets the subshapes
880   // of the given piece with the given shape.
881   void BuildPieceSubtree(const Shape& shape, Piece* piece);
882 
883   // Accessor for the root piece of this literal.
root_piece()884   const Piece& root_piece() const override { return root_piece_; };
885   Piece root_piece_;
886 
887   // Shape of this literal. Stored as unique_ptr such that the (default) move
888   // construction of this class would be trivially correct: the pointer to Shape
889   // root_piece_ stores will still point to the correct address.
890   std::unique_ptr<Shape> shape_;
891 };
892 
893 template <typename NativeT>
data()894 absl::Span<const NativeT> LiteralBase::Piece::data() const {
895   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
896   DCHECK_EQ(subshape().element_type(),
897             primitive_util::NativeToPrimitiveType<NativeT>())
898       << "Attempting to access "
899       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
900       << " type, but literal element type is "
901       << PrimitiveType_Name(subshape().element_type());
902   return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
903                                    element_count());
904 }
905 
906 template <typename NativeT>
data()907 absl::Span<NativeT> LiteralBase::Piece::data() {
908   DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
909   DCHECK_EQ(subshape().element_type(),
910             primitive_util::NativeToPrimitiveType<NativeT>())
911       << "Attempting to access "
912       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
913       << " type, but literal element type is "
914       << PrimitiveType_Name(subshape().element_type());
915   return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
916                              element_count());
917 }
918 
919 template <typename NativeT>
Get(absl::Span<const int64> multi_index)920 NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
921   CHECK(LayoutUtil::IsDenseArray(subshape()));
922   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
923       subshape(), multi_index)];
924 }
925 
926 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)927 void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
928                              NativeT value) {
929   CHECK(LayoutUtil::IsDenseArray(subshape()));
930   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
931       subshape(), multi_index)] = value;
932 }
933 
934 template <typename NativeT>
data(const ShapeIndex & shape_index)935 absl::Span<const NativeT> LiteralBase::data(
936     const ShapeIndex& shape_index) const {
937   return piece(shape_index).data<NativeT>();
938 }
939 
940 template <typename NativeT>
data(const ShapeIndex & shape_index)941 absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
942   return piece(shape_index).data<NativeT>();
943 }
944 
945 template <typename NativeT>
Get(absl::Span<const int64> multi_index,const ShapeIndex & shape_index)946 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
947                                 const ShapeIndex& shape_index) const {
948   return piece(shape_index).Get<NativeT>(multi_index);
949 }
950 
951 template <typename NativeT>
Get(absl::Span<const int64> multi_index)952 inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
953   return root_piece().Get<NativeT>(multi_index);
954 }
955 
956 template <typename NativeT>
Set(absl::Span<const int64> multi_index,const ShapeIndex & shape_index,NativeT value)957 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
958                                     const ShapeIndex& shape_index,
959                                     NativeT value) {
960   return piece(shape_index).Set<NativeT>(multi_index, value);
961 }
962 
963 template <typename NativeT>
Set(absl::Span<const int64> multi_index,NativeT value)964 inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
965                                     NativeT value) {
966   return root_piece().Set<NativeT>(multi_index, value);
967 }
968 
969 template <typename NativeT>
GetFirstElement()970 NativeT LiteralBase::GetFirstElement() const {
971   return data<NativeT>().at(0);
972 }
973 
974 template <typename NativeT>
EachCell(std::function<void (absl::Span<const int64> indices,NativeT value)> per_cell)975 void LiteralBase::EachCell(
976     std::function<void(absl::Span<const int64> indices, NativeT value)>
977         per_cell) const {
978   if (ShapeUtil::IsZeroElementArray(shape())) {
979     return;
980   }
981   std::vector<int64> indices(shape().rank(), 0);
982 
983   Shape shape_dynamic = shape();
984   for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
985     shape_dynamic.set_dimensions(i, GetDynamicSize(i));
986   }
987   do {
988     per_cell(indices, Get<NativeT>(indices));
989   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
990 }
991 
992 template <typename NativeT>
PopulateR1(absl::Span<const NativeT> values)993 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
994   CHECK(shape().IsArray());
995   CHECK_EQ(shape().rank(), 1);
996   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
997   CHECK_EQ(shape().element_type(),
998            primitive_util::NativeToPrimitiveType<NativeT>());
999   auto data_span = data<NativeT>();
1000   std::copy(values.begin(), values.end(), data_span.begin());
1001 }
1002 
1003 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1004 void MutableLiteralBase::PopulateR2(
1005     std::initializer_list<std::initializer_list<NativeT>> values) {
1006   CHECK(shape().IsArray());
1007   CHECK_EQ(shape().rank(), 2);
1008   CHECK_EQ(shape().element_type(),
1009            primitive_util::NativeToPrimitiveType<NativeT>());
1010 
1011   const int64 dim0_size = values.size();
1012   const int64 dim1_size = values.begin()->size();
1013   CHECK_EQ(dim0_size, shape().dimensions(0));
1014   CHECK_EQ(dim1_size, shape().dimensions(1));
1015 
1016   int64 dim0 = 0;
1017   for (auto inner_list : values) {
1018     int64 dim1 = 0;
1019     for (auto value : inner_list) {
1020       Set({dim0, dim1}, value);
1021       ++dim1;
1022     }
1023     CHECK_EQ(dim1_size, dim1);
1024     ++dim0;
1025   }
1026 }
1027 
1028 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1029 void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
1030   CHECK(shape().IsArray());
1031   CHECK_EQ(shape().element_type(),
1032            primitive_util::NativeToPrimitiveType<NativeT>());
1033   CHECK_EQ(shape().rank(), values.num_dimensions());
1034   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1035     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1036   }
1037   values.Each([this](absl::Span<const int64> indices, NativeT value) {
1038     this->Set(indices, value);
1039   });
1040 }
1041 
1042 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1043 void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1044   PopulateFromArray(values);
1045 }
1046 
1047 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1048 void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1049   PopulateFromArray(values);
1050 }
1051 
1052 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1053 void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1054   PopulateFromArray(values);
1055 }
1056 
1057 template <typename NativeT, typename FnType>
PopulateInternal(const FnType & generator,bool parallel)1058 Status MutableLiteralBase::PopulateInternal(const FnType& generator,
1059                                             bool parallel) {
1060   const Shape& this_shape = shape();
1061   const int64 rank = this_shape.rank();
1062   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1063   TF_RET_CHECK(this_shape.element_type() ==
1064                primitive_util::NativeToPrimitiveType<NativeT>());
1065   absl::Span<NativeT> literal_data = data<NativeT>();
1066   if (rank > 0) {
1067     StrideConfig stride_config(this_shape, this_shape,
1068                                AsInt64Slice(this_shape.dimensions()));
1069     int64 minor_dimension_size =
1070         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1071 
1072     auto init_function = [&](absl::Span<const int64> indexes) {
1073       DimensionVector minor_scan_indexes(rank, 0);
1074       const int64 index =
1075           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1076       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1077       for (int64 i = 0; i < minor_dimension_size; ++i) {
1078         minor_scan_indexes[stride_config.minor_dimension] = i;
1079         literal_data.at(index + i) = generator(minor_scan_indexes);
1080       }
1081     };
1082     if (parallel) {
1083       ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1084                                       stride_config.dimensions,
1085                                       stride_config.step, init_function);
1086     } else {
1087       ShapeUtil::ForEachIndex(
1088           this_shape, stride_config.base, stride_config.dimensions,
1089           stride_config.step,
1090           [&init_function](absl::Span<const int64> indexes) {
1091             init_function(indexes);
1092             return true;
1093           });
1094     }
1095   } else {
1096     // For scalars.
1097     literal_data.at(0) = generator({});
1098   }
1099   return Status::OK();
1100 }
1101 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1102 Status MutableLiteralBase::Populate(const FnType& generator) {
1103   return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1104 }
1105 
1106 template <typename NativeT, typename FnType>
PopulateParallel(const FnType & generator)1107 Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
1108   return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1109 }
1110 
1111 template <typename NativeT>
PopulateWithValue(NativeT value)1112 void MutableLiteralBase::PopulateWithValue(NativeT value) {
1113   CHECK(shape().IsArray());
1114   CHECK_EQ(shape().element_type(),
1115            primitive_util::NativeToPrimitiveType<NativeT>());
1116   for (NativeT& element : data<NativeT>()) {
1117     element = value;
1118   }
1119 }
1120 
1121 template <typename NativeT>
Replicate(int64 times)1122 Literal LiteralBase::Replicate(int64 times) const {
1123   DimensionVector bounds = {times};
1124   bounds.reserve(shape().dimensions_size() + 1);
1125   for (int64 bound : shape().dimensions()) {
1126     bounds.push_back(bound);
1127   }
1128   Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
1129   int64 elements = ShapeUtil::ElementsIn(literal.shape());
1130   if (elements == 0) {
1131     return literal;
1132   }
1133 
1134   DimensionVector output_indices(bounds.size(), 0);
1135   absl::Span<const int64> input_indices = output_indices;
1136   input_indices.remove_prefix(1);
1137 
1138   bool done = false;
1139   while (!done) {
1140     const auto element = Get<NativeT>(input_indices);
1141     literal.Set<NativeT>(output_indices, element);
1142 
1143     done = true;
1144     for (int n = 0; n < output_indices.size(); ++n) {
1145       ++output_indices[n];
1146       if (output_indices[n] < bounds[n]) {
1147         done = false;
1148         break;
1149       }
1150       output_indices[n] = 0;
1151     }
1152   }
1153   return literal;
1154 }
1155 
1156 }  // namespace xla
1157 
1158 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_H_
1159