1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/Eigen/Core" 23 #include "third_party/eigen3/Eigen/LU" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/kernels/linalg_ops_common.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 #if GOOGLE_CUDA 34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 35 #include "tensorflow/core/kernels/cuda_solvers.h" 36 #include "tensorflow/core/kernels/eye_functor.h" 37 #include "tensorflow/core/kernels/transpose_functor.h" 38 #endif 39 40 namespace tensorflow { 41 42 template <class Scalar> 43 class MatrixInverseOp : public LinearAlgebraOp<Scalar> { 44 public: 45 INHERIT_LINALG_TYPEDEFS(Scalar); 46 MatrixInverseOp(OpKernelConstruction * context)47 explicit MatrixInverseOp(OpKernelConstruction* context) : Base(context) { 48 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 49 } 50 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)51 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 52 MatrixMaps* outputs) final { 53 const ConstMatrixMap& input = inputs[0]; 54 if (input.rows() == 0) { 55 // By definition, an empty matrix's inverse is an empty matrix. 56 return; 57 } 58 Eigen::PartialPivLU<Matrix> lu_decomposition; 59 if (adjoint_) { 60 // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy. 61 // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/ 62 // bd2219a74c96dfe3f6bc2c23588749e36d2d8173 63 lu_decomposition.compute(input.adjoint()); 64 } else { 65 lu_decomposition.compute(input); 66 } 67 // TODO(rmlarsen): Add check based on condition number estimation. 68 // PartialPivLU cannot give strong guarantees on invertibility, but 69 // we can at least guard against exact zero pivots. This can occur as 70 // a result of basic user mistakes, such as providing integer valued 71 // matrices that are exactly singular, or due to underflow if this 72 // code is run with denormals being flushed to zero. 73 const RealScalar min_abs_pivot = 74 lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); 75 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 76 errors::InvalidArgument("Input is not invertible.")); 77 outputs->at(0).noalias() = lu_decomposition.inverse(); 78 } 79 80 private: 81 bool adjoint_; 82 83 TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp); 84 }; 85 86 #if GOOGLE_CUDA 87 88 typedef Eigen::GpuDevice GPUDevice; 89 90 template <class Scalar> 91 class MatrixInverseOpGpu : public AsyncOpKernel { 92 public: MatrixInverseOpGpu(OpKernelConstruction * context)93 explicit MatrixInverseOpGpu(OpKernelConstruction* context) 94 : AsyncOpKernel(context) { 95 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 96 } 97 ComputeAsync(OpKernelContext * context,DoneCallback done)98 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 99 const Tensor& input = context->input(0); 100 const int ndims = input.dims(); 101 const int64 n = input.dim_size(ndims - 1); 102 // Validate inputs. 103 OP_REQUIRES_ASYNC( 104 context, ndims >= 2, 105 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 106 done); 107 OP_REQUIRES_ASYNC( 108 context, input.dim_size(ndims - 2) == n, 109 errors::InvalidArgument("Input matrices must be squares, got", 110 input.dim_size(ndims - 2), " != ", n), 111 done); 112 113 // By definition, an empty matrix's inverse is an empty matrix. 114 if (input.NumElements() == 0) { 115 context->set_output(0, input); 116 done(); 117 return; 118 } 119 120 // Allocate output. 121 Tensor* output; 122 OP_REQUIRES_OK_ASYNC(context, 123 context->forward_input_or_allocate_output( 124 {0}, 0, input.shape(), &output), 125 done); 126 127 // TODO(rmlarsen): Convert to std::make_unique when available. 128 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 129 130 // Make a copy of the (possible adjointed) input that we will use for the 131 // factorization step. 132 Tensor input_copy; 133 OP_REQUIRES_OK_ASYNC( 134 context, 135 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 136 input.shape(), &input_copy), 137 done); 138 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>(); 139 const GPUDevice& device = context->eigen_device<GPUDevice>(); 140 if (!adjoint_) { 141 device.memcpy(input_copy.flat<Scalar>().data(), 142 input.flat<Scalar>().data(), 143 input.NumElements() * sizeof(Scalar)); 144 } else { 145 OP_REQUIRES_OK_ASYNC( 146 context, DoConjugateMatrixTranspose(device, input, &input_copy), 147 done); 148 } 149 const int64 batch_size = input_copy_reshaped.dimension(0); 150 151 Tensor pivots; 152 OP_REQUIRES_OK_ASYNC( 153 context, 154 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value, 155 TensorShape{batch_size, n}, &pivots), 156 done); 157 auto pivots_mat = pivots.template matrix<int>(); 158 auto input_copy_ptr_array = solver->GetScratchSpace<uint8>( 159 sizeof(Scalar*) * batch_size, "input_copy_ptr_array", 160 /* on_host */ true); 161 auto output_ptr_array = solver->GetScratchSpace<uint8>( 162 sizeof(Scalar*) * batch_size, "output_copy_ptr_array", 163 /* on_host */ true); 164 auto output_reshaped = output->template flat_inner_dims<Scalar, 3>(); 165 std::vector<DeviceLapackInfo> dev_info; 166 if (n < 32 || batch_size > n) { 167 // For small matrices or very large batch sizes, we use the batched 168 // interfaces in cuBlas to avoid being dominated by kernel launch 169 // overhead. 170 // TODO(rmlarsen): Come up with a better heuristic based on a simple 171 // cost model. 172 const Scalar** input_copy_ptr_array_base = 173 reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data()); 174 const Scalar** output_ptr_array_base = 175 reinterpret_cast<const Scalar**>(output_ptr_array.mutable_data()); 176 for (int batch = 0; batch < batch_size; ++batch) { 177 input_copy_ptr_array_base[batch] = &input_copy_reshaped(batch, 0, 0); 178 output_ptr_array_base[batch] = &output_reshaped(batch, 0, 0); 179 } 180 181 if (n < 32) { 182 // MatInvBatched only supports n < 32. 183 dev_info.push_back( 184 solver->GetDeviceLapackInfo(batch_size, "MatInvBatched")); 185 OP_REQUIRES_OK_ASYNC( 186 context, 187 solver->MatInvBatched(n, input_copy_ptr_array_base, n, 188 output_ptr_array_base, n, &dev_info.back(), 189 batch_size), 190 191 done); 192 } else { 193 // For larger matrices and large batch size, we used the batched 194 // GETRF/GETRI kernels. 195 dev_info.push_back( 196 solver->GetDeviceLapackInfo(batch_size, "GetrfBatched")); 197 OP_REQUIRES_OK_ASYNC(context, 198 solver->GetrfBatched(n, input_copy_ptr_array_base, 199 n, pivots_mat.data(), 200 &dev_info.back(), batch_size), 201 done); 202 // 2. Compute the inverse(s). 203 dev_info.push_back( 204 solver->GetDeviceLapackInfo(batch_size, "GetriBatched")); 205 OP_REQUIRES_OK_ASYNC( 206 context, 207 solver->GetriBatched(n, input_copy_ptr_array_base, n, 208 pivots_mat.data(), output_ptr_array_base, n, 209 &dev_info.back(), batch_size), 210 done); 211 } 212 } else { 213 // For large matrices, we compute the inverse of each matrix in the batch 214 // sequentially. Here we use the cuSolver methods GETRF/GETRS because they 215 // are MUCH faster than their batched cuBlas equivalents for large 216 // matrices. 217 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); 218 for (int batch = 0; batch < batch_size; ++batch) { 219 OP_REQUIRES_OK_ASYNC( 220 context, 221 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, 222 &pivots_mat(batch, 0), &dev_info.back()(batch)), 223 done); 224 } 225 226 // Set all right-hand sides to the identity. 227 functor::EyeFunctor<GPUDevice, Scalar> eye; 228 eye(device, output_reshaped); 229 230 // Solve A X = I. 231 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs")); 232 for (int batch = 0; batch < batch_size; ++batch) { 233 OP_REQUIRES_OK_ASYNC( 234 context, 235 solver->Getrs(CUBLAS_OP_N, n, n, &input_copy_reshaped(batch, 0, 0), 236 n, &pivots_mat(batch, 0), 237 &output_reshaped(batch, 0, 0), n, 238 &dev_info.back()(batch)), 239 done); 240 } 241 } 242 // Callback for checking info after kernels finish. 243 auto info_checker = [context, done]( 244 const Status& status, 245 const std::vector<HostLapackInfo>& host_infos) { 246 if (!status.ok() && errors::IsInvalidArgument(status)) { 247 for (const auto& host_info : host_infos) { 248 for (int i = 0; i < host_info.size(); ++i) { 249 // Match the CPU error message for singular matrices. Otherwise 250 // just print the original error message from the call itself 251 // below. 252 OP_REQUIRES_ASYNC( 253 context, host_info(i) <= 0, 254 errors::InvalidArgument("Input is not invertible."), done); 255 } 256 } 257 } 258 OP_REQUIRES_OK_ASYNC(context, status, done); 259 done(); 260 }; 261 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 262 std::move(info_checker)); 263 } 264 265 private: 266 bool adjoint_; 267 }; 268 269 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<float>), float); 270 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<double>), double); 271 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<complex64>), 272 complex64); 273 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<complex128>), 274 complex128); 275 276 #endif // GOOGLE_CUDA 277 278 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float>), float); 279 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<double>), double); 280 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<complex64>), complex64); 281 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<complex128>), complex128); 282 REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<float>), float); 283 REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<double>), double); 284 285 } // namespace tensorflow 286