1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #ifdef GOOGLE_CUDA
19 
20 #define EIGEN_USE_GPU
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/util/cuda_solvers.h"
30 #include "tensorflow/core/util/cuda_sparse.h"
31 #include "tensorflow/core/util/gpu_device_functions.h"
32 #include "tensorflow/core/util/gpu_kernel_helper.h"
33 #include "tensorflow/core/util/gpu_launch_config.h"
34 
35 namespace tensorflow {
36 
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38 
39 static const char kNotInvertibleScalarMsg[] =
40     "The matrix is not invertible: it is a scalar with value zero.";
41 
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x,bool * __restrict__ not_invertible)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44                                            const Scalar* __restrict__ diags,
45                                            const Scalar* __restrict__ rhs,
46                                            const int num_rhs,
47                                            Scalar* __restrict__ x,
48                                            bool* __restrict__ not_invertible) {
49   if (m == 1) {
50     if (diags[1] == Scalar(0)) {
51       *not_invertible = true;
52       return;
53     }
54     for (int i : GpuGridRangeX(num_rhs)) {
55       x[i] = rhs[i] / diags[1];
56     }
57   } else {
58     Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
59     if (det == Scalar(0)) {
60       *not_invertible = true;
61       return;
62     }
63     for (int i : GpuGridRangeX(num_rhs)) {
64       x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
65       x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
66     }
67   }
68 }
69 
70 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)71 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
72   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
73   se::DeviceMemory<Scalar> typed(wrapped);
74   return typed;
75 }
76 
77 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)78 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
79                         Scalar* dst, const int num_elements) {
80   auto src_device_mem = AsDeviceMemory(src);
81   auto dst_device_mem = AsDeviceMemory(dst);
82   auto* stream = context->op_device_context()->stream();
83   bool copy_status = stream
84                          ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
85                                          sizeof(Scalar) * num_elements)
86                          .ok();
87 
88   if (!copy_status) {
89     context->SetStatus(errors::Internal("Copying device-to-device failed."));
90   }
91 }
92 
93 // This implementation is used in cases when the batching mechanism of
94 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
95 template <class Scalar>
96 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
97  public:
98   INHERIT_LINALG_TYPEDEFS(Scalar);
99 
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)100   explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
101       : Base(context) {
102     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
103   }
104 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const105   void ValidateInputMatrixShapes(
106       OpKernelContext* context,
107       const TensorShapes& input_matrix_shapes) const final {
108     auto num_inputs = input_matrix_shapes.size();
109     OP_REQUIRES(context, num_inputs == 2,
110                 errors::InvalidArgument("Expected two input matrices, got ",
111                                         num_inputs, "."));
112 
113     auto num_diags = input_matrix_shapes[0].dim_size(0);
114     OP_REQUIRES(
115         context, num_diags == 3,
116         errors::InvalidArgument("Expected diagonals to be provided as a "
117                                 "matrix with 3 columns, got ",
118                                 num_diags, " columns."));
119 
120     auto num_rows1 = input_matrix_shapes[0].dim_size(1);
121     auto num_rows2 = input_matrix_shapes[1].dim_size(0);
122     OP_REQUIRES(context, num_rows1 == num_rows2,
123                 errors::InvalidArgument("Expected same number of rows in both "
124                                         "arguments, got ",
125                                         num_rows1, " and ", num_rows2, "."));
126   }
127 
EnableInputForwarding() const128   bool EnableInputForwarding() const final { return false; }
129 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const130   TensorShapes GetOutputMatrixShapes(
131       const TensorShapes& input_matrix_shapes) const final {
132     return TensorShapes({input_matrix_shapes[1]});
133   }
134 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)135   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
136                      MatrixMaps* outputs) final {
137     const auto diagonals = inputs[0];
138     // Superdiagonal elements, first is ignored.
139     const auto& superdiag = diagonals.row(0);
140     // Diagonal elements.
141     const auto& diag = diagonals.row(1);
142     // Subdiagonal elements, last is ignored.
143     const auto& subdiag = diagonals.row(2);
144     // Right-hand sides.
145     const auto& rhs = inputs[1];
146     MatrixMap& x = outputs->at(0);
147     const int m = diag.size();
148     const int k = rhs.cols();
149 
150     if (m == 0) {
151       return;
152     }
153     if (m < 3) {
154       // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
155       SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
156                            k);
157       return;
158     }
159     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
160     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
161     if (k == 1) {
162       // rhs is copied into x, then gtsv replaces x with solution.
163       CopyDeviceToDevice(context, rhs.data(), x.data(), m);
164       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
165                     subdiag.data(), x.data(), m, 1);
166     } else {
167       // Gtsv expects rhs in column-major form, so we have to transpose.
168       // rhs is transposed into temp, gtsv replaces temp with solution, then
169       // temp is transposed into x.
170       std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
171       Tensor temp;
172       TensorShape temp_shape({k, m});
173       OP_REQUIRES_OK(context,
174                      cublas_solver->allocate_scoped_tensor(
175                          DataTypeToEnum<Scalar>::value, temp_shape, &temp));
176       TransposeWithGeam(context, cublas_solver, rhs.data(),
177                         temp.flat<Scalar>().data(), m, k);
178       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
179                     subdiag.data(), temp.flat<Scalar>().data(), m, k);
180       TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
181                         x.data(), k, m);
182     }
183   }
184 
185  private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<CudaSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const186   void TransposeWithGeam(OpKernelContext* context,
187                          const std::unique_ptr<CudaSolver>& cublas_solver,
188                          const Scalar* src, Scalar* dst, const int src_rows,
189                          const int src_cols) const {
190     const Scalar zero(0), one(1);
191     OP_REQUIRES_OK(context,
192                    cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
193                                        src_cols, &one, src, src_cols, &zero,
194                                        static_cast<const Scalar*>(nullptr),
195                                        src_rows, dst, src_rows));
196   }
197 
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const198   void SolveWithGtsv(OpKernelContext* context,
199                      std::unique_ptr<GpuSparse>& cusparse_solver,
200                      const Scalar* superdiag, const Scalar* diag,
201                      const Scalar* subdiag, Scalar* rhs, const int num_eqs,
202                      const int num_rhs) const {
203     auto buffer_function = pivoting_
204                                ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
205                                : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
206     size_t buffer_size;
207     OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
208                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
209                                 num_eqs, &buffer_size));
210     Tensor temp_tensor;
211     TensorShape temp_shape({static_cast<int64>(buffer_size)});
212     OP_REQUIRES_OK(context,
213                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
214     void* buffer = temp_tensor.flat<std::uint8_t>().data();
215 
216     auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
217                                      : &GpuSparse::Gtsv2NoPivot<Scalar>;
218     OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
219                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
220                                 num_eqs, buffer));
221   }
222 
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)223   void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
224                             const Scalar* rhs, Scalar* output, int m, int k) {
225     const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
226     GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
227     bool* not_invertible_dev;
228     cudaMalloc(&not_invertible_dev, sizeof(bool));
229     TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
230                                 cfg.block_count, cfg.thread_per_block, 0,
231                                 device.stream(), m, diagonals, rhs, k, output,
232                                 not_invertible_dev));
233     bool not_invertible_host;
234     cudaMemcpy(&not_invertible_host, not_invertible_dev, sizeof(bool),
235                cudaMemcpyDeviceToHost);
236     cudaFree(not_invertible_dev);
237     OP_REQUIRES(context, !not_invertible_host,
238                 errors::InvalidArgument(m == 1 ? kNotInvertibleScalarMsg
239                                                : kNotInvertibleMsg));
240   }
241 
242   bool pivoting_;
243 };
244 
245 template <class Scalar>
246 class TridiagonalSolveOpGpu : public OpKernel {
247  public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)248   explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
249       : OpKernel(context), linalgOp_(context) {
250     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
251   }
252 
Compute(OpKernelContext * context)253   void Compute(OpKernelContext* context) final {
254     const Tensor& lhs = context->input(0);
255     const Tensor& rhs = context->input(1);
256     const int ndims = lhs.dims();
257     const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
258     const int64 matrix_size = lhs.dim_size(ndims - 1);
259     int64 batch_size = 1;
260     for (int i = 0; i < ndims - 2; i++) {
261       batch_size *= lhs.dim_size(i);
262     }
263 
264     // The batching mechanism of LinearAlgebraOp is used when it's not
265     // possible or desirable to use GtsvBatched.
266     const bool use_linalg_op =
267         pivoting_            // GtsvBatched doesn't do pivoting
268         || num_rhs > 1       // GtsvBatched doesn't support multiple rhs
269         || matrix_size < 3   // Not supported in cuSparse, use the custom kernel
270         || batch_size == 1;  // No point to use GtsvBatched
271 
272     if (use_linalg_op) {
273       linalgOp_.Compute(context);
274     } else {
275       ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
276     }
277   }
278 
279  private:
280   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
281 
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)282   void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
283                               const Tensor& rhs, const int batch_size) {
284     const Scalar* rhs_data = rhs.flat<Scalar>().data();
285     const int ndims = lhs.dims();
286 
287     // To use GtsvBatched we need to transpose the left-hand side from shape
288     // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
289     // between corresponding diagonal elements of consecutive batch components
290     // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
291     // GtsvBatched requires the strides to be the same. For this reason we
292     // transpose into [3, ..., M], so that diagonals, superdiagonals, and
293     // and subdiagonals are separated from each other, and have stride M.
294     Tensor lhs_transposed;
295     TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
296     int matrix_size = lhs.dim_size(ndims - 1);
297     const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
298     const Scalar* superdiag = lhs_data;
299     const Scalar* diag = lhs_data + matrix_size * batch_size;
300     const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
301 
302     // Copy right-hand side into the output. GtsvBatched will replace it with
303     // the solution.
304     Tensor* output;
305     OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
306     CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
307                        rhs.flat<Scalar>().size());
308     Scalar* x = output->flat<Scalar>().data();
309 
310     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
311 
312     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
313 
314     size_t buffer_size;
315     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
316                                 matrix_size, subdiag, diag, superdiag, x,
317                                 batch_size, matrix_size, &buffer_size));
318     Tensor temp_tensor;
319     TensorShape temp_shape({static_cast<int64>(buffer_size)});
320     OP_REQUIRES_OK(context,
321                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
322     void* buffer = temp_tensor.flat<std::uint8_t>().data();
323     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
324                                 matrix_size, subdiag, diag, superdiag, x,
325                                 batch_size, matrix_size, buffer));
326   }
327 
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)328   void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
329                                   Tensor& lhs_transposed) {
330     const int ndims = lhs.dims();
331 
332     // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
333     // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
334     std::vector<int> perm(ndims);
335     perm[0] = ndims - 2;
336     for (int i = 0; i < ndims - 2; ++i) {
337       perm[i + 1] = i;
338     }
339     perm[ndims - 1] = ndims - 1;
340 
341     std::vector<int64> dims;
342     for (int index : perm) {
343       dims.push_back(lhs.dim_size(index));
344     }
345     TensorShape lhs_transposed_shape(
346         gtl::ArraySlice<int64>(dims.data(), ndims));
347 
348     std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
349     OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
350                                 DataTypeToEnum<Scalar>::value,
351                                 lhs_transposed_shape, &lhs_transposed));
352     auto device = context->eigen_device<Eigen::GpuDevice>();
353     OP_REQUIRES_OK(
354         context,
355         DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
356                     &lhs_transposed));
357   }
358 
359   TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
360   bool pivoting_;
361 };
362 
363 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
364                        float);
365 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
366                        double);
367 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
368                        complex64);
369 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
370                        complex128);
371 
372 }  // namespace tensorflow
373 
374 #endif  // GOOGLE_CUDA
375