1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Registrations for LinearOperator.cholesky.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.ops import linalg_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops.linalg import linear_operator 24from tensorflow.python.ops.linalg import linear_operator_algebra 25from tensorflow.python.ops.linalg import linear_operator_block_diag 26from tensorflow.python.ops.linalg import linear_operator_diag 27from tensorflow.python.ops.linalg import linear_operator_identity 28from tensorflow.python.ops.linalg import linear_operator_kronecker 29from tensorflow.python.ops.linalg import linear_operator_lower_triangular 30 31 32# By default, compute the Cholesky of the dense matrix, and return a 33# LowerTriangular operator. Methods below specialize this registration. 34@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator) 35def _cholesky_linear_operator(linop): 36 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 37 linalg_ops.cholesky(linop.to_dense()), 38 is_non_singular=True, 39 is_self_adjoint=False, 40 is_square=True) 41 42 43@linear_operator_algebra.RegisterCholesky( 44 linear_operator_diag.LinearOperatorDiag) 45def _cholesky_diag(diag_operator): 46 return linear_operator_diag.LinearOperatorDiag( 47 math_ops.sqrt(diag_operator.diag), 48 is_non_singular=True, 49 is_self_adjoint=True, 50 is_positive_definite=True, 51 is_square=True) 52 53 54@linear_operator_algebra.RegisterCholesky( 55 linear_operator_identity.LinearOperatorIdentity) 56def _cholesky_identity(identity_operator): 57 return linear_operator_identity.LinearOperatorIdentity( 58 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 59 batch_shape=identity_operator.batch_shape, 60 dtype=identity_operator.dtype, 61 is_non_singular=True, 62 is_self_adjoint=True, 63 is_positive_definite=True, 64 is_square=True) 65 66 67@linear_operator_algebra.RegisterCholesky( 68 linear_operator_identity.LinearOperatorScaledIdentity) 69def _cholesky_scaled_identity(identity_operator): 70 return linear_operator_identity.LinearOperatorScaledIdentity( 71 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 72 multiplier=math_ops.sqrt(identity_operator.multiplier), 73 is_non_singular=True, 74 is_self_adjoint=True, 75 is_positive_definite=True, 76 is_square=True) 77 78 79@linear_operator_algebra.RegisterCholesky( 80 linear_operator_block_diag.LinearOperatorBlockDiag) 81def _cholesky_block_diag(block_diag_operator): 82 # We take the cholesky of each block on the diagonal. 83 return linear_operator_block_diag.LinearOperatorBlockDiag( 84 operators=[ 85 operator.cholesky() for operator in block_diag_operator.operators], 86 is_non_singular=True, 87 is_self_adjoint=False, 88 is_square=True) 89 90 91@linear_operator_algebra.RegisterCholesky( 92 linear_operator_kronecker.LinearOperatorKronecker) 93def _cholesky_kronecker(kronecker_operator): 94 # Cholesky decomposition of a Kronecker product is the Kronecker product 95 # of cholesky decompositions. 96 return linear_operator_kronecker.LinearOperatorKronecker( 97 operators=[ 98 operator.cholesky() for operator in kronecker_operator.operators], 99 is_non_singular=True, 100 is_self_adjoint=False, 101 is_square=True) 102