1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include "tensorflow/core/framework/kernel_def_builder.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/linalg_ops_common.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 static const char kErrMsg[] = "The matrix is not invertible.";
29 
30 template <class Scalar>
31 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> {
32  public:
33   INHERIT_LINALG_TYPEDEFS(Scalar);
34 
TridiagonalSolveOp(OpKernelConstruction * context)35   explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) {}
36 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const37   void ValidateInputMatrixShapes(
38       OpKernelContext* context,
39       const TensorShapes& input_matrix_shapes) const final {
40     auto num_inputs = input_matrix_shapes.size();
41     OP_REQUIRES(context, num_inputs == 2,
42                 errors::InvalidArgument("Expected two input matrices, got ",
43                                         num_inputs, "."));
44 
45     auto num_diags = input_matrix_shapes[0].dim_size(0);
46     OP_REQUIRES(
47         context, num_diags == 3,
48         errors::InvalidArgument("Expected diagonals to be provided as a "
49                                 "matrix with 3 rows, got ",
50                                 num_diags, " rows."));
51 
52     auto num_eqs_left = input_matrix_shapes[0].dim_size(1);
53     auto num_eqs_right = input_matrix_shapes[1].dim_size(0);
54     OP_REQUIRES(
55         context, num_eqs_left == num_eqs_right,
56         errors::InvalidArgument("Expected the same number of left-hand sides "
57                                 "and right-hand sides, got ",
58                                 num_eqs_left, " and ", num_eqs_right, "."));
59   }
60 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const61   TensorShapes GetOutputMatrixShapes(
62       const TensorShapes& input_matrix_shapes) const final {
63     return TensorShapes({input_matrix_shapes[1]});
64   }
65 
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const66   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
67     const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1));
68     const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0));
69 
70     const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>();
71     const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>();
72     const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>();
73 
74     // Assuming cases with and without row interchange are equiprobable.
75     const double cost =
76         num_eqs * (div_cost * (num_rhss + 1) +
77                    (add_cost + mult_cost) * (2.5 * num_rhss + 1.5));
78     return cost >= static_cast<double>(kint64max) ? kint64max
79                                                   : static_cast<int64>(cost);
80   }
81 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)82   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
83                      MatrixMaps* outputs) final {
84     const auto diagonals = inputs[0];
85 
86     // Subdiagonal elements, first is ignored.
87     const auto& superdiag = diagonals.row(0);
88     // Diagonal elements.
89     const auto& diag = diagonals.row(1);
90     // Superdiagonal elements, n-th is ignored.
91     const auto& subdiag = diagonals.row(2);
92     // Right-hand sides (transposed - necessary for GPU impl).
93     const auto& rhs = inputs[1];
94 
95     const int n = diag.size();
96     MatrixMap& x = outputs->at(0);
97     const Scalar zero(0);
98 
99     if (n == 0) {
100       return;
101     }
102     if (n == 1) {
103       OP_REQUIRES(context, diag(0) != zero, errors::InvalidArgument(kErrMsg));
104       x.row(0) = rhs.row(0) / diag(0);
105       return;
106     }
107 
108     // The three columns in u are the diagonal, superdiagonal, and second
109     // superdiagonal, respectively, of the U matrix in the LU decomposition of
110     // the input matrix (subject to row exchanges due to pivoting). For pivoted
111     // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals.
112     Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3);
113 
114     // The code below roughly follows LAPACK's dgtsv routine, with main
115     // difference being not overwriting the input.
116     u(0, 0) = diag(0);
117     u(0, 1) = superdiag(0);
118     x.row(0) = rhs.row(0);
119     for (int i = 0; i < n - 1; ++i) {
120       if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) {
121         // No row interchange.
122         OP_REQUIRES(context, u(i) != zero, errors::InvalidArgument(kErrMsg));
123         const Scalar factor = subdiag(i + 1) / u(i, 0);
124         u(i + 1, 0) = diag(i + 1) - factor * u(i, 1);
125         x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i);
126         if (i != n - 2) {
127           u(i + 1, 1) = superdiag(i + 1);
128           u(i, 2) = 0;
129         }
130       } else {
131         // Interchange rows i and i + 1.
132         const Scalar factor = u(i, 0) / subdiag(i + 1);
133         u(i, 0) = subdiag(i + 1);
134         u(i + 1, 0) = u(i, 1) - factor * diag(i + 1);
135         u(i, 1) = diag(i + 1);
136         x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1);
137         x.row(i) = rhs.row(i + 1);
138         if (i != n - 2) {
139           u(i, 2) = superdiag(i + 1);
140           u(i + 1, 1) = -factor * superdiag(i + 1);
141         }
142       }
143     }
144     x.row(n - 1) /= u(n - 1, 0);
145     x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0);
146     for (int i = n - 3; i >= 0; --i) {
147       x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) /
148                  u(i, 0);
149     }
150   }
151 
152  private:
153   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp);
154 };
155 
156 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float);
157 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>),
158                        double);
159 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>),
160                        complex64);
161 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>),
162                        complex128);
163 }  // namespace tensorflow
164