1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include "tensorflow/core/framework/kernel_def_builder.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/linalg_ops_common.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 static const char kErrMsg[] = "The matrix is not invertible."; 29 30 template <class Scalar> 31 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> { 32 public: 33 INHERIT_LINALG_TYPEDEFS(Scalar); 34 TridiagonalSolveOp(OpKernelConstruction * context)35 explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) {} 36 ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const37 void ValidateInputMatrixShapes( 38 OpKernelContext* context, 39 const TensorShapes& input_matrix_shapes) const final { 40 auto num_inputs = input_matrix_shapes.size(); 41 OP_REQUIRES(context, num_inputs == 2, 42 errors::InvalidArgument("Expected two input matrices, got ", 43 num_inputs, ".")); 44 45 auto num_diags = input_matrix_shapes[0].dim_size(0); 46 OP_REQUIRES( 47 context, num_diags == 3, 48 errors::InvalidArgument("Expected diagonals to be provided as a " 49 "matrix with 3 rows, got ", 50 num_diags, " rows.")); 51 52 auto num_eqs_left = input_matrix_shapes[0].dim_size(1); 53 auto num_eqs_right = input_matrix_shapes[1].dim_size(0); 54 OP_REQUIRES( 55 context, num_eqs_left == num_eqs_right, 56 errors::InvalidArgument("Expected the same number of left-hand sides " 57 "and right-hand sides, got ", 58 num_eqs_left, " and ", num_eqs_right, ".")); 59 } 60 GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const61 TensorShapes GetOutputMatrixShapes( 62 const TensorShapes& input_matrix_shapes) const final { 63 return TensorShapes({input_matrix_shapes[1]}); 64 } 65 GetCostPerUnit(const TensorShapes & input_matrix_shapes) const66 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 67 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); 68 const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0)); 69 70 const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>(); 71 const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>(); 72 const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>(); 73 74 // Assuming cases with and without row interchange are equiprobable. 75 const double cost = 76 num_eqs * (div_cost * (num_rhss + 1) + 77 (add_cost + mult_cost) * (2.5 * num_rhss + 1.5)); 78 return cost >= static_cast<double>(kint64max) ? kint64max 79 : static_cast<int64>(cost); 80 } 81 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)82 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 83 MatrixMaps* outputs) final { 84 const auto diagonals = inputs[0]; 85 86 // Subdiagonal elements, first is ignored. 87 const auto& superdiag = diagonals.row(0); 88 // Diagonal elements. 89 const auto& diag = diagonals.row(1); 90 // Superdiagonal elements, n-th is ignored. 91 const auto& subdiag = diagonals.row(2); 92 // Right-hand sides (transposed - necessary for GPU impl). 93 const auto& rhs = inputs[1]; 94 95 const int n = diag.size(); 96 MatrixMap& x = outputs->at(0); 97 const Scalar zero(0); 98 99 if (n == 0) { 100 return; 101 } 102 if (n == 1) { 103 OP_REQUIRES(context, diag(0) != zero, errors::InvalidArgument(kErrMsg)); 104 x.row(0) = rhs.row(0) / diag(0); 105 return; 106 } 107 108 // The three columns in u are the diagonal, superdiagonal, and second 109 // superdiagonal, respectively, of the U matrix in the LU decomposition of 110 // the input matrix (subject to row exchanges due to pivoting). For pivoted 111 // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals. 112 Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3); 113 114 // The code below roughly follows LAPACK's dgtsv routine, with main 115 // difference being not overwriting the input. 116 u(0, 0) = diag(0); 117 u(0, 1) = superdiag(0); 118 x.row(0) = rhs.row(0); 119 for (int i = 0; i < n - 1; ++i) { 120 if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) { 121 // No row interchange. 122 OP_REQUIRES(context, u(i) != zero, errors::InvalidArgument(kErrMsg)); 123 const Scalar factor = subdiag(i + 1) / u(i, 0); 124 u(i + 1, 0) = diag(i + 1) - factor * u(i, 1); 125 x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i); 126 if (i != n - 2) { 127 u(i + 1, 1) = superdiag(i + 1); 128 u(i, 2) = 0; 129 } 130 } else { 131 // Interchange rows i and i + 1. 132 const Scalar factor = u(i, 0) / subdiag(i + 1); 133 u(i, 0) = subdiag(i + 1); 134 u(i + 1, 0) = u(i, 1) - factor * diag(i + 1); 135 u(i, 1) = diag(i + 1); 136 x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1); 137 x.row(i) = rhs.row(i + 1); 138 if (i != n - 2) { 139 u(i, 2) = superdiag(i + 1); 140 u(i + 1, 1) = -factor * superdiag(i + 1); 141 } 142 } 143 } 144 x.row(n - 1) /= u(n - 1, 0); 145 x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0); 146 for (int i = n - 3; i >= 0; --i) { 147 x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) / 148 u(i, 0); 149 } 150 } 151 152 private: 153 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp); 154 }; 155 156 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float); 157 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>), 158 double); 159 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>), 160 complex64); 161 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>), 162 complex128); 163 } // namespace tensorflow 164