1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17 
18 #include <numeric>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 
32 namespace xla {
33 namespace tridiagonal {
34 
35 namespace {
36 
CheckSecondToLastDimension(const Shape & op_shape,int64 rank,int64 expected,const std::string & op_name)37 Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
38                                   int64 expected, const std::string& op_name) {
39   const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
40 
41   if (actual_num_dims != expected) {
42     return InvalidArgument(
43         "Second to last dimension of %s should be %d but is %d.", op_name,
44         expected, actual_num_dims);
45   }
46 
47   return Status::OK();
48 }
49 
CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)50 StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
51                                                  XlaOp main_diagonal,
52                                                  XlaOp upper_diagonal,
53                                                  XlaOp rhs) {
54   XlaBuilder* builder = lower_diagonal.builder();
55 
56   TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
57                       builder->GetShape(lower_diagonal));
58   TF_ASSIGN_OR_RETURN(Shape main_diagonal_shape,
59                       builder->GetShape(main_diagonal));
60   TF_ASSIGN_OR_RETURN(Shape upper_diagonal_shape,
61                       builder->GetShape(upper_diagonal));
62   TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs));
63 
64   const auto lower_diagonal_rank = lower_diagonal_shape.rank();
65   const auto main_diagonal_rank = main_diagonal_shape.rank();
66   const auto upper_diagonal_rank = upper_diagonal_shape.rank();
67   const auto rhs_rank = rhs_shape.rank();
68   if (!((lower_diagonal_rank == main_diagonal_rank) &&
69         (lower_diagonal_rank == upper_diagonal_rank) &&
70         (lower_diagonal_rank == rhs_rank))) {
71     return InvalidArgument(
72         "All inputs should have the same rank but got rank "
73         "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
74         "%d for rhs",
75         lower_diagonal_rank, main_diagonal_rank, upper_diagonal_rank, rhs_rank);
76   }
77   const auto rank = lower_diagonal_rank;
78   if (rank < 2) {
79     return InvalidArgument("Arguments must have rank >=2; got rank %d.", rank);
80   }
81 
82   const auto lower_diagonal_num_eqs =
83       ShapeUtil::GetDimension(lower_diagonal_shape, rank - 1);
84   const auto main_diagonal_num_eqs =
85       ShapeUtil::GetDimension(main_diagonal_shape, rank - 1);
86   const auto upper_diagonal_num_eqs =
87       ShapeUtil::GetDimension(upper_diagonal_shape, rank - 1);
88   const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1);
89   if (!((lower_diagonal_num_eqs == main_diagonal_num_eqs) &&
90         (lower_diagonal_num_eqs == upper_diagonal_num_eqs) &&
91         (lower_diagonal_num_eqs == rhs_num_eqs))) {
92     return InvalidArgument(
93         "All inputs should have the same innermost dimension but got "
94         "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
95         "%d for rhs",
96         lower_diagonal_num_eqs, main_diagonal_num_eqs, upper_diagonal_num_eqs,
97         rhs_num_eqs);
98   }
99   const auto num_equations = lower_diagonal_num_eqs;
100 
101   TF_RETURN_IF_ERROR(CheckSecondToLastDimension(lower_diagonal_shape, rank, 1,
102                                                 "lower diagonal"));
103   TF_RETURN_IF_ERROR(
104       CheckSecondToLastDimension(main_diagonal_shape, rank, 1, "diagonal"));
105   TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
106                                                 "upper diagonal"));
107 
108   return num_equations;
109 }
110 
Coefficient(XlaOp operand,int32 i)111 XlaOp Coefficient(XlaOp operand, int32 i) {
112   return DynamicSliceInMinorDims(operand,
113                                  /*starts=*/{ConstantR0(operand.builder(), i)},
114                                  /*sizes=*/{1});
115 }
116 
Coefficient(XlaOp operand,XlaOp i)117 XlaOp Coefficient(XlaOp operand, XlaOp i) {
118   return DynamicSliceInMinorDims(operand,
119                                  /*starts=*/{i}, /*sizes=*/{1});
120 }
121 
UpdateEq(XlaOp updated,int32 i,XlaOp update)122 XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
123   return DynamicUpdateSliceInMinorDims(
124       updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
125 }
126 
UpdateEq(XlaOp updated,XlaOp i,XlaOp update)127 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
128   return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
129 }
130 
131 }  // namespace
132 
133 // Applies Thomas algorithm to solve a linear system where the linear operand
134 // is a tri-diagonal matrix.
135 // See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm for a simple
136 // reference on the Thomas algorithm.
137 // It is expected that the three diagonals are represented as tensors of shape
138 // [..., 1, num_equations] where num_equations is the number of dimensions of
139 // the unknowns considered in the linear systems.
140 // The first innermost dimension of `lower_diagonal` (`lower_diagonal[..., :,
141 // 0]`) will be ignored. The last innermost dimension of `upper_diagonal`
142 // (`upper_diagonal[..., :, num_equations - 1]`) will be ignored. The shape of
143 // the right-hand-side `rhs` should be [..., num_rhs, num_equations]. The
144 // solution will have the shape [..., num_rhs, num_equations].
ThomasSolver(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)145 StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
146                              XlaOp upper_diagonal, XlaOp rhs) {
147   XlaBuilder* builder = lower_diagonal.builder();
148 
149   TF_ASSIGN_OR_RETURN(int64 num_eqs,
150                       CheckSystemAndReturnNumEquations(
151                           lower_diagonal, main_diagonal, upper_diagonal, rhs));
152 
153   XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
154   XlaOp rhs_after_elimination = ZerosLike(rhs);
155   XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
156   XlaOp x_coeffs = ZerosLike(rhs);
157 
158   // main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
159   main_diag_after_elimination =
160       UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
161 
162   // rhs_after_elimination[:, 0] = rhs[:, 0];
163   rhs_after_elimination =
164       UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
165 
166   auto preparation_body_fn =
167       [](XlaOp i, absl::Span<const XlaOp> values,
168          XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
169     auto upper_diagonal_coeffs = values[0];
170     auto upper_diagonal = values[1];
171     // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
172     upper_diagonal_coeffs =
173         UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
174     return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
175   };
176   TF_ASSIGN_OR_RETURN(auto values_after_preparation,
177                       ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
178                                    {upper_diagonal_coeffs, upper_diagonal},
179                                    "preparation", builder));
180   upper_diagonal_coeffs = values_after_preparation[0];
181 
182   // Forward transformation.
183   auto forward_transformation_fn =
184       [](XlaOp i_minus_one, absl::Span<const XlaOp> values,
185          XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
186     auto lower_diagonal = values[0];
187     auto main_diagonal = values[1];
188     auto rhs = values[2];
189     auto main_diag_after_elimination = values[3];
190     auto upper_diagonal_coeffs = values[4];
191     auto rhs_after_elimination = values[5];
192 
193     auto one = ScalarLike(i_minus_one, 1);
194     auto i = i_minus_one + one;
195     auto lower_diagonal_i = Coefficient(lower_diagonal, i);
196     auto main_diagonal_i = Coefficient(main_diagonal, i);
197     auto rhs_i = Coefficient(rhs, i);
198 
199     auto w_i =
200         lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
201 
202     // main_diag_after_elimination[:, i] =
203     //     main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
204     main_diag_after_elimination = UpdateEq(
205         main_diag_after_elimination, i,
206         main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
207     // rhs_after_elimination[:, i] =
208     //     rhs_i - w_i * rhs_after_elimination[:, i - 1];
209     rhs_after_elimination =
210         UpdateEq(rhs_after_elimination, i,
211                  rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
212 
213     return std::vector<XlaOp>{lower_diagonal,
214                               main_diagonal,
215                               rhs,
216                               main_diag_after_elimination,
217                               upper_diagonal_coeffs,
218                               rhs_after_elimination};
219   };
220   TF_ASSIGN_OR_RETURN(
221       auto values_after_fwd_transformation,
222       ForEachIndex(
223           num_eqs - 1, S32, forward_transformation_fn,
224           {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
225            upper_diagonal_coeffs, rhs_after_elimination},
226           "forward_transformation", builder));
227   lower_diagonal = values_after_fwd_transformation[0];
228   main_diagonal = values_after_fwd_transformation[1];
229   rhs = values_after_fwd_transformation[2];
230   main_diag_after_elimination = values_after_fwd_transformation[3];
231   upper_diagonal_coeffs = values_after_fwd_transformation[4];
232   rhs_after_elimination = values_after_fwd_transformation[5];
233 
234   // Backward reduction.
235   // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
236   //                              main_diag_after_elimination[:, num_eqs - 1];
237   x_coeffs =
238       UpdateEq(x_coeffs, num_eqs - 1,
239                Coefficient(rhs_after_elimination, num_eqs - 1) /
240                    Coefficient(main_diag_after_elimination, num_eqs - 1));
241   auto bwd_reduction_fn =
242       [num_eqs](XlaOp j, absl::Span<const XlaOp> values,
243                 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
244     auto x_coeffs = values[0];
245     auto rhs_after_elimination = values[1];
246     auto upper_diagonal_coeffs = values[2];
247     auto main_diag_after_elimination = values[3];
248     auto n = ScalarLike(j, num_eqs - 2);
249     auto one = ScalarLike(j, 1);
250     auto i = n - j;
251     // for (int i = num_eqs - 2; i >= 0; i--)
252     //   x_coeffs[:, i] = (rhs_after_elimination[:, i] -
253     //     upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
254     //       main_diag_after_elimination[:, i];
255     x_coeffs = UpdateEq(x_coeffs, i,
256                         (Coefficient(rhs_after_elimination, i) -
257                          Coefficient(upper_diagonal_coeffs, i) *
258                              Coefficient(x_coeffs, i + one)) /
259                             Coefficient(main_diag_after_elimination, i));
260     return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
261                               upper_diagonal_coeffs,
262                               main_diag_after_elimination};
263   };
264 
265   TF_ASSIGN_OR_RETURN(
266       auto values_after_bwd_reduction,
267       ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
268                    {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
269                     main_diag_after_elimination},
270                    "backward_reduction", builder));
271   x_coeffs = values_after_bwd_reduction[0];
272 
273   return x_coeffs;
274 }
275 
276 // Applies Thomas algorithm to solve a linear system where the linear operand
277 // is a tri-diagonal matrix.
278 // It is expected that the tree diagonals are stacked into a tensors of shape
279 // [..., 3, num_equations] where num_equations is the number of spatial
280 // dimensions considered in the system.
281 // diagonals[..., 0, :] represents the upper diagonal whose last inner
282 // dimension will be ignored.
283 // diagonals[..., 1, :] represents the main diagonal.
284 // diagonals[..., 2, :] represents the lower diagonal whose first inner
285 // dimension will be ignored.
286 // The right-hand-side d is expected to have dimension
287 // [..., num_rhs, num_equations].
288 // The solution will have size [..., num_rhs, num_equations].
ThomasSolver(XlaOp diagonals,XlaOp rhs)289 StatusOr<XlaOp> ThomasSolver(XlaOp diagonals, XlaOp rhs) {
290   XlaBuilder* builder = diagonals.builder();
291   TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals));
292   const int64 rank = diagonals_shape.rank();
293 
294   auto upper_diagonal =
295       SliceInDim(diagonals, /*start_index=*/0, /*limit_index=*/1,
296                  /*stride=*/1, /*dimno=*/rank - 2);
297   auto main_diagonal =
298       SliceInDim(diagonals, /*start_index=*/1, /*limit_index=*/2,
299                  /*stride=*/1, /*dimno=*/rank - 2);
300   auto lower_diagonal =
301       SliceInDim(diagonals, /*start_index=*/2, /*limit_index=*/3,
302                  /*stride=*/1, /*dimno=*/rank - 2);
303 
304   // TODO(belletti): Get rid of the transposes here.
305   std::vector<int64> transpose_order(rank);
306   std::iota(transpose_order.begin(), transpose_order.end(), 0);
307   transpose_order[rank - 2] = rank - 1;
308   transpose_order[rank - 1] = rank - 2;
309   // Swap the last two dimensions.
310   rhs = Transpose(rhs, transpose_order);
311 
312   TF_ASSIGN_OR_RETURN(XlaOp x, ThomasSolver(lower_diagonal, main_diagonal,
313                                             upper_diagonal, rhs));
314   return Transpose(x, transpose_order);
315 }
316 
317 }  // namespace tridiagonal
318 }  // namespace xla
319