1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <vector>
22
23 #include "tensorflow/core/platform/dynamic_annotations.h"
24 #include "tensorflow/core/platform/macros.h"
25
26 template <typename T>
TopK(tensorflow::int64 batch_size,tensorflow::int64 input_size,tensorflow::int64 k,const T * values,T * out_values,tensorflow::int32 * out_indices)27 static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size,
28 tensorflow::int64 k, const T* values, T* out_values,
29 tensorflow::int32* out_indices) {
30 // 'values' is managed by the JIT code, so msan can't tell they are
31 // initialized.
32 TF_ANNOTATE_MEMORY_IS_INITIALIZED(values,
33 input_size * batch_size * sizeof(T));
34
35 std::vector<tensorflow::int32> temp_indices(input_size);
36 for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) {
37 std::iota(temp_indices.begin(), temp_indices.end(), 0);
38
39 const T* values_batch = values + batch * input_size;
40
41 auto convert_to_int = [](T value) {
42 tensorflow::uint32 x;
43 std::memcpy(&x, &value, sizeof(x));
44 return static_cast<tensorflow::int32>(x) < 0
45 ? std::numeric_limits<tensorflow::int32>::max() - x
46 : x;
47 };
48
49 auto kth_element = temp_indices.begin() + k;
50 std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
51 [&](size_t i1, size_t i2) {
52 // Do the comparison in integers to enforce a total
53 // order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
54 tensorflow::int32 v1 = convert_to_int(values_batch[i1]);
55 tensorflow::int32 v2 = convert_to_int(values_batch[i2]);
56 if (v1 == v2) {
57 return i1 < i2; // Stabilize sorting.
58 }
59 return v1 > v2;
60 });
61
62 T* out_values_batch = out_values + batch * k;
63 tensorflow::int32* out_indices_batch = out_indices + batch * k;
64 std::copy(temp_indices.begin(), kth_element, out_indices_batch);
65 for (tensorflow::int64 i = 0; i < k; i++) {
66 out_values_batch[i] = values_batch[temp_indices[i]];
67 }
68 }
69 }
70
__xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size,tensorflow::int64 input_size,tensorflow::int64 k,const float * values,float * out_values,tensorflow::int32 * out_indices)71 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
72 tensorflow::int64 batch_size, tensorflow::int64 input_size,
73 tensorflow::int64 k, const float* values, float* out_values,
74 tensorflow::int32* out_indices) {
75 TopK(batch_size, input_size, k, values, out_values, out_indices);
76 }
77