1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a lower triangular matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.linalg import linalg_impl as linalg 25from tensorflow.python.ops.linalg import linear_operator 26from tensorflow.python.ops.linalg import linear_operator_util 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = [ 30 "LinearOperatorLowerTriangular", 31] 32 33 34@tf_export("linalg.LinearOperatorLowerTriangular") 35class LinearOperatorLowerTriangular(linear_operator.LinearOperator): 36 """`LinearOperator` acting like a [batch] square lower triangular matrix. 37 38 This operator acts like a [batch] lower triangular matrix `A` with shape 39 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 40 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 41 an `N x N` matrix. 42 43 `LinearOperatorLowerTriangular` is initialized with a `Tensor` having 44 dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two 45 dimensions is ignored. 46 47 ```python 48 # Create a 2 x 2 lower-triangular linear operator. 49 tril = [[1., 2.], [3., 4.]] 50 operator = LinearOperatorLowerTriangular(tril) 51 52 # The upper triangle is ignored. 53 operator.to_dense() 54 ==> [[1., 0.] 55 [3., 4.]] 56 57 operator.shape 58 ==> [2, 2] 59 60 operator.log_abs_determinant() 61 ==> scalar Tensor 62 63 x = ... Shape [2, 4] Tensor 64 operator.matmul(x) 65 ==> Shape [2, 4] Tensor 66 67 # Create a [2, 3] batch of 4 x 4 linear operators. 68 tril = tf.random.normal(shape=[2, 3, 4, 4]) 69 operator = LinearOperatorLowerTriangular(tril) 70 ``` 71 72 #### Shape compatibility 73 74 This operator acts on [batch] matrix with compatible shape. 75 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 76 77 ``` 78 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 79 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 80 ``` 81 82 #### Performance 83 84 Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`, 85 and `x.shape = [N, R]`. Then 86 87 * `operator.matmul(x)` involves `N^2 * R` multiplications. 88 * `operator.solve(x)` involves `N * R` size `N` back-substitutions. 89 * `operator.determinant()` involves a size `N` `reduce_prod`. 90 91 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 92 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 93 94 #### Matrix property hints 95 96 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 97 for `X = non_singular, self_adjoint, positive_definite, square`. 98 These have the following meaning: 99 100 * If `is_X == True`, callers should expect the operator to have the 101 property `X`. This is a promise that should be fulfilled, but is *not* a 102 runtime assert. For example, finite floating point precision may result 103 in these promises being violated. 104 * If `is_X == False`, callers should expect the operator to not have `X`. 105 * If `is_X == None` (the default), callers should have no expectation either 106 way. 107 """ 108 109 def __init__(self, 110 tril, 111 is_non_singular=None, 112 is_self_adjoint=None, 113 is_positive_definite=None, 114 is_square=None, 115 name="LinearOperatorLowerTriangular"): 116 r"""Initialize a `LinearOperatorLowerTriangular`. 117 118 Args: 119 tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. 120 The lower triangular part of `tril` defines this operator. The strictly 121 upper triangle is ignored. 122 is_non_singular: Expect that this operator is non-singular. 123 This operator is non-singular if and only if its diagonal elements are 124 all non-zero. 125 is_self_adjoint: Expect that this operator is equal to its hermitian 126 transpose. This operator is self-adjoint only if it is diagonal with 127 real-valued diagonal entries. In this case it is advised to use 128 `LinearOperatorDiag`. 129 is_positive_definite: Expect that this operator is positive definite, 130 meaning the quadratic form `x^H A x` has positive real part for all 131 nonzero `x`. Note that we do not require the operator to be 132 self-adjoint to be positive-definite. See: 133 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 134 is_square: Expect that this operator acts like square [batch] matrices. 135 name: A name for this `LinearOperator`. 136 137 Raises: 138 ValueError: If `is_square` is `False`. 139 """ 140 parameters = dict( 141 tril=tril, 142 is_non_singular=is_non_singular, 143 is_self_adjoint=is_self_adjoint, 144 is_positive_definite=is_positive_definite, 145 is_square=is_square, 146 name=name 147 ) 148 149 if is_square is False: 150 raise ValueError( 151 "Only square lower triangular operators supported at this time.") 152 is_square = True 153 154 with ops.name_scope(name, values=[tril]): 155 self._tril = linear_operator_util.convert_nonref_to_tensor(tril, 156 name="tril") 157 self._check_tril(self._tril) 158 159 super(LinearOperatorLowerTriangular, self).__init__( 160 dtype=self._tril.dtype, 161 is_non_singular=is_non_singular, 162 is_self_adjoint=is_self_adjoint, 163 is_positive_definite=is_positive_definite, 164 is_square=is_square, 165 parameters=parameters, 166 name=name) 167 self._set_graph_parents([self._tril]) 168 169 def _check_tril(self, tril): 170 """Static check of the `tril` argument.""" 171 172 if tril.shape.ndims is not None and tril.shape.ndims < 2: 173 raise ValueError( 174 "Argument tril must have at least 2 dimensions. Found: %s" 175 % tril) 176 177 def _get_tril(self): 178 """Gets the `tril` kwarg, with upper part zero-d out.""" 179 return array_ops.matrix_band_part(self._tril, -1, 0) 180 181 def _get_diag(self): 182 """Gets the diagonal part of `tril` kwarg.""" 183 return array_ops.matrix_diag_part(self._tril) 184 185 def _shape(self): 186 return self._tril.shape 187 188 def _shape_tensor(self): 189 return array_ops.shape(self._tril) 190 191 def _assert_non_singular(self): 192 return linear_operator_util.assert_no_entries_with_modulus_zero( 193 self._get_diag(), 194 message="Singular operator: Diagonal contained zero values.") 195 196 def _matmul(self, x, adjoint=False, adjoint_arg=False): 197 return math_ops.matmul( 198 self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 199 200 def _determinant(self): 201 return math_ops.reduce_prod(self._get_diag(), axis=[-1]) 202 203 def _log_abs_determinant(self): 204 return math_ops.reduce_sum( 205 math_ops.log(math_ops.abs(self._get_diag())), axis=[-1]) 206 207 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 208 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 209 return linalg.triangular_solve( 210 self._get_tril(), rhs, lower=True, adjoint=adjoint) 211 212 def _to_dense(self): 213 return self._get_tril() 214 215 def _eigvals(self): 216 return self._get_diag() 217