1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a lower triangular matrix."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.linalg import linalg_impl as linalg
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = [
30    "LinearOperatorLowerTriangular",
31]
32
33
34@tf_export("linalg.LinearOperatorLowerTriangular")
35class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
36  """`LinearOperator` acting like a [batch] square lower triangular matrix.
37
38  This operator acts like a [batch] lower triangular matrix `A` with shape
39  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
40  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
41  an `N x N` matrix.
42
43  `LinearOperatorLowerTriangular` is initialized with a `Tensor` having
44  dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two
45  dimensions is ignored.
46
47  ```python
48  # Create a 2 x 2 lower-triangular linear operator.
49  tril = [[1., 2.], [3., 4.]]
50  operator = LinearOperatorLowerTriangular(tril)
51
52  # The upper triangle is ignored.
53  operator.to_dense()
54  ==> [[1., 0.]
55       [3., 4.]]
56
57  operator.shape
58  ==> [2, 2]
59
60  operator.log_abs_determinant()
61  ==> scalar Tensor
62
63  x = ... Shape [2, 4] Tensor
64  operator.matmul(x)
65  ==> Shape [2, 4] Tensor
66
67  # Create a [2, 3] batch of 4 x 4 linear operators.
68  tril = tf.random.normal(shape=[2, 3, 4, 4])
69  operator = LinearOperatorLowerTriangular(tril)
70  ```
71
72  #### Shape compatibility
73
74  This operator acts on [batch] matrix with compatible shape.
75  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
76
77  ```
78  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
79  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
80  ```
81
82  #### Performance
83
84  Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`,
85  and `x.shape = [N, R]`.  Then
86
87  * `operator.matmul(x)` involves `N^2 * R` multiplications.
88  * `operator.solve(x)` involves `N * R` size `N` back-substitutions.
89  * `operator.determinant()` involves a size `N` `reduce_prod`.
90
91  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
92  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
93
94  #### Matrix property hints
95
96  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
97  for `X = non_singular, self_adjoint, positive_definite, square`.
98  These have the following meaning:
99
100  * If `is_X == True`, callers should expect the operator to have the
101    property `X`.  This is a promise that should be fulfilled, but is *not* a
102    runtime assert.  For example, finite floating point precision may result
103    in these promises being violated.
104  * If `is_X == False`, callers should expect the operator to not have `X`.
105  * If `is_X == None` (the default), callers should have no expectation either
106    way.
107  """
108
109  def __init__(self,
110               tril,
111               is_non_singular=None,
112               is_self_adjoint=None,
113               is_positive_definite=None,
114               is_square=None,
115               name="LinearOperatorLowerTriangular"):
116    r"""Initialize a `LinearOperatorLowerTriangular`.
117
118    Args:
119      tril:  Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
120        The lower triangular part of `tril` defines this operator.  The strictly
121        upper triangle is ignored.
122      is_non_singular:  Expect that this operator is non-singular.
123        This operator is non-singular if and only if its diagonal elements are
124        all non-zero.
125      is_self_adjoint:  Expect that this operator is equal to its hermitian
126        transpose.  This operator is self-adjoint only if it is diagonal with
127        real-valued diagonal entries.  In this case it is advised to use
128        `LinearOperatorDiag`.
129      is_positive_definite:  Expect that this operator is positive definite,
130        meaning the quadratic form `x^H A x` has positive real part for all
131        nonzero `x`.  Note that we do not require the operator to be
132        self-adjoint to be positive-definite.  See:
133        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
134      is_square:  Expect that this operator acts like square [batch] matrices.
135      name: A name for this `LinearOperator`.
136
137    Raises:
138      ValueError:  If `is_square` is `False`.
139    """
140    parameters = dict(
141        tril=tril,
142        is_non_singular=is_non_singular,
143        is_self_adjoint=is_self_adjoint,
144        is_positive_definite=is_positive_definite,
145        is_square=is_square,
146        name=name
147    )
148
149    if is_square is False:
150      raise ValueError(
151          "Only square lower triangular operators supported at this time.")
152    is_square = True
153
154    with ops.name_scope(name, values=[tril]):
155      self._tril = linear_operator_util.convert_nonref_to_tensor(tril,
156                                                                 name="tril")
157      self._check_tril(self._tril)
158
159      super(LinearOperatorLowerTriangular, self).__init__(
160          dtype=self._tril.dtype,
161          is_non_singular=is_non_singular,
162          is_self_adjoint=is_self_adjoint,
163          is_positive_definite=is_positive_definite,
164          is_square=is_square,
165          parameters=parameters,
166          name=name)
167      self._set_graph_parents([self._tril])
168
169  def _check_tril(self, tril):
170    """Static check of the `tril` argument."""
171
172    if tril.shape.ndims is not None and tril.shape.ndims < 2:
173      raise ValueError(
174          "Argument tril must have at least 2 dimensions.  Found: %s"
175          % tril)
176
177  def _get_tril(self):
178    """Gets the `tril` kwarg, with upper part zero-d out."""
179    return array_ops.matrix_band_part(self._tril, -1, 0)
180
181  def _get_diag(self):
182    """Gets the diagonal part of `tril` kwarg."""
183    return array_ops.matrix_diag_part(self._tril)
184
185  def _shape(self):
186    return self._tril.shape
187
188  def _shape_tensor(self):
189    return array_ops.shape(self._tril)
190
191  def _assert_non_singular(self):
192    return linear_operator_util.assert_no_entries_with_modulus_zero(
193        self._get_diag(),
194        message="Singular operator:  Diagonal contained zero values.")
195
196  def _matmul(self, x, adjoint=False, adjoint_arg=False):
197    return math_ops.matmul(
198        self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
199
200  def _determinant(self):
201    return math_ops.reduce_prod(self._get_diag(), axis=[-1])
202
203  def _log_abs_determinant(self):
204    return math_ops.reduce_sum(
205        math_ops.log(math_ops.abs(self._get_diag())), axis=[-1])
206
207  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
208    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
209    return linalg.triangular_solve(
210        self._get_tril(), rhs, lower=True, adjoint=adjoint)
211
212  def _to_dense(self):
213    return self._get_tril()
214
215  def _eigvals(self):
216    return self._get_diag()
217