1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/platform/types.h"
21 
22 extern "C" {
23 
24 // Each entry in 'values' represents a 3-dimensional shape with dimensions
25 // [a, b, c]. The 'b' dimension of each shape is sorted into ascending order
26 // according to the results of comparisons using the provided 'less_than'
27 // function. 'values_count' must be > 0 and specifies the number of entries in
28 // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
29 // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
30 // bytes. 'is_stable' specifies whether the sorting should be stable.
31 // 'run_options' and 'prof_counters' are passed through to the less-than
32 // function, which expects the following arguments:
33 // - pointer to the return value buffer (char*)
34 // - xla::ExecutableRunOptions = 'run_options' (char*)
35 // - pointers to the parameter buffers (char**)
36 // - pointers to the buffer tables = nullptr for thread local functions (char**)
37 // - profile counters = 'prof_counters' (int64*)
38 extern void __xla_cpu_runtime_KeyValueSort(
39     tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
40     char** values, tensorflow::int32 values_count,
41     tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable,
42     char* run_options, tensorflow::int64* prof_counters,
43     void (*less_than)(char*, char*, char**, char**, tensorflow::int64*));
44 }
45 
46 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
47