1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.cholesky."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.ops import linalg_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops.linalg import linear_operator
24from tensorflow.python.ops.linalg import linear_operator_algebra
25from tensorflow.python.ops.linalg import linear_operator_block_diag
26from tensorflow.python.ops.linalg import linear_operator_diag
27from tensorflow.python.ops.linalg import linear_operator_identity
28from tensorflow.python.ops.linalg import linear_operator_kronecker
29from tensorflow.python.ops.linalg import linear_operator_lower_triangular
30
31
32# By default, compute the Cholesky of the dense matrix, and return a
33# LowerTriangular operator. Methods below specialize this registration.
34@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator)
35def _cholesky_linear_operator(linop):
36  return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
37      linalg_ops.cholesky(linop.to_dense()),
38      is_non_singular=True,
39      is_self_adjoint=False,
40      is_square=True)
41
42
43@linear_operator_algebra.RegisterCholesky(
44    linear_operator_diag.LinearOperatorDiag)
45def _cholesky_diag(diag_operator):
46  return linear_operator_diag.LinearOperatorDiag(
47      math_ops.sqrt(diag_operator.diag),
48      is_non_singular=True,
49      is_self_adjoint=True,
50      is_positive_definite=True,
51      is_square=True)
52
53
54@linear_operator_algebra.RegisterCholesky(
55    linear_operator_identity.LinearOperatorIdentity)
56def _cholesky_identity(identity_operator):
57  return linear_operator_identity.LinearOperatorIdentity(
58      num_rows=identity_operator._num_rows,  # pylint: disable=protected-access
59      batch_shape=identity_operator.batch_shape,
60      dtype=identity_operator.dtype,
61      is_non_singular=True,
62      is_self_adjoint=True,
63      is_positive_definite=True,
64      is_square=True)
65
66
67@linear_operator_algebra.RegisterCholesky(
68    linear_operator_identity.LinearOperatorScaledIdentity)
69def _cholesky_scaled_identity(identity_operator):
70  return linear_operator_identity.LinearOperatorScaledIdentity(
71      num_rows=identity_operator._num_rows,  # pylint: disable=protected-access
72      multiplier=math_ops.sqrt(identity_operator.multiplier),
73      is_non_singular=True,
74      is_self_adjoint=True,
75      is_positive_definite=True,
76      is_square=True)
77
78
79@linear_operator_algebra.RegisterCholesky(
80    linear_operator_block_diag.LinearOperatorBlockDiag)
81def _cholesky_block_diag(block_diag_operator):
82    # We take the cholesky of each block on the diagonal.
83  return linear_operator_block_diag.LinearOperatorBlockDiag(
84      operators=[
85          operator.cholesky() for operator in block_diag_operator.operators],
86      is_non_singular=True,
87      is_self_adjoint=False,
88      is_square=True)
89
90
91@linear_operator_algebra.RegisterCholesky(
92    linear_operator_kronecker.LinearOperatorKronecker)
93def _cholesky_kronecker(kronecker_operator):
94    # Cholesky decomposition of a Kronecker product is the Kronecker product
95    # of cholesky decompositions.
96  return linear_operator_kronecker.LinearOperatorKronecker(
97      operators=[
98          operator.cholesky() for operator in kronecker_operator.operators],
99      is_non_singular=True,
100      is_self_adjoint=False,
101      is_square=True)
102