1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
18 
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 
22 namespace xla {
23 
24 // Computes the LU decomposition with partial pivoting of a batch of matrices.
25 //
26 // Given a (batched) matrix a with shape [..., m, n], computes the matrix
27 // decomposition A = P @ L @ U where P is a permutation matrix, L is a
28 // lower-triangular matrix with unit diagonal entries, and U is an
29 // upper-triangular matrix.
30 //
31 // L and U are returned as a single matrix [..., m, n] containing both L and U
32 // packed in the same array. The unit diagonal of L is not represented
33 // explicitly.
34 //
35 // The permutation matrix P is returned in two forms, both as `pivots`, which is
36 // an s32[..., min(m, n)] array that describes a sequence of row-swaps in the
37 // style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array
38 // which gives the permutation to apply to the rows. We return both
39 // representations because they are each useful for different purposes; `pivots`
40 // is useful for computing the sign of a determinant, whereas `permutation` can
41 // be used via a Gather operation to permute the rows of a matrix.
42 //
43 // This method is only implemented on TPU at the moment.
44 // TODO(b/168208200): the implementation only supports F32 arrays. Handle the
45 // complex case.
46 struct LuDecompositionResult {
47   // The LU decomposition, with both L and U packed into an array with shape
48   // [..., m, n].
49   XlaOp lu;
50   // An array of shape s32[..., min(m, n)] containing the pivot rows.
51   XlaOp pivots;
52   // An array of shape s32[..., m], containing an another representation of the
53   // pivots as a permutation.
54   XlaOp permutation;
55 };
56 
57 LuDecompositionResult LuDecomposition(XlaOp a);
58 
59 }  // namespace xla
60 
61 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
62