1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/sorting.h"
17 
18 #include "tensorflow/compiler/xla/client/lib/comparators.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/loops.h"
21 #include "tensorflow/compiler/xla/client/lib/slicing.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 
TopK(XlaOp input,int64 k)28 XlaOp TopK(XlaOp input, int64 k) {
29   XlaBuilder* const builder = input.builder();
30   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
31     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
32     int last_dim = input_shape.dimensions_size() - 1;
33     int64 last_dim_size = input_shape.dimensions(last_dim);
34     // TODO(b/148796364): tune these constants for better performance.
35     const int64 kPerPartitionSize = 8192;        // 2^13
36     const int64 kLastDimSizeThreshold = 524288;  // 2^19
37     const int64 kMinNumPartitions = 8;
38     const int64 kMinimalK = 1000;
39     if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
40         (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
41       int64 num_partitions =
42           CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
43       if (num_partitions >= kMinNumPartitions) {
44         return TopKWithPartitions(input, k, num_partitions);
45       }
46     }
47 
48     Shape iota_shape =
49         ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
50     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
51     for (int64 i = 0; i < input_shape.rank(); ++i) {
52       if (input_shape.is_dynamic_dimension(i)) {
53         // Propagate dynamic dimension from inputs to iota.
54         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
55       }
56     }
57     auto input_dims = input_shape.dimensions();
58     XlaOp sort_result =
59         Sort({input, iota_s32},
60              CreateScalarGtComputation({input_shape.element_type(), S32},
61                                        iota_s32.builder()),
62              last_dim, /*is_stable=*/true);
63     std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
64     std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
65     limit_indices[last_dim] = k;
66     std::vector<int64> strides(input_shape.dimensions_size(), 1);
67 
68     XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
69                          limit_indices, strides);
70     XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
71                           limit_indices, strides);
72     return Tuple(builder, {values, indices});
73   });
74 }
75 
TopKWithPartitions(XlaOp input,int64 k,int64 num_partitions)76 XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
77   XlaBuilder* const builder = input.builder();
78   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
80     int last_dim = input_shape.dimensions_size() - 1;
81     // Calculate per partition size.
82     auto input_dims = input_shape.dimensions();
83     int64 last_dim_size = input_shape.dimensions(last_dim);
84     const int64 per_partition_size = CeilOfRatio(last_dim_size, num_partitions);
85     // Do normal TopK when per partition size is smaller than or equal to k.
86     if (k >= per_partition_size) {
87       return TopK(input, k);
88     }
89 
90     Shape iota_shape =
91         ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
92     XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
93     for (int64 i = 0; i < input_shape.rank(); ++i) {
94       if (input_shape.is_dynamic_dimension(i)) {
95         // Propagate dynamic dimension from inputs to iota.
96         iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
97       }
98     }
99 
100     auto topk_body_fn =
101         [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
102             XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
103       auto values = values_and_indices[0];
104       auto indices = values_and_indices[1];
105       auto input = values_and_indices[2];
106       auto iota_s32 = values_and_indices[3];
107 
108       // Slice value and indices for this partition.
109       XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
110                         ConstantR0<int32>(builder, per_partition_size));
111       XlaOp sliced_input =
112           DynamicSliceInMinorDims(input, {start}, {per_partition_size});
113       XlaOp sliced_indices =
114           DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
115       // Concat with previous results.
116       sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
117       sliced_indices =
118           ConcatInDim(builder, {indices, sliced_indices}, last_dim);
119       // Sort this slice
120       XlaOp sort_result =
121           Sort({sliced_input, sliced_indices},
122                CreateScalarGtComputation({input_shape.element_type(), S32},
123                                          sliced_indices.builder()),
124                last_dim, true);
125 
126       std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
127       std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
128       std::vector<int64> strides(input_shape.dimensions_size(), 1);
129       // Slice topk.
130       start_indices[last_dim] = 0;
131       limit_indices[last_dim] = k;
132       values = Slice(GetTupleElement(sort_result, 0), start_indices,
133                      limit_indices, strides);
134       indices = Slice(GetTupleElement(sort_result, 1), start_indices,
135                       limit_indices, strides);
136       return std::vector<XlaOp>{values, indices, input, iota_s32};
137     };
138 
139     // Get the values and indices for the first topk so that they can
140     // be passed to the while loop.
141     std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
142     std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
143     std::vector<int64> strides(input_shape.dimensions_size(), 1);
144     start_indices[last_dim] = 0;
145     limit_indices[last_dim] = per_partition_size;
146     // Slice value and indices for the first partition.
147     XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
148     XlaOp sliced_indices =
149         Slice(iota_s32, start_indices, limit_indices, strides);
150     // Sort this slice
151     XlaOp sort_result =
152         Sort({sliced_input, sliced_indices},
153              CreateScalarGtComputation({input_shape.element_type(), S32},
154                                        sliced_indices.builder()),
155              last_dim, /*is_stable=*/true);
156 
157     // Slice topk.
158     start_indices[last_dim] = 0;
159     limit_indices[last_dim] = k;
160     XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
161                          limit_indices, strides);
162     XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
163                           limit_indices, strides);
164 
165     // Pass the result of the first TopK to the while loop and do
166     // num_partition - 1 iterations.
167     TF_ASSIGN_OR_RETURN(auto values_and_indices,
168                         ForEachIndex(num_partitions - 1, S32, topk_body_fn,
169                                      {values, indices, input, iota_s32},
170                                      "topk_with_partition", builder));
171     return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
172   });
173 }
174 
175 }  // namespace xla
176