1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #include <cmath>
19
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/kernels/linalg/determinant_op.h"
24 #endif
25
26 #include "third_party/eigen3/Eigen/LU"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/numeric_types.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35
36 #if GOOGLE_CUDA
37 #include "tensorflow/core/kernels/fill_functor.h"
38 #include "tensorflow/core/util/cuda_solvers.h"
39 #endif
40
41 namespace tensorflow {
42
43 // A helper function to compute the sign and absolute value of the log of the
44 // determinant of inputs via a partially pivoted LU
45 // factorization.
46 //
47 // Returns the log of the absolute value of the determinant, and its sign in
48 // 'sign'.
49 template <class Scalar>
SLogDet(const Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> & inputs,Scalar * sign)50 static typename Eigen::NumTraits<Scalar>::Real SLogDet(
51 const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>& inputs,
52 Scalar* sign) {
53 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
54 RealScalar log_abs_det = 0;
55 *sign = 1;
56 // An empty matrix' determinant is defined to be 1.
57 // (https://en.wikipedia.org/wiki/Determinant)
58 if (inputs.size() > 0) {
59 // Compute the log determinant through a Partially Pivoted LU decomposition
60 using Eigen::Dynamic;
61 Eigen::PartialPivLU<Eigen::Matrix<Scalar, Dynamic, Dynamic>> lu(inputs);
62 Eigen::Matrix<Scalar, Dynamic, Dynamic> LU = lu.matrixLU();
63 *sign = lu.permutationP().determinant();
64 auto diag = LU.diagonal().array().eval();
65 auto abs_diag = diag.cwiseAbs().eval();
66 log_abs_det += abs_diag.log().sum();
67 *sign *= (diag / abs_diag).prod();
68 }
69 if (!Eigen::numext::isfinite(log_abs_det)) {
70 *sign = 0;
71 log_abs_det =
72 log_abs_det > 0 ? -std::log(RealScalar(0)) : std::log(RealScalar(0));
73 }
74 return log_abs_det;
75 }
76
77 template <class Scalar>
78 class LogDeterminantOp : public LinearAlgebraOp<Scalar> {
79 public:
80 INHERIT_LINALG_TYPEDEFS(Scalar);
81
LogDeterminantOp(OpKernelConstruction * context)82 explicit LogDeterminantOp(OpKernelConstruction* context) : Base(context) {}
83
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const84 TensorShapes GetOutputMatrixShapes(
85 const TensorShapes& input_matrix_shapes) const final {
86 return TensorShapes({TensorShape({}), TensorShape({})});
87 }
88
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)89 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
90 MatrixMaps* outputs) final {
91 Scalar sign;
92 const RealScalar log_abs_det = SLogDet(
93 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
94 &sign);
95
96 outputs->at(0)(0, 0) = sign;
97 outputs->at(1)(0, 0) = log_abs_det;
98 }
99 };
100
101 template <class Scalar>
102 class DeterminantOp : public LinearAlgebraOp<Scalar> {
103 public:
104 INHERIT_LINALG_TYPEDEFS(Scalar);
105
DeterminantOp(OpKernelConstruction * context)106 explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {}
107
GetOutputMatrixShapes(const TensorShapes & input_matrix_shape) const108 TensorShapes GetOutputMatrixShapes(
109 const TensorShapes& input_matrix_shape) const final {
110 return TensorShapes({TensorShape({})});
111 }
112
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)113 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
114 MatrixMaps* outputs) final {
115 Scalar sign;
116 const RealScalar log_abs_det = SLogDet(
117 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
118 &sign);
119 outputs->at(0)(0, 0) = sign * std::exp(log_abs_det);
120 }
121 };
122
123 #if GOOGLE_CUDA
124
125 typedef Eigen::GpuDevice GPUDevice;
126
127 template <class Scalar>
128 class DeterminantOpGpu : public AsyncOpKernel {
129 public:
DeterminantOpGpu(OpKernelConstruction * context)130 explicit DeterminantOpGpu(OpKernelConstruction* context)
131 : AsyncOpKernel(context) {}
132
ComputeAsync(OpKernelContext * context,DoneCallback done)133 void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
134 const Tensor& input = context->input(0);
135 const int ndims = input.dims();
136 const int64 n = input.dim_size(ndims - 1);
137 // Validate inputs.
138 OP_REQUIRES_ASYNC(
139 context, ndims >= 2,
140 errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
141 done);
142 OP_REQUIRES_ASYNC(
143 context, input.dim_size(ndims - 2) == n,
144 errors::InvalidArgument("Input matrices must be square, got",
145 input.dim_size(ndims - 2), " != ", n),
146 done);
147
148 // Allocate output.
149 TensorShape out_shape;
150 for (int dim = 0; dim < ndims - 2; ++dim) {
151 out_shape.AddDim(input.dim_size(dim));
152 }
153 out_shape.AppendShape(TensorShape({}));
154 Tensor* out;
155 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &out),
156 done);
157
158 // By definition, the determinant of an empty matrix is equal to one.
159 const GPUDevice& d = context->eigen_device<GPUDevice>();
160 if (input.NumElements() == 0) {
161 functor::SetOneFunctor<GPUDevice, Scalar> f;
162 f(d, out->template flat<Scalar>());
163 done();
164 return;
165 }
166
167 // TODO(rmlarsen): Convert to absl::make_unique when available.
168 std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
169
170 // Reuse the input buffer or make a copy for the factorization step,
171 // depending on whether this ops owns it exclusively.
172 Tensor input_copy;
173 OP_REQUIRES_OK_ASYNC(
174 context,
175 solver->forward_input_or_allocate_scoped_tensor(
176 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
177 done);
178 if (!input.SharesBufferWith(input_copy)) {
179 d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
180 input.NumElements() * sizeof(Scalar));
181 }
182 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
183 const int64 batch_size = input_copy_reshaped.dimension(0);
184
185 // Allocate pivots on the device.
186 Tensor pivots;
187 OP_REQUIRES_OK_ASYNC(
188 context,
189 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
190 TensorShape{batch_size, n}, &pivots),
191 done);
192 auto pivots_mat = pivots.template matrix<int>();
193
194 // Prepare pointer arrays for cuBlas' batch interface.
195 // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
196 // without the ugly casting.
197 auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
198 sizeof(Scalar*) * batch_size, "input_copy_ptrs",
199 /* on_host */ true);
200 auto output_reshaped = out->template flat_inner_dims<Scalar, 1>();
201
202 // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
203 std::vector<DeviceLapackInfo> dev_info;
204 if (n / batch_size <= 128) {
205 // For small matrices or large batch sizes, we use the batched interface
206 // from cuBlas.
207 const Scalar** input_copy_ptrs_base =
208 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
209 for (int batch = 0; batch < batch_size; ++batch) {
210 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
211 }
212 dev_info.push_back(
213 solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
214 OP_REQUIRES_OK_ASYNC(
215 context,
216 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
217 &dev_info.back(), batch_size),
218 done);
219 } else {
220 // For small batch sizes we use the non-batched interface from cuSolver,
221 // which is much faster for large matrices.
222 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
223 for (int batch = 0; batch < batch_size; ++batch) {
224 OP_REQUIRES_OK_ASYNC(
225 context,
226 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
227 &pivots_mat(batch, 0), &dev_info.back()(batch)),
228 done);
229 }
230 }
231
232 // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
233 // where s is the order of the permutation encoded in pivots and U is the
234 // upper triangular factor of the LU factorization, which is written to
235 // input_copy by the Getrf{Batched} kernel.
236 functor::DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
237 functor(d,
238 const_cast<const Tensor*>(&input_copy)
239 ->template flat_inner_dims<Scalar, 3>(),
240 pivots_mat.data(), output_reshaped, dev_info.back().mutable_data());
241
242 // Register callback to check info after kernels finish.
243 auto info_checker = [context, done](
244 const Status& status,
245 const std::vector<HostLapackInfo>& host_infos) {
246 if (!status.ok() && errors::IsInvalidArgument(status) &&
247 !host_infos.empty()) {
248 for (int i = 0; i < host_infos[0].size(); ++i) {
249 // It is OK for a matrix to be singular (signaled by info > 0),
250 // corresponding to determinant of zero, but we do want to catch
251 // invalid arguments to Getrf{Batched}.
252 OP_REQUIRES_ASYNC(
253 context, host_infos[0](i) >= 0,
254 errors::InvalidArgument("Invalid input argument no. ",
255 host_infos[0].data()[i],
256 " for batch index ", i, "."),
257 done);
258 }
259 }
260 done();
261 };
262 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
263 std::move(info_checker));
264 }
265 };
266
267 template <class Scalar>
268 class LogDeterminantOpGpu : public AsyncOpKernel {
269 public:
LogDeterminantOpGpu(OpKernelConstruction * context)270 explicit LogDeterminantOpGpu(OpKernelConstruction* context)
271 : AsyncOpKernel(context) {}
272
ComputeAsync(OpKernelContext * context,DoneCallback done)273 void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
274 const Tensor& input = context->input(0);
275 const int ndims = input.dims();
276 const int64 n = input.dim_size(ndims - 1);
277 // Validate inputs.
278 OP_REQUIRES_ASYNC(
279 context, ndims >= 2,
280 errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
281 done);
282 OP_REQUIRES_ASYNC(
283 context, input.dim_size(ndims - 2) == n,
284 errors::InvalidArgument("Input matrices must be square, got",
285 input.dim_size(ndims - 2), " != ", n),
286 done);
287
288 // Allocate output.
289 TensorShape out_shape;
290 for (int dim = 0; dim < ndims - 2; ++dim) {
291 out_shape.AddDim(input.dim_size(dim));
292 }
293 out_shape.AppendShape(TensorShape({}));
294 Tensor* sign;
295 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &sign),
296 done);
297 Tensor* log_abs_det;
298 OP_REQUIRES_OK_ASYNC(
299 context, context->allocate_output(1, out_shape, &log_abs_det), done);
300
301 // By definition, the determinant of an empty matrix is equal to one.
302 const GPUDevice& d = context->eigen_device<GPUDevice>();
303 if (input.NumElements() == 0) {
304 functor::SetOneFunctor<GPUDevice, Scalar> one_func;
305 one_func(d, sign->template flat<Scalar>());
306 functor::SetZeroFunctor<GPUDevice, Scalar> zero_func;
307 zero_func(d, log_abs_det->template flat<Scalar>());
308 done();
309 return;
310 }
311
312 // TODO(rmlarsen): Convert to absl::make_unique when available.
313 std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
314
315 // Reuse the input buffer or make a copy for the factorization step,
316 // depending on whether this ops owns it exclusively.
317 Tensor input_copy;
318 OP_REQUIRES_OK_ASYNC(
319 context,
320 solver->forward_input_or_allocate_scoped_tensor(
321 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
322 done);
323 if (!input.SharesBufferWith(input_copy)) {
324 d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
325 input.NumElements() * sizeof(Scalar));
326 }
327 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
328 const int64 batch_size = input_copy_reshaped.dimension(0);
329
330 // Allocate pivots on the device.
331 Tensor pivots;
332 OP_REQUIRES_OK_ASYNC(
333 context,
334 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
335 TensorShape{batch_size, n}, &pivots),
336 done);
337 auto pivots_mat = pivots.template matrix<int>();
338
339 // Prepare pointer arrays for cuBlas' batch interface.
340 // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
341 // without the ugly casting.
342 auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
343 sizeof(Scalar*) * batch_size, "input_copy_ptrs",
344 /* on_host */ true);
345
346 // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
347 std::vector<DeviceLapackInfo> dev_info;
348 if (n / batch_size <= 128) {
349 // For small matrices or large batch sizes, we use the batched interface
350 // from cuBlas.
351 const Scalar** input_copy_ptrs_base =
352 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
353 for (int batch = 0; batch < batch_size; ++batch) {
354 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
355 }
356 dev_info.push_back(
357 solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
358 OP_REQUIRES_OK_ASYNC(
359 context,
360 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
361 &dev_info.back(), batch_size),
362 done);
363 } else {
364 // For large matrices or small batch sizes we use the non-batched
365 // interface from cuSolver, which is much faster for large matrices.
366 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
367 for (int batch = 0; batch < batch_size; ++batch) {
368 OP_REQUIRES_OK_ASYNC(
369 context,
370 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
371 &pivots_mat(batch, 0), &dev_info.back()(batch)),
372 done);
373 }
374 }
375
376 auto input_copy_reshaped_const =
377 const_cast<const Tensor*>(&input_copy)
378 ->template flat_inner_dims<Scalar, 3>();
379 auto sign_reshaped = sign->flat<Scalar>();
380 auto log_abs_det_reshaped = log_abs_det->flat<Scalar>();
381 // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
382 // where s is the order of the permutation encoded in pivots and U is the
383 // upper triangular factor of the LU factorization, which is written to
384 // input_copy by the Getrf{Batched} kernel.
385 functor::LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
386 functor(d, input_copy_reshaped_const, pivots_mat.data(), sign_reshaped,
387 log_abs_det_reshaped);
388
389 // Register callback to check info after kernels finish.
390 auto info_checker = [context, done](
391 const Status& status,
392 const std::vector<HostLapackInfo>& host_infos) {
393 if (!status.ok() && errors::IsInvalidArgument(status) &&
394 !host_infos.empty()) {
395 for (int i = 0; i < host_infos[0].size(); ++i) {
396 // It is OK for a matrix to be singular (signaled by info > 0),
397 // corresponding to determinant of zero, but we do want to catch
398 // invalid arguments to Getrf{Batched}.
399 OP_REQUIRES_ASYNC(
400 context, host_infos[0](i) >= 0,
401 errors::InvalidArgument("Invalid input argument no. ",
402 host_infos[0].data()[i],
403 " for batch index ", i, "."),
404 done);
405 }
406 }
407 done();
408 };
409 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
410 std::move(info_checker));
411 }
412 };
413
414 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<float>), float);
415 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<double>), double);
416 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex64>),
417 complex64);
418 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex128>),
419 complex128);
420
421 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<float>),
422 float);
423 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<double>),
424 double);
425 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<complex64>),
426 complex64);
427 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant",
428 (LogDeterminantOpGpu<complex128>), complex128);
429 #endif // GOOGLE_CUDA
430
431 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
432 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
433 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
434 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
435 complex128);
436 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
437 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
438 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
439 complex64);
440 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
441 complex128);
442
443 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<float>), float);
444 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<double>), double);
445 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex64>),
446 complex64);
447 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex128>),
448 complex128);
449 } // namespace tensorflow
450