1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/kernels/determinant_op.h"
21 
22 #include <complex>
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/kernels/cuda_solvers.h"
26 #include "tensorflow/core/util/cuda_kernel_helper.h"
27 
28 namespace tensorflow {
29 namespace functor {
30 
31 typedef Eigen::GpuDevice GPUDevice;
32 namespace {
PermutationOrder(int n,const int * pivots)33 __device__ int PermutationOrder(int n, const int* pivots) {
34   // Compute the order of the permutation from the number of transpositions
35   // encoded in the pivot array, see:
36   // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340
37   int order = 0;
38   for (int i = 0; i < n - 1; ++i) {
39     // Notice: Internally, the cuBlas code uses Fortran convention (1-based)
40     // indexing so we expect pivots[i] == i + 1 for rows that were not moved.
41     order += pivots[i] != (i + 1);
42   }
43   return order;
44 }
45 
46 #if defined(__CUDACC__)
47 // Hack around missing support for complex in NVCC.
48 template <typename T>
complex_multiply(const std::complex<T> & a,const std::complex<T> & b)49 __device__ inline std::complex<T> complex_multiply(const std::complex<T>& a,
50                                                    const std::complex<T>& b) {
51   const T a_real = Eigen::numext::real(a);
52   const T a_imag = Eigen::numext::imag(a);
53   const T b_real = Eigen::numext::real(b);
54   const T b_imag = Eigen::numext::imag(b);
55   return std::complex<T>(a_real * b_real - a_imag * b_imag,
56                          a_real * b_imag + a_imag * b_real);
57 }
operator *(const complex64 & a,const complex64 & b)58 __device__ inline complex64 operator*(const complex64& a, const complex64& b) {
59   return complex_multiply<float>(a, b);
60 }
operator *(const complex64 & a,const float & b)61 __device__ inline complex64 operator*(const complex64& a, const float& b) {
62   return complex64(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b);
63 }
operator /(const complex64 & a,const float & b)64 __device__ inline complex64 operator/(const complex64& a, const float& b) {
65   const float inv_b = 1.0f / b;
66   return a * inv_b;
67 }
operator *(const complex128 & a,const complex128 & b)68 __device__ inline complex128 operator*(const complex128& a,
69                                        const complex128& b) {
70   return complex_multiply<double>(a, b);
71 }
operator *(const complex128 & a,const double & b)72 __device__ inline complex128 operator*(const complex128& a, const double& b) {
73   return complex128(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b);
74 }
operator /(const complex128 & a,const double & b)75 __device__ inline complex128 operator/(const complex128& a, const double& b) {
76   const double inv_b = 1.0 / b;
77   return a * inv_b;
78 }
79 #endif
80 }  // namespace
81 
82 // This kernel computes either determinant or log_abs_determinant, depending
83 // on the value of the template parameter. If compute_log_abs_det is false,
84 // the sign argument is ignored.
85 template <typename Scalar, bool compute_log_abs_det = true>
DeterminantFromPivotedLUKernel(int nthreads,int n,const Scalar * lu_factor,const int * all_pivots,Scalar * sign,Scalar * log_abs_det)86 __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
87                                                const Scalar* lu_factor,
88                                                const int* all_pivots,
89                                                Scalar* sign,
90                                                Scalar* log_abs_det) {
91   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
92   const int matrix_size = n * n;
93   const int stride = n + 1;
94   // We only parallelize over batches here. Performance is not critical,
95   // since this cheap O(n) kernel always follows an O(n^3) LU factorization.
96   // The main purpose is to avoid having to copy the LU decomposition to
97   // host memory.
98   CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
99     // Initialize sign to (-1)^order.
100     const int order = PermutationOrder(n, all_pivots + o_idx * n);
101     Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1);
102     RealScalar sum_log_abs_det = RealScalar(0);
103     int i_idx = matrix_size * o_idx;
104     for (int i = 0; i < n; ++i, i_idx += stride) {
105       const RealScalar abs_i = Eigen::numext::abs(lu_factor[i_idx]);
106       sum_log_abs_det += Eigen::numext::log(abs_i);
107       prod_sign = prod_sign * (lu_factor[i_idx] / abs_i);
108     }
109     if (!Eigen::numext::isfinite(sum_log_abs_det)) {
110       prod_sign = Scalar(0);
111       sum_log_abs_det = sum_log_abs_det > 0 ? -Eigen::numext::log(RealScalar(0))
112                                             : Eigen::numext::log(RealScalar(0));
113     }
114     if (compute_log_abs_det) {
115       sign[o_idx] = prod_sign;
116       log_abs_det[o_idx] = Scalar(sum_log_abs_det);
117     } else {
118       log_abs_det[o_idx] = prod_sign * Eigen::numext::exp(sum_log_abs_det);
119     }
120   }
121 }
122 
123 template <typename Scalar>
124 struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
operator ()tensorflow::functor::DeterminantFromPivotedLUFunctor125   void operator()(const GPUDevice& device,
126                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
127                   const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
128                   int* info) {
129     const int64 num_matrices = output.size();
130     const int64 n = lu_factor.dimension(2);
131     CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
132 
133     TF_CHECK_OK(CudaLaunchKernel(
134         DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>,
135         config.block_count, config.thread_per_block, 0, device.stream(),
136         config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr,
137         output.data()));
138   }
139 };
140 
141 template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
142 template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
143 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
144 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
145 
146 template <typename Scalar>
147 struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
operator ()tensorflow::functor::LogDeterminantFromPivotedLUFunctor148   void operator()(const GPUDevice& device,
149                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
150                   const int* pivots, typename TTypes<Scalar, 1>::Tensor sign,
151                   typename TTypes<Scalar, 1>::Tensor log_abs_det) {
152     const int64 num_matrices = sign.size();
153     const int64 n = lu_factor.dimension(2);
154     CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
155     TF_CHECK_OK(CudaLaunchKernel(
156         DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>,
157         config.block_count, config.thread_per_block, 0, device.stream(),
158         config.virtual_thread_count, n, lu_factor.data(), pivots, sign.data(),
159         log_abs_det.data()));
160   }
161 };
162 
163 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, float>;
164 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, double>;
165 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
166 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
167 
168 }  // namespace functor
169 }  // namespace tensorflow
170 
171 #endif  // GOOGLE_CUDA
172