1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utility class for managing sparse array indices.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
19 #define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
20 
21 #include <vector>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/index_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 
31 // Encapsulates the array of indices for a sparse array.  A SparseIndexArray
32 // contain indices for up to `max_indices` elements of a sparse array.  Each
33 // sparse index is an array of `rank` int64 value that gives the location of a
34 // value within a sparse array.  Note that the dimensions of the array are not
35 // checked (except for the rank).  To avoid confusion, we refer to the position
36 // of an index within a SparseIndexArray as a sparse index number.
37 class SparseIndexArray {
38  public:
39   SparseIndexArray();
40   SparseIndexArray(const SparseIndexArray&) = default;
41   SparseIndexArray(SparseIndexArray&&) = default;
42   SparseIndexArray& operator=(const SparseIndexArray&) = default;
43   SparseIndexArray& operator=(SparseIndexArray&&) = default;
44 
45   // Constructs a SparseIndexArray that can hold up to `max_indices` sparse
46   // indices, with an initial contents obtained from the given array.  The rank
47   // is taken from the minor dimension of the array.  The major dimension of the
48   // array must not exceed `max_indices`.
49   SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
50 
51   // Like above, but the array is flattened.  For example, the following are
52   // equivalent:
53   //
54   //  SparseIndexArray(10, 3,
55   //                   Array2D{
56   //                     {0, 1, 2},
57   //                     {3, 4, 5},
58   //                     {6, 7, 8},
59   //                     {9, 10, 11},
60   //                   })
61   //
62   //  SparseIndexArray(10, 3,
63   //                   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
64   //
65   SparseIndexArray(int64 max_indices, int64 rank,
66                    std::vector<int64> indices = {});
67   SparseIndexArray(int64 max_indices, int64 rank,
68                    absl::Span<const int64> indices);
69 
70   // Returns the number of elements represented by the indices stored in the
71   // array.
72   int64 index_count() const;
73 
74   // Returns a slice that refers to the given sparse index number. The argument
75   // must be in the range [0, element_count()).
76   absl::Span<const int64> At(int64 sparse_element_number) const;
77   absl::Span<int64> At(int64 sparse_element_number);
78 
79   // Adds the given index at the end of the array.  The new size of the
80   // SparseIndexArray must not exceed `max_indices`.
81   void Append(absl::Span<const int64> index);
82 
83   // Removes all indices from the array.
84   void Clear();
85 
86   // Resizes the array to contain the given number of sparse indices.  The new
87   // size must be smaller than `max_indices`.  If the new size is larger than
88   // the old size, the value of the new indices is not specified.
89   void Resize(int64 num_indices);
90 
91   // Returns true iff all indices are unique and occur in sorted order, and are
92   // valid for the given shape.
93   bool Validate(const Shape& shape) const;
94 
rank()95   int64 rank() const { return rank_; }
max_indices()96   int64 max_indices() const { return max_indices_; }
97 
98   // Returns a pointer to the int64 array that holds the sparse indices.
mutable_data()99   absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
data()100   absl::Span<const int64> data() const { return indices_; }
101 
102   // Sorts this sparse index array along with the set of corresponding values.
103   // The indices and values are sorted in the lexicographic order of the
104   // indices, from smallest to largest.
105   //
106   // For example:
107   //
108   //   std::vector<float> v{10.0, 11.0, 12.0};
109   //   SparseIndexArray a(10, 3,
110   //                      {{3, 4, 5},
111   //                       {1, 2, 3},
112   //                       {2, 3, 4}});
113   //   a.SortWithValues(&v);
114   //   // Prints "11.0, 12.0, 10.0":
115   //   std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
116   //
117   template <typename NativeT>
118   void SortWithValues(absl::Span<NativeT> values);
119 
120  private:
121   std::vector<int64> indices_;
122   int64 rank_;
123   int64 max_indices_;
124 };
125 
126 template <typename NativeT>
SortWithValues(absl::Span<NativeT> values)127 void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
128   int64 num_elements = index_count();
129   CHECK_EQ(values.size(), num_elements);
130   std::vector<int64> sort_order;
131   sort_order.reserve(num_elements);
132   for (int64 i = 0; i < num_elements; ++i) {
133     sort_order.push_back(i);
134   }
135   auto sort_order_less = [this](int64 lhs, int64 rhs) {
136     return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
137   };
138   absl::c_sort(sort_order, sort_order_less);
139 
140   // Reorder the array elements according to sort_order.  Work through the array
141   // and follow cycles so we can do the reorder in-place.
142   absl::InlinedVector<int64, 8> saved_index(rank());
143   for (int64 i = 0; i < num_elements; ++i) {
144     // sort_order[i] == -1 indicates the element has already been copied.
145     if (sort_order[i] < 0) {
146       continue;
147     } else if (i == sort_order[i]) {
148       // The element is already in sorted order.
149       sort_order[i] = -1;
150       continue;
151     }
152 
153     std::copy_n(At(i).begin(), rank(), saved_index.begin());
154     NativeT saved_value = values[i];
155     int64 j = i;
156     for (;;) {
157       if (sort_order[j] == i) {
158         std::copy_n(saved_index.begin(), rank(), At(j).begin());
159         values[j] = saved_value;
160         sort_order[j] = -1;
161         break;
162       }
163 
164       std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
165       values[j] = values[sort_order[j]];
166 
167       int64 k = sort_order[j];
168       sort_order[j] = -1;
169       j = k;
170     }
171   }
172 }
173 
174 }  // namespace xla
175 
176 #endif  // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
177