1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include <cmath>
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/kernels/linalg/determinant_op.h"
24 #endif
25 
26 #include "third_party/eigen3/Eigen/LU"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/numeric_types.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 #if GOOGLE_CUDA
37 #include "tensorflow/core/kernels/fill_functor.h"
38 #include "tensorflow/core/util/cuda_solvers.h"
39 #endif
40 
41 namespace tensorflow {
42 
43 // A helper function to compute the sign and absolute value of the log of the
44 // determinant of inputs via a partially pivoted LU
45 // factorization.
46 //
47 // Returns the log of the absolute value of the determinant, and its sign in
48 // 'sign'.
49 template <class Scalar>
SLogDet(const Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> & inputs,Scalar * sign)50 static typename Eigen::NumTraits<Scalar>::Real SLogDet(
51     const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>& inputs,
52     Scalar* sign) {
53   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
54   RealScalar log_abs_det = 0;
55   *sign = 1;
56   // An empty matrix' determinant is defined to be 1.
57   // (https://en.wikipedia.org/wiki/Determinant)
58   if (inputs.size() > 0) {
59     // Compute the log determinant through a Partially Pivoted LU decomposition
60     using Eigen::Dynamic;
61     Eigen::PartialPivLU<Eigen::Matrix<Scalar, Dynamic, Dynamic>> lu(inputs);
62     Eigen::Matrix<Scalar, Dynamic, Dynamic> LU = lu.matrixLU();
63     *sign = lu.permutationP().determinant();
64     auto diag = LU.diagonal().array().eval();
65     auto abs_diag = diag.cwiseAbs().eval();
66     log_abs_det += abs_diag.log().sum();
67     *sign *= (diag / abs_diag).prod();
68   }
69   if (!Eigen::numext::isfinite(log_abs_det)) {
70     *sign = 0;
71     log_abs_det =
72         log_abs_det > 0 ? -std::log(RealScalar(0)) : std::log(RealScalar(0));
73   }
74   return log_abs_det;
75 }
76 
77 template <class Scalar>
78 class LogDeterminantOp : public LinearAlgebraOp<Scalar> {
79  public:
80   INHERIT_LINALG_TYPEDEFS(Scalar);
81 
LogDeterminantOp(OpKernelConstruction * context)82   explicit LogDeterminantOp(OpKernelConstruction* context) : Base(context) {}
83 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const84   TensorShapes GetOutputMatrixShapes(
85       const TensorShapes& input_matrix_shapes) const final {
86     return TensorShapes({TensorShape({}), TensorShape({})});
87   }
88 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)89   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
90                      MatrixMaps* outputs) final {
91     Scalar sign;
92     const RealScalar log_abs_det = SLogDet(
93         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
94         &sign);
95 
96     outputs->at(0)(0, 0) = sign;
97     outputs->at(1)(0, 0) = log_abs_det;
98   }
99 };
100 
101 template <class Scalar>
102 class DeterminantOp : public LinearAlgebraOp<Scalar> {
103  public:
104   INHERIT_LINALG_TYPEDEFS(Scalar);
105 
DeterminantOp(OpKernelConstruction * context)106   explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {}
107 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shape) const108   TensorShapes GetOutputMatrixShapes(
109       const TensorShapes& input_matrix_shape) const final {
110     return TensorShapes({TensorShape({})});
111   }
112 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)113   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
114                      MatrixMaps* outputs) final {
115     Scalar sign;
116     const RealScalar log_abs_det = SLogDet(
117         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
118         &sign);
119     outputs->at(0)(0, 0) = sign * std::exp(log_abs_det);
120   }
121 };
122 
123 #if GOOGLE_CUDA
124 
125 typedef Eigen::GpuDevice GPUDevice;
126 
127 template <class Scalar>
128 class DeterminantOpGpu : public AsyncOpKernel {
129  public:
DeterminantOpGpu(OpKernelConstruction * context)130   explicit DeterminantOpGpu(OpKernelConstruction* context)
131       : AsyncOpKernel(context) {}
132 
ComputeAsync(OpKernelContext * context,DoneCallback done)133   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
134     const Tensor& input = context->input(0);
135     const int ndims = input.dims();
136     const int64 n = input.dim_size(ndims - 1);
137     // Validate inputs.
138     OP_REQUIRES_ASYNC(
139         context, ndims >= 2,
140         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
141         done);
142     OP_REQUIRES_ASYNC(
143         context, input.dim_size(ndims - 2) == n,
144         errors::InvalidArgument("Input matrices must be square, got",
145                                 input.dim_size(ndims - 2), " != ", n),
146         done);
147 
148     // Allocate output.
149     TensorShape out_shape;
150     for (int dim = 0; dim < ndims - 2; ++dim) {
151       out_shape.AddDim(input.dim_size(dim));
152     }
153     out_shape.AppendShape(TensorShape({}));
154     Tensor* out;
155     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &out),
156                          done);
157 
158     // By definition, the determinant of an empty matrix is equal to one.
159     const GPUDevice& d = context->eigen_device<GPUDevice>();
160     if (input.NumElements() == 0) {
161       functor::SetOneFunctor<GPUDevice, Scalar> f;
162       f(d, out->template flat<Scalar>());
163       done();
164       return;
165     }
166 
167     // TODO(rmlarsen): Convert to absl::make_unique when available.
168     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
169 
170     // Reuse the input buffer or make a copy for the factorization step,
171     // depending on whether this ops owns it exclusively.
172     Tensor input_copy;
173     OP_REQUIRES_OK_ASYNC(
174         context,
175         solver->forward_input_or_allocate_scoped_tensor(
176             {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
177         done);
178     if (!input.SharesBufferWith(input_copy)) {
179       d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
180                input.NumElements() * sizeof(Scalar));
181     }
182     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
183     const int64 batch_size = input_copy_reshaped.dimension(0);
184 
185     // Allocate pivots on the device.
186     Tensor pivots;
187     OP_REQUIRES_OK_ASYNC(
188         context,
189         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
190                                        TensorShape{batch_size, n}, &pivots),
191         done);
192     auto pivots_mat = pivots.template matrix<int>();
193 
194     // Prepare pointer arrays for cuBlas' batch interface.
195     // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
196     // without the ugly casting.
197     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
198         sizeof(Scalar*) * batch_size, "input_copy_ptrs",
199         /* on_host */ true);
200     auto output_reshaped = out->template flat_inner_dims<Scalar, 1>();
201 
202     // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
203     std::vector<DeviceLapackInfo> dev_info;
204     if (n / batch_size <= 128) {
205       // For small matrices or large batch sizes, we use the batched interface
206       // from cuBlas.
207       const Scalar** input_copy_ptrs_base =
208           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
209       for (int batch = 0; batch < batch_size; ++batch) {
210         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
211       }
212       dev_info.push_back(
213           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
214       OP_REQUIRES_OK_ASYNC(
215           context,
216           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
217                                &dev_info.back(), batch_size),
218           done);
219     } else {
220       // For small batch sizes we use the non-batched interface from cuSolver,
221       // which is much faster for large matrices.
222       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
223       for (int batch = 0; batch < batch_size; ++batch) {
224         OP_REQUIRES_OK_ASYNC(
225             context,
226             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
227                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
228             done);
229       }
230     }
231 
232     // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
233     // where s is the order of the permutation encoded in pivots and U is the
234     // upper triangular factor of the LU factorization, which is written to
235     // input_copy by the Getrf{Batched} kernel.
236     functor::DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
237     functor(d,
238             const_cast<const Tensor*>(&input_copy)
239                 ->template flat_inner_dims<Scalar, 3>(),
240             pivots_mat.data(), output_reshaped, dev_info.back().mutable_data());
241 
242     // Register callback to check info after kernels finish.
243     auto info_checker = [context, done](
244                             const Status& status,
245                             const std::vector<HostLapackInfo>& host_infos) {
246       if (!status.ok() && errors::IsInvalidArgument(status) &&
247           !host_infos.empty()) {
248         for (int i = 0; i < host_infos[0].size(); ++i) {
249           // It is OK for a matrix to be singular (signaled by info > 0),
250           // corresponding to determinant of zero, but we do want to catch
251           // invalid arguments to Getrf{Batched}.
252           OP_REQUIRES_ASYNC(
253               context, host_infos[0](i) >= 0,
254               errors::InvalidArgument("Invalid input argument no. ",
255                                       host_infos[0].data()[i],
256                                       " for batch index ", i, "."),
257               done);
258         }
259       }
260       done();
261     };
262     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
263                                                     std::move(info_checker));
264   }
265 };
266 
267 template <class Scalar>
268 class LogDeterminantOpGpu : public AsyncOpKernel {
269  public:
LogDeterminantOpGpu(OpKernelConstruction * context)270   explicit LogDeterminantOpGpu(OpKernelConstruction* context)
271       : AsyncOpKernel(context) {}
272 
ComputeAsync(OpKernelContext * context,DoneCallback done)273   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
274     const Tensor& input = context->input(0);
275     const int ndims = input.dims();
276     const int64 n = input.dim_size(ndims - 1);
277     // Validate inputs.
278     OP_REQUIRES_ASYNC(
279         context, ndims >= 2,
280         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
281         done);
282     OP_REQUIRES_ASYNC(
283         context, input.dim_size(ndims - 2) == n,
284         errors::InvalidArgument("Input matrices must be square, got",
285                                 input.dim_size(ndims - 2), " != ", n),
286         done);
287 
288     // Allocate output.
289     TensorShape out_shape;
290     for (int dim = 0; dim < ndims - 2; ++dim) {
291       out_shape.AddDim(input.dim_size(dim));
292     }
293     out_shape.AppendShape(TensorShape({}));
294     Tensor* sign;
295     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &sign),
296                          done);
297     Tensor* log_abs_det;
298     OP_REQUIRES_OK_ASYNC(
299         context, context->allocate_output(1, out_shape, &log_abs_det), done);
300 
301     // By definition, the determinant of an empty matrix is equal to one.
302     const GPUDevice& d = context->eigen_device<GPUDevice>();
303     if (input.NumElements() == 0) {
304       functor::SetOneFunctor<GPUDevice, Scalar> one_func;
305       one_func(d, sign->template flat<Scalar>());
306       functor::SetZeroFunctor<GPUDevice, Scalar> zero_func;
307       zero_func(d, log_abs_det->template flat<Scalar>());
308       done();
309       return;
310     }
311 
312     // TODO(rmlarsen): Convert to absl::make_unique when available.
313     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
314 
315     // Reuse the input buffer or make a copy for the factorization step,
316     // depending on whether this ops owns it exclusively.
317     Tensor input_copy;
318     OP_REQUIRES_OK_ASYNC(
319         context,
320         solver->forward_input_or_allocate_scoped_tensor(
321             {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
322         done);
323     if (!input.SharesBufferWith(input_copy)) {
324       d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
325                input.NumElements() * sizeof(Scalar));
326     }
327     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
328     const int64 batch_size = input_copy_reshaped.dimension(0);
329 
330     // Allocate pivots on the device.
331     Tensor pivots;
332     OP_REQUIRES_OK_ASYNC(
333         context,
334         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
335                                        TensorShape{batch_size, n}, &pivots),
336         done);
337     auto pivots_mat = pivots.template matrix<int>();
338 
339     // Prepare pointer arrays for cuBlas' batch interface.
340     // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
341     // without the ugly casting.
342     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
343         sizeof(Scalar*) * batch_size, "input_copy_ptrs",
344         /* on_host */ true);
345 
346     // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
347     std::vector<DeviceLapackInfo> dev_info;
348     if (n / batch_size <= 128) {
349       // For small matrices or large batch sizes, we use the batched interface
350       // from cuBlas.
351       const Scalar** input_copy_ptrs_base =
352           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
353       for (int batch = 0; batch < batch_size; ++batch) {
354         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
355       }
356       dev_info.push_back(
357           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
358       OP_REQUIRES_OK_ASYNC(
359           context,
360           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
361                                &dev_info.back(), batch_size),
362           done);
363     } else {
364       // For large matrices or small batch sizes we use the non-batched
365       // interface from cuSolver, which is much faster for large matrices.
366       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
367       for (int batch = 0; batch < batch_size; ++batch) {
368         OP_REQUIRES_OK_ASYNC(
369             context,
370             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
371                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
372             done);
373       }
374     }
375 
376     auto input_copy_reshaped_const =
377         const_cast<const Tensor*>(&input_copy)
378             ->template flat_inner_dims<Scalar, 3>();
379     auto sign_reshaped = sign->flat<Scalar>();
380     auto log_abs_det_reshaped = log_abs_det->flat<Scalar>();
381     // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
382     // where s is the order of the permutation encoded in pivots and U is the
383     // upper triangular factor of the LU factorization, which is written to
384     // input_copy by the Getrf{Batched} kernel.
385     functor::LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
386     functor(d, input_copy_reshaped_const, pivots_mat.data(), sign_reshaped,
387             log_abs_det_reshaped);
388 
389     // Register callback to check info after kernels finish.
390     auto info_checker = [context, done](
391                             const Status& status,
392                             const std::vector<HostLapackInfo>& host_infos) {
393       if (!status.ok() && errors::IsInvalidArgument(status) &&
394           !host_infos.empty()) {
395         for (int i = 0; i < host_infos[0].size(); ++i) {
396           // It is OK for a matrix to be singular (signaled by info > 0),
397           // corresponding to determinant of zero, but we do want to catch
398           // invalid arguments to Getrf{Batched}.
399           OP_REQUIRES_ASYNC(
400               context, host_infos[0](i) >= 0,
401               errors::InvalidArgument("Invalid input argument no. ",
402                                       host_infos[0].data()[i],
403                                       " for batch index ", i, "."),
404               done);
405         }
406       }
407       done();
408     };
409     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
410                                                     std::move(info_checker));
411   }
412 };
413 
414 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<float>), float);
415 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<double>), double);
416 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex64>),
417                        complex64);
418 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex128>),
419                        complex128);
420 
421 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<float>),
422                        float);
423 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<double>),
424                        double);
425 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<complex64>),
426                        complex64);
427 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant",
428                        (LogDeterminantOpGpu<complex128>), complex128);
429 #endif  // GOOGLE_CUDA
430 
431 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
432 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
433 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
434 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
435                    complex128);
436 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
437 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
438 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
439                    complex64);
440 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
441                    complex128);
442 
443 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<float>), float);
444 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<double>), double);
445 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex64>),
446                    complex64);
447 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex128>),
448                    complex128);
449 }  // namespace tensorflow
450