1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for linear algebra."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import gen_linalg_ops
31from tensorflow.python.ops import linalg_ops
32from tensorflow.python.ops import map_fn
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import special_math_ops
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.tf_export import tf_export
37
38# Linear algebra ops.
39band_part = array_ops.matrix_band_part
40cholesky = linalg_ops.cholesky
41cholesky_solve = linalg_ops.cholesky_solve
42det = linalg_ops.matrix_determinant
43slogdet = gen_linalg_ops.log_matrix_determinant
44tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet))
45diag = array_ops.matrix_diag
46diag_part = array_ops.matrix_diag_part
47eigh = linalg_ops.self_adjoint_eig
48eigvalsh = linalg_ops.self_adjoint_eigvals
49einsum = special_math_ops.einsum
50eye = linalg_ops.eye
51inv = linalg_ops.matrix_inverse
52logm = gen_linalg_ops.matrix_logarithm
53lu = gen_linalg_ops.lu
54tf_export('linalg.logm')(dispatch.add_dispatch_support(logm))
55lstsq = linalg_ops.matrix_solve_ls
56norm = linalg_ops.norm
57qr = linalg_ops.qr
58set_diag = array_ops.matrix_set_diag
59solve = linalg_ops.matrix_solve
60sqrtm = linalg_ops.matrix_square_root
61svd = linalg_ops.svd
62tensordot = math_ops.tensordot
63trace = math_ops.trace
64transpose = array_ops.matrix_transpose
65triangular_solve = linalg_ops.matrix_triangular_solve
66
67
68@tf_export('linalg.logdet')
69@dispatch.add_dispatch_support
70def logdet(matrix, name=None):
71  """Computes log of the determinant of a hermitian positive definite matrix.
72
73  ```python
74  # Compute the determinant of a matrix while reducing the chance of over- or
75  underflow:
76  A = ... # shape 10 x 10
77  det = tf.exp(tf.linalg.logdet(A))  # scalar
78  ```
79
80  Args:
81    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
82      or `complex128` with shape `[..., M, M]`.
83    name:  A name to give this `Op`.  Defaults to `logdet`.
84
85  Returns:
86    The natural log of the determinant of `matrix`.
87
88  @compatibility(numpy)
89  Equivalent to numpy.linalg.slogdet, although no sign is returned since only
90  hermitian positive definite matrices are supported.
91  @end_compatibility
92  """
93  # This uses the property that the log det(A) = 2*sum(log(real(diag(C))))
94  # where C is the cholesky decomposition of A.
95  with ops.name_scope(name, 'logdet', [matrix]):
96    chol = gen_linalg_ops.cholesky(matrix)
97    return 2.0 * math_ops.reduce_sum(
98        math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
99        axis=[-1])
100
101
102@tf_export('linalg.adjoint')
103@dispatch.add_dispatch_support
104def adjoint(matrix, name=None):
105  """Transposes the last two dimensions of and conjugates tensor `matrix`.
106
107  For example:
108
109  ```python
110  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
111                   [4 + 4j, 5 + 5j, 6 + 6j]])
112  tf.linalg.adjoint(x)  # [[1 - 1j, 4 - 4j],
113                        #  [2 - 2j, 5 - 5j],
114                        #  [3 - 3j, 6 - 6j]]
115  ```
116
117  Args:
118    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
119      or `complex128` with shape `[..., M, M]`.
120    name:  A name to give this `Op` (optional).
121
122  Returns:
123    The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of
124    matrix.
125  """
126  with ops.name_scope(name, 'adjoint', [matrix]):
127    matrix = ops.convert_to_tensor(matrix, name='matrix')
128    return array_ops.matrix_transpose(matrix, conjugate=True)
129
130
131# This section is ported nearly verbatim from Eigen's implementation:
132# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html
133def _matrix_exp_pade3(matrix):
134  """3rd-order Pade approximant for matrix exponential."""
135  b = [120.0, 60.0, 12.0]
136  b = [constant_op.constant(x, matrix.dtype) for x in b]
137  ident = linalg_ops.eye(
138      array_ops.shape(matrix)[-2],
139      batch_shape=array_ops.shape(matrix)[:-2],
140      dtype=matrix.dtype)
141  matrix_2 = math_ops.matmul(matrix, matrix)
142  tmp = matrix_2 + b[1] * ident
143  matrix_u = math_ops.matmul(matrix, tmp)
144  matrix_v = b[2] * matrix_2 + b[0] * ident
145  return matrix_u, matrix_v
146
147
148def _matrix_exp_pade5(matrix):
149  """5th-order Pade approximant for matrix exponential."""
150  b = [30240.0, 15120.0, 3360.0, 420.0, 30.0]
151  b = [constant_op.constant(x, matrix.dtype) for x in b]
152  ident = linalg_ops.eye(
153      array_ops.shape(matrix)[-2],
154      batch_shape=array_ops.shape(matrix)[:-2],
155      dtype=matrix.dtype)
156  matrix_2 = math_ops.matmul(matrix, matrix)
157  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
158  tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident
159  matrix_u = math_ops.matmul(matrix, tmp)
160  matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
161  return matrix_u, matrix_v
162
163
164def _matrix_exp_pade7(matrix):
165  """7th-order Pade approximant for matrix exponential."""
166  b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0]
167  b = [constant_op.constant(x, matrix.dtype) for x in b]
168  ident = linalg_ops.eye(
169      array_ops.shape(matrix)[-2],
170      batch_shape=array_ops.shape(matrix)[:-2],
171      dtype=matrix.dtype)
172  matrix_2 = math_ops.matmul(matrix, matrix)
173  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
174  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
175  tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident
176  matrix_u = math_ops.matmul(matrix, tmp)
177  matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
178  return matrix_u, matrix_v
179
180
181def _matrix_exp_pade9(matrix):
182  """9th-order Pade approximant for matrix exponential."""
183  b = [
184      17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0,
185      2162160.0, 110880.0, 3960.0, 90.0
186  ]
187  b = [constant_op.constant(x, matrix.dtype) for x in b]
188  ident = linalg_ops.eye(
189      array_ops.shape(matrix)[-2],
190      batch_shape=array_ops.shape(matrix)[:-2],
191      dtype=matrix.dtype)
192  matrix_2 = math_ops.matmul(matrix, matrix)
193  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
194  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
195  matrix_8 = math_ops.matmul(matrix_6, matrix_2)
196  tmp = (
197      matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 +
198      b[1] * ident)
199  matrix_u = math_ops.matmul(matrix, tmp)
200  matrix_v = (
201      b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 +
202      b[0] * ident)
203  return matrix_u, matrix_v
204
205
206def _matrix_exp_pade13(matrix):
207  """13th-order Pade approximant for matrix exponential."""
208  b = [
209      64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
210      1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
211      33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0
212  ]
213  b = [constant_op.constant(x, matrix.dtype) for x in b]
214  ident = linalg_ops.eye(
215      array_ops.shape(matrix)[-2],
216      batch_shape=array_ops.shape(matrix)[:-2],
217      dtype=matrix.dtype)
218  matrix_2 = math_ops.matmul(matrix, matrix)
219  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
220  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
221  tmp_u = (
222      math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) +
223      b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident)
224  matrix_u = math_ops.matmul(matrix, tmp_u)
225  tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2
226  matrix_v = (
227      math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 +
228      b[2] * matrix_2 + b[0] * ident)
229  return matrix_u, matrix_v
230
231
232@tf_export('linalg.expm')
233@dispatch.add_dispatch_support
234def matrix_exponential(input, name=None):  # pylint: disable=redefined-builtin
235  r"""Computes the matrix exponential of one or more square matrices.
236
237  $$exp(A) = \sum_{n=0}^\infty A^n/n!$$
238
239  The exponential is computed using a combination of the scaling and squaring
240  method and the Pade approximation. Details can be found in:
241  Nicholas J. Higham, "The scaling and squaring method for the matrix
242  exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
243
244  The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
245  form square matrices. The output is a tensor of the same shape as the input
246  containing the exponential for all input submatrices `[..., :, :]`.
247
248  Args:
249    input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or
250      `complex128` with shape `[..., M, M]`.
251    name:  A name to give this `Op` (optional).
252
253  Returns:
254    the matrix exponential of the input.
255
256  Raises:
257    ValueError: An unsupported type is provided as input.
258
259  @compatibility(scipy)
260  Equivalent to scipy.linalg.expm
261  @end_compatibility
262  """
263  with ops.name_scope(name, 'matrix_exponential', [input]):
264    matrix = ops.convert_to_tensor(input, name='input')
265    if matrix.shape[-2:] == [0, 0]:
266      return matrix
267    batch_shape = matrix.shape[:-2]
268    if not batch_shape.is_fully_defined():
269      batch_shape = array_ops.shape(matrix)[:-2]
270
271    # reshaping the batch makes the where statements work better
272    matrix = array_ops.reshape(
273        matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
274    l1_norm = math_ops.reduce_max(
275        math_ops.reduce_sum(
276            math_ops.abs(matrix),
277            axis=array_ops.size(array_ops.shape(matrix)) - 2),
278        axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
279
280    const = lambda x: constant_op.constant(x, l1_norm.dtype)
281
282    def _nest_where(vals, cases):
283      assert len(vals) == len(cases) - 1
284      if len(vals) == 1:
285        return array_ops.where_v2(
286            math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
287      else:
288        return array_ops.where_v2(
289            math_ops.less(l1_norm, const(vals[0])), cases[0],
290            _nest_where(vals[1:], cases[1:]))
291
292    if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
293      maxnorm = const(3.925724783138660)
294      squarings = math_ops.maximum(
295          math_ops.floor(
296              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
297      u3, v3 = _matrix_exp_pade3(matrix)
298      u5, v5 = _matrix_exp_pade5(matrix)
299      u7, v7 = _matrix_exp_pade7(
300          matrix /
301          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
302      conds = (4.258730016922831e-001, 1.880152677804762e+000)
303      u = _nest_where(conds, (u3, u5, u7))
304      v = _nest_where(conds, (v3, v5, v7))
305    elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
306      maxnorm = const(5.371920351148152)
307      squarings = math_ops.maximum(
308          math_ops.floor(
309              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
310      u3, v3 = _matrix_exp_pade3(matrix)
311      u5, v5 = _matrix_exp_pade5(matrix)
312      u7, v7 = _matrix_exp_pade7(matrix)
313      u9, v9 = _matrix_exp_pade9(matrix)
314      u13, v13 = _matrix_exp_pade13(
315          matrix /
316          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
317      conds = (1.495585217958292e-002, 2.539398330063230e-001,
318               9.504178996162932e-001, 2.097847961257068e+000)
319      u = _nest_where(conds, (u3, u5, u7, u9, u13))
320      v = _nest_where(conds, (v3, v5, v7, v9, v13))
321    else:
322      raise ValueError('tf.linalg.expm does not support matrices of type %s' %
323                       matrix.dtype)
324
325    is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm))
326    nan = constant_op.constant(np.nan, matrix.dtype)
327    result = control_flow_ops.cond(
328        is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v),
329        lambda: array_ops.fill(array_ops.shape(matrix), nan))
330    max_squarings = math_ops.reduce_max(squarings)
331    i = const(0.0)
332
333    def c(i, _):
334      return control_flow_ops.cond(is_finite,
335                                   lambda: math_ops.less(i, max_squarings),
336                                   lambda: constant_op.constant(False))
337
338    def b(i, r):
339      return i + 1, array_ops.where_v2(
340          math_ops.less(i, squarings), math_ops.matmul(r, r), r)
341
342    _, result = control_flow_ops.while_loop(c, b, [i, result])
343    if not matrix.shape.is_fully_defined():
344      return array_ops.reshape(
345          result,
346          array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
347    return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
348
349
350@tf_export('linalg.banded_triangular_solve', v1=[])
351def banded_triangular_solve(
352    bands,
353    rhs,
354    lower=True,
355    adjoint=False,  # pylint: disable=redefined-outer-name
356    name=None):
357  r"""Solve triangular systems of equations with a banded solver.
358
359  `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number
360  of bands stored. This corresponds to a batch of `M` by `M` matrices, whose
361  `K` subdiagonals (when `lower` is `True`) are stored.
362
363  This operator broadcasts the batch dimensions of `bands` and the batch
364  dimensions of `rhs`.
365
366
367  Examples:
368
369  Storing 2 bands of a 3x3 matrix.
370  Note that first element in the second row is ignored due to
371  the 'LEFT_RIGHT' padding.
372
373  >>> x = [[2., 3., 4.], [1., 2., 3.]]
374  >>> x2 = [[2., 3., 4.], [10000., 2., 3.]]
375  >>> y = tf.zeros([3, 3])
376  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
377  >>> z
378  <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
379  array([[2., 0., 0.],
380         [2., 3., 0.],
381         [0., 3., 4.]], dtype=float32)>
382  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
383  >>> soln
384  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
385  array([[0.5 ],
386         [0.  ],
387         [0.25]], dtype=float32)>
388  >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
389  >>> tf.reduce_all(are_equal).numpy()
390  True
391  >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
392  >>> tf.reduce_all(are_equal).numpy()
393  True
394
395  Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding
396  the last element of the first row is ignored.
397
398  >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
399  >>> y = tf.zeros([4, 4])
400  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
401  >>> z
402  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
403  array([[-1.,  2.,  0.,  0.],
404         [ 0., -2.,  3.,  0.],
405         [ 0.,  0., -3.,  4.],
406         [ 0.,  0., -0., -4.]], dtype=float32)>
407  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
408  >>> soln
409  <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
410  array([[-4.       ],
411         [-1.5      ],
412         [-0.6666667],
413         [-0.25     ]], dtype=float32)>
414  >>> are_equal = (soln == tf.linalg.triangular_solve(
415  ...   z, tf.ones([4, 1]), lower=False))
416  >>> tf.reduce_all(are_equal).numpy()
417  True
418
419
420  Args:
421    bands: A `Tensor` describing the bands of the left hand side, with shape
422      `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th
423      diagonal (the diagonal is the top row) when `lower` is `True` and
424      otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is
425      the bottom row) when `lower` is `False`. The bands are stored with
426      'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right
427      and subdiagonals are padded on the left. This is the alignment cuSPARSE
428      uses.  See  `tf.linalg.set_diag` for more details.
429    rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as
430      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
431      statically, `rhs` will be treated as a matrix rather than a vector.
432    lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
433      `bands` represents a lower or upper triangular matrix.
434    adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
435      to solve with the matrix's block-wise adjoint.
436    name:  A name to give this `Op` (optional).
437
438  Returns:
439    A `Tensor` of shape [..., M] or [..., M, N] containing the solutions.
440  """
441  with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]):
442    return gen_linalg_ops.banded_triangular_solve(
443        bands, rhs, lower=lower, adjoint=adjoint)
444
445
446@tf_export('linalg.tridiagonal_solve')
447@dispatch.add_dispatch_support
448def tridiagonal_solve(diagonals,
449                      rhs,
450                      diagonals_format='compact',
451                      transpose_rhs=False,
452                      conjugate_rhs=False,
453                      name=None,
454                      partial_pivoting=True):
455  r"""Solves tridiagonal systems of equations.
456
457  The input can be supplied in various formats: `matrix`, `sequence` and
458  `compact`, specified by the `diagonals_format` arg.
459
460  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
461  two inner-most dimensions representing the square tridiagonal matrices.
462  Elements outside of the three diagonals will be ignored.
463
464  In `sequence` format, `diagonals` are supplied as a tuple or list of three
465  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
466  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
467  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
468  first element of subdiagonal will be ignored.
469
470  In `compact` format the three diagonals are brought together into one tensor
471  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
472  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
473  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
474
475  The `compact` format is recommended as the one with best performance. In case
476  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
477  An example for a tensor of shape [m, m]:
478
479  ```python
480  rhs = tf.constant([...])
481  matrix = tf.constant([[...]])
482  m = matrix.shape[0]
483  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
484  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
485           [[i, i] for i in range(m)],                          # Diagonal
486           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
487  diagonals=tf.gather_nd(matrix, indices)
488  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
489  ```
490
491  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
492  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
493  same left-hand sides and K different right-hand sides. If `transpose_rhs`
494  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.
495
496  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
497  `rhs`.
498
499  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
500  `[..., M, K]`.
501
502  The op isn't guaranteed to raise an error if the input matrix is not
503  invertible. `tf.debugging.check_numerics` can be applied to the output to
504  detect invertibility problems.
505
506  **Note**: with large batch sizes, the computation on the GPU may be slow, if
507  either `partial_pivoting=True` or there are multiple right-hand sides
508  (`K > 1`). If this issue arises, consider if it's possible to disable pivoting
509  and have `K = 1`, or, alternatively, consider using CPU.
510
511  On CPU, solution is computed via Gaussian elimination with or without partial
512  pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
513  library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
514
515  Args:
516    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
517      shape depends of `diagonals_format`, see description above. Must be
518      `float32`, `float64`, `complex64`, or `complex128`.
519    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
520      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
521      statically, `rhs` will be treated as a matrix rather than a vector.
522    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
523      `compact`.
524    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
525      if the shape of rhs is [..., M]).
526    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
527    name:  A name to give this `Op` (optional).
528    partial_pivoting: whether to perform partial pivoting. `True` by default.
529      Partial pivoting makes the procedure more stable, but slower. Partial
530      pivoting is unnecessary in some cases, including diagonally dominant and
531      symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).
532
533  Returns:
534    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.
535
536  Raises:
537    ValueError: An unsupported type is provided as input, or when the input
538      tensors have incorrect shapes.
539    UnimplementedError: Whenever `partial_pivoting` is true and the backend is
540      XLA.
541
542  [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
543  Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
544
545  """
546  if diagonals_format == 'compact':
547    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
548                                             conjugate_rhs, partial_pivoting,
549                                             name)
550
551  if diagonals_format == 'sequence':
552    if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
553      raise ValueError('Expected diagonals to be a sequence of length 3.')
554
555    superdiag, maindiag, subdiag = diagonals
556    if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or
557        not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])):
558      raise ValueError(
559          'Tensors representing the three diagonals must have the same shape,'
560          'except for the last dimension, got {}, {}, {}'.format(
561              subdiag.shape, maindiag.shape, superdiag.shape))
562
563    m = tensor_shape.dimension_value(maindiag.shape[-1])
564
565    def pad_if_necessary(t, name, last_dim_padding):
566      n = tensor_shape.dimension_value(t.shape[-1])
567      if not n or n == m:
568        return t
569      if n == m - 1:
570        paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
571                    [last_dim_padding])
572        return array_ops.pad(t, paddings)
573      raise ValueError('Expected {} to be have length {} or {}, got {}.'.format(
574          name, m, m - 1, n))
575
576    subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
577    superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])
578
579    diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
580    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
581                                             conjugate_rhs, partial_pivoting,
582                                             name)
583
584  if diagonals_format == 'matrix':
585    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
586    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
587    if m1 and m2 and m1 != m2:
588      raise ValueError(
589          'Expected last two dimensions of diagonals to be same, got {} and {}'
590          .format(m1, m2))
591    m = m1 or m2
592    diagonals = array_ops.matrix_diag_part(
593        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
594    return _tridiagonal_solve_compact_format(
595        diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name)
596
597  raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format))
598
599
600def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
601                                      conjugate_rhs, partial_pivoting, name):
602  """Helper function used after the input has been cast to compact form."""
603  diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank
604
605  # If we know the rank of the diagonal tensor, do some static checking.
606  if diags_rank:
607    if diags_rank < 2:
608      raise ValueError(
609          'Expected diagonals to have rank at least 2, got {}'.format(
610              diags_rank))
611    if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
612      raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format(
613          diags_rank - 1, diags_rank, rhs_rank))
614    if (rhs_rank and not diagonals.shape[:-2].is_compatible_with(
615        rhs.shape[:diags_rank - 2])):
616      raise ValueError('Batch shapes {} and {} are incompatible'.format(
617          diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))
618
619  if diagonals.shape[-2] and diagonals.shape[-2] != 3:
620    raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2]))
621
622  def check_num_lhs_matches_num_rhs():
623    if (diagonals.shape[-1] and rhs.shape[-2] and
624        diagonals.shape[-1] != rhs.shape[-2]):
625      raise ValueError('Expected number of left-hand sided and right-hand '
626                       'sides to be equal, got {} and {}'.format(
627                           diagonals.shape[-1], rhs.shape[-2]))
628
629  if rhs_rank and diags_rank and rhs_rank == diags_rank - 1:
630    # Rhs provided as a vector, ignoring transpose_rhs
631    if conjugate_rhs:
632      rhs = math_ops.conj(rhs)
633    rhs = array_ops.expand_dims(rhs, -1)
634    check_num_lhs_matches_num_rhs()
635    return array_ops.squeeze(
636        linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name),
637        -1)
638
639  if transpose_rhs:
640    rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
641  elif conjugate_rhs:
642    rhs = math_ops.conj(rhs)
643
644  check_num_lhs_matches_num_rhs()
645  return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name)
646
647
648@tf_export('linalg.tridiagonal_matmul')
649@dispatch.add_dispatch_support
650def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
651  r"""Multiplies tridiagonal matrix by matrix.
652
653  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
654  `diagonals_format`.
655
656  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
657  two inner-most dimensions representing the square tridiagonal matrices.
658  Elements outside of the three diagonals will be ignored.
659
660  If `sequence` format, `diagonals` is list or tuple of three tensors:
661  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
662  of `superdiag` first element of `subdiag` are ignored.
663
664  In `compact` format the three diagonals are brought together into one tensor
665  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
666  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
667  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
668
669  The `sequence` format is recommended as the one with the best performance.
670
671  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.
672
673  Example:
674
675  ```python
676  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
677  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
678  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
679  diagonals = [superdiag, maindiag, subdiag]
680  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
681  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
682  ```
683
684  Args:
685    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
686      shape depends of `diagonals_format`, see description above. Must be
687      `float32`, `float64`, `complex64`, or `complex128`.
688    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
689    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
690    name:  A name to give this `Op` (optional).
691
692  Returns:
693    A `Tensor` of shape [..., M, N] containing the result of multiplication.
694
695  Raises:
696    ValueError: An unsupported type is provided as input, or when the input
697    tensors have incorrect shapes.
698  """
699  if diagonals_format == 'compact':
700    superdiag = diagonals[..., 0, :]
701    maindiag = diagonals[..., 1, :]
702    subdiag = diagonals[..., 2, :]
703  elif diagonals_format == 'sequence':
704    superdiag, maindiag, subdiag = diagonals
705  elif diagonals_format == 'matrix':
706    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
707    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
708    if m1 and m2 and m1 != m2:
709      raise ValueError(
710          'Expected last two dimensions of diagonals to be same, got {} and {}'
711          .format(m1, m2))
712    diags = array_ops.matrix_diag_part(
713        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
714    superdiag = diags[..., 0, :]
715    maindiag = diags[..., 1, :]
716    subdiag = diags[..., 2, :]
717  else:
718    raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
719
720  # C++ backend requires matrices.
721  # Converting 1-dimensional vectors to matrices with 1 row.
722  superdiag = array_ops.expand_dims(superdiag, -2)
723  maindiag = array_ops.expand_dims(maindiag, -2)
724  subdiag = array_ops.expand_dims(subdiag, -2)
725
726  return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)
727
728
729def _maybe_validate_matrix(a, validate_args):
730  """Checks that input is a `float` matrix."""
731  assertions = []
732  if not a.dtype.is_floating:
733    raise TypeError('Input `a` must have `float`-like `dtype` '
734                    '(saw {}).'.format(a.dtype.name))
735  if a.shape is not None and a.shape.rank is not None:
736    if a.shape.rank < 2:
737      raise ValueError('Input `a` must have at least 2 dimensions '
738                       '(saw: {}).'.format(a.shape.rank))
739  elif validate_args:
740    assertions.append(
741        check_ops.assert_rank_at_least(
742            a, rank=2, message='Input `a` must have at least 2 dimensions.'))
743  return assertions
744
745
746@tf_export('linalg.matrix_rank')
747@dispatch.add_dispatch_support
748def matrix_rank(a, tol=None, validate_args=False, name=None):
749  """Compute the matrix rank of one or more matrices.
750
751  Args:
752    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
753      pseudo-inverted.
754    tol: Threshold below which the singular value is counted as 'zero'.
755      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
756    validate_args: When `True`, additional assertions might be embedded in the
757      graph.
758      Default value: `False` (i.e., no graph assertions are added).
759    name: Python `str` prefixed to ops created by this function.
760      Default value: 'matrix_rank'.
761
762  Returns:
763    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
764      singular values.
765  """
766  with ops.name_scope(name or 'matrix_rank'):
767    a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a')
768    assertions = _maybe_validate_matrix(a, validate_args)
769    if assertions:
770      with ops.control_dependencies(assertions):
771        a = array_ops.identity(a)
772    s = svd(a, compute_uv=False)
773    if tol is None:
774      if (a.shape[-2:]).is_fully_defined():
775        m = np.max(a.shape[-2:].as_list())
776      else:
777        m = math_ops.reduce_max(array_ops.shape(a)[-2:])
778      eps = np.finfo(a.dtype.as_numpy_dtype).eps
779      tol = (
780          eps * math_ops.cast(m, a.dtype) *
781          math_ops.reduce_max(s, axis=-1, keepdims=True))
782    return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1)
783
784
785@tf_export('linalg.pinv')
786@dispatch.add_dispatch_support
787def pinv(a, rcond=None, validate_args=False, name=None):
788  """Compute the Moore-Penrose pseudo-inverse of one or more matrices.
789
790  Calculate the [generalized inverse of a matrix](
791  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
792  singular-value decomposition (SVD) and including all large singular values.
793
794  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
795  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
796  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
797  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
798  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]
799
800  This function is analogous to [`numpy.linalg.pinv`](
801  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
802  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
803  default `rcond` is `1e-15`. Here the default is
804  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
805
806  Args:
807    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
808      pseudo-inverted.
809    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
810      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
811      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
812      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
813    validate_args: When `True`, additional assertions might be embedded in the
814      graph.
815      Default value: `False` (i.e., no graph assertions are added).
816    name: Python `str` prefixed to ops created by this function.
817      Default value: 'pinv'.
818
819  Returns:
820    a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except
821      rightmost two dimensions are transposed.
822
823  Raises:
824    TypeError: if input `a` does not have `float`-like `dtype`.
825    ValueError: if input `a` has fewer than 2 dimensions.
826
827  #### Examples
828
829  ```python
830  import tensorflow as tf
831  import tensorflow_probability as tfp
832
833  a = tf.constant([[1.,  0.4,  0.5],
834                   [0.4, 0.2,  0.25],
835                   [0.5, 0.25, 0.35]])
836  tf.matmul(tf.linalg..pinv(a), a)
837  # ==> array([[1., 0., 0.],
838               [0., 1., 0.],
839               [0., 0., 1.]], dtype=float32)
840
841  a = tf.constant([[1.,  0.4,  0.5,  1.],
842                   [0.4, 0.2,  0.25, 2.],
843                   [0.5, 0.25, 0.35, 3.]])
844  tf.matmul(tf.linalg..pinv(a), a)
845  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
846               [ 0.37,  0.43, -0.33,  0.02],
847               [ 0.21, -0.33,  0.81,  0.01],
848               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
849  ```
850
851  #### References
852
853  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
854       Inc., 1980, pp. 139-142.
855  """
856  with ops.name_scope(name or 'pinv'):
857    a = ops.convert_to_tensor(a, name='a')
858
859    assertions = _maybe_validate_matrix(a, validate_args)
860    if assertions:
861      with ops.control_dependencies(assertions):
862        a = array_ops.identity(a)
863
864    dtype = a.dtype.as_numpy_dtype
865
866    if rcond is None:
867
868      def get_dim_size(dim):
869        dim_val = tensor_shape.dimension_value(a.shape[dim])
870        if dim_val is not None:
871          return dim_val
872        return array_ops.shape(a)[dim]
873
874      num_rows = get_dim_size(-2)
875      num_cols = get_dim_size(-1)
876      if isinstance(num_rows, int) and isinstance(num_cols, int):
877        max_rows_cols = float(max(num_rows, num_cols))
878      else:
879        max_rows_cols = math_ops.cast(
880            math_ops.maximum(num_rows, num_cols), dtype)
881      rcond = 10. * max_rows_cols * np.finfo(dtype).eps
882
883    rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond')
884
885    # Calculate pseudo inverse via SVD.
886    # Note: if a is Hermitian then u == v. (We might observe additional
887    # performance by explicitly setting `v = u` in such cases.)
888    [
889        singular_values,  # Sigma
890        left_singular_vectors,  # U
891        right_singular_vectors,  # V
892    ] = svd(
893        a, full_matrices=False, compute_uv=True)
894
895    # Saturate small singular values to inf. This has the effect of make
896    # `1. / s = 0.` while not resulting in `NaN` gradients.
897    cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1)
898    singular_values = array_ops.where_v2(
899        singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values,
900        np.array(np.inf, dtype))
901
902    # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse
903    # is defined as `pinv(a) == v @ inv(s) @ u^H`.
904    a_pinv = math_ops.matmul(
905        right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2),
906        left_singular_vectors,
907        adjoint_b=True)
908
909    if a.shape is not None and a.shape.rank is not None:
910      a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))
911
912    return a_pinv
913
914
915@tf_export('linalg.lu_solve')
916@dispatch.add_dispatch_support
917def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
918  """Solves systems of linear eqns `A X = RHS`, given LU factorizations.
919
920  Note: this function does not verify the implied matrix is actually invertible
921  nor is this condition checked even when `validate_args=True`.
922
923  Args:
924    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
925      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
926    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
927      X` then `perm = argmax(P)`.
928    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
929      `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[...,
930        tf.newaxis])[..., 0]`.
931    validate_args: Python `bool` indicating whether arguments should be checked
932      for correctness. Note: this function does not verify the implied matrix is
933        actually invertible, even when `validate_args=True`.
934      Default value: `False` (i.e., don't validate arguments).
935    name: Python `str` name given to ops managed by this object.
936      Default value: `None` (i.e., 'lu_solve').
937
938  Returns:
939    x: The `X` in `A @ X = RHS`.
940
941  #### Examples
942
943  ```python
944  import numpy as np
945  import tensorflow as tf
946  import tensorflow_probability as tfp
947
948  x = [[[1., 2],
949        [3, 4]],
950       [[7, 8],
951        [3, 4]]]
952  inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
953  tf.assert_near(tf.matrix_inverse(x), inv_x)
954  # ==> True
955  ```
956
957  """
958
959  with ops.name_scope(name or 'lu_solve'):
960    lower_upper = ops.convert_to_tensor(
961        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
962    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
963    rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs')
964
965    assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
966    if assertions:
967      with ops.control_dependencies(assertions):
968        lower_upper = array_ops.identity(lower_upper)
969        perm = array_ops.identity(perm)
970        rhs = array_ops.identity(rhs)
971
972    if (rhs.shape.rank == 2 and perm.shape.rank == 1):
973      # Both rhs and perm have scalar batch_shape.
974      permuted_rhs = array_ops.gather(rhs, perm, axis=-2)
975    else:
976      # Either rhs or perm have non-scalar batch_shape or we can't determine
977      # this information statically.
978      rhs_shape = array_ops.shape(rhs)
979      broadcast_batch_shape = array_ops.broadcast_dynamic_shape(
980          rhs_shape[:-2],
981          array_ops.shape(perm)[:-1])
982      d, m = rhs_shape[-2], rhs_shape[-1]
983      rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]],
984                                             axis=0)
985
986      # Tile out rhs.
987      broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape)
988      broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m])
989
990      # Tile out perm and add batch indices.
991      broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1])
992      broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d])
993      broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape)
994      broadcast_batch_indices = array_ops.broadcast_to(
995          math_ops.range(broadcast_batch_size)[:, array_ops.newaxis],
996          [broadcast_batch_size, d])
997      broadcast_perm = array_ops.stack(
998          [broadcast_batch_indices, broadcast_perm], axis=-1)
999
1000      permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm)
1001      permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape)
1002
1003    lower = set_diag(
1004        band_part(lower_upper, num_lower=-1, num_upper=0),
1005        array_ops.ones(
1006            array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
1007    return triangular_solve(
1008        lower_upper,  # Only upper is accessed.
1009        triangular_solve(lower, permuted_rhs),
1010        lower=False)
1011
1012
1013@tf_export('linalg.lu_matrix_inverse')
1014@dispatch.add_dispatch_support
1015def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
1016  """Computes the inverse given the LU decomposition(s) of one or more matrices.
1017
1018  This op is conceptually identical to,
1019
1020  ```python
1021  inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
1022  tf.assert_near(tf.matrix_inverse(X), inv_X)
1023  # ==> True
1024  ```
1025
1026  Note: this function does not verify the implied matrix is actually invertible
1027  nor is this condition checked even when `validate_args=True`.
1028
1029  Args:
1030    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1031      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1032    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1033      X` then `perm = argmax(P)`.
1034    validate_args: Python `bool` indicating whether arguments should be checked
1035      for correctness. Note: this function does not verify the implied matrix is
1036        actually invertible, even when `validate_args=True`.
1037      Default value: `False` (i.e., don't validate arguments).
1038    name: Python `str` name given to ops managed by this object.
1039      Default value: `None` (i.e., 'lu_matrix_inverse').
1040
1041  Returns:
1042    inv_x: The matrix_inv, i.e.,
1043      `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`.
1044
1045  #### Examples
1046
1047  ```python
1048  import numpy as np
1049  import tensorflow as tf
1050  import tensorflow_probability as tfp
1051
1052  x = [[[3., 4], [1, 2]],
1053       [[7., 8], [3, 4]]]
1054  inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x))
1055  tf.assert_near(tf.matrix_inverse(x), inv_x)
1056  # ==> True
1057  ```
1058
1059  """
1060
1061  with ops.name_scope(name or 'lu_matrix_inverse'):
1062    lower_upper = ops.convert_to_tensor(
1063        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1064    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1065    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1066    if assertions:
1067      with ops.control_dependencies(assertions):
1068        lower_upper = array_ops.identity(lower_upper)
1069        perm = array_ops.identity(perm)
1070    shape = array_ops.shape(lower_upper)
1071    return lu_solve(
1072        lower_upper,
1073        perm,
1074        rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype),
1075        validate_args=False)
1076
1077
1078@tf_export('linalg.lu_reconstruct')
1079@dispatch.add_dispatch_support
1080def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
1081  """The reconstruct one or more matrices from their LU decomposition(s).
1082
1083  Args:
1084    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1085      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1086    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1087      X` then `perm = argmax(P)`.
1088    validate_args: Python `bool` indicating whether arguments should be checked
1089      for correctness.
1090      Default value: `False` (i.e., don't validate arguments).
1091    name: Python `str` name given to ops managed by this object.
1092      Default value: `None` (i.e., 'lu_reconstruct').
1093
1094  Returns:
1095    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
1096      `lu_reconstruct(*tf.linalg.lu(x))`.
1097
1098  #### Examples
1099
1100  ```python
1101  import numpy as np
1102  import tensorflow as tf
1103  import tensorflow_probability as tfp
1104
1105  x = [[[3., 4], [1, 2]],
1106       [[7., 8], [3, 4]]]
1107  x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
1108  tf.assert_near(x, x_reconstructed)
1109  # ==> True
1110  ```
1111
1112  """
1113  with ops.name_scope(name or 'lu_reconstruct'):
1114    lower_upper = ops.convert_to_tensor(
1115        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1116    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1117
1118    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1119    if assertions:
1120      with ops.control_dependencies(assertions):
1121        lower_upper = array_ops.identity(lower_upper)
1122        perm = array_ops.identity(perm)
1123
1124    shape = array_ops.shape(lower_upper)
1125
1126    lower = set_diag(
1127        band_part(lower_upper, num_lower=-1, num_upper=0),
1128        array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
1129    upper = band_part(lower_upper, num_lower=0, num_upper=-1)
1130    x = math_ops.matmul(lower, upper)
1131
1132    if (lower_upper.shape is None or lower_upper.shape.rank is None or
1133        lower_upper.shape.rank != 2):
1134      # We either don't know the batch rank or there are >0 batch dims.
1135      batch_size = math_ops.reduce_prod(shape[:-2])
1136      d = shape[-1]
1137      x = array_ops.reshape(x, [batch_size, d, d])
1138      perm = array_ops.reshape(perm, [batch_size, d])
1139      perm = map_fn.map_fn(array_ops.invert_permutation, perm)
1140      batch_indices = array_ops.broadcast_to(
1141          math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d])
1142      x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm],
1143                                                 axis=-1))
1144      x = array_ops.reshape(x, shape)
1145    else:
1146      x = array_ops.gather(x, array_ops.invert_permutation(perm))
1147
1148    x.set_shape(lower_upper.shape)
1149    return x
1150
1151
1152def lu_reconstruct_assertions(lower_upper, perm, validate_args):
1153  """Returns list of assertions related to `lu_reconstruct` assumptions."""
1154  assertions = []
1155
1156  message = 'Input `lower_upper` must have at least 2 dimensions.'
1157  if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2:
1158    raise ValueError(message)
1159  elif validate_args:
1160    assertions.append(
1161        check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message))
1162
1163  message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
1164  if lower_upper.shape.rank is not None and perm.shape.rank is not None:
1165    if lower_upper.shape.rank != perm.shape.rank + 1:
1166      raise ValueError(message)
1167  elif validate_args:
1168    assertions.append(
1169        check_ops.assert_rank(
1170            lower_upper, rank=array_ops.rank(perm) + 1, message=message))
1171
1172  message = '`lower_upper` must be square.'
1173  if lower_upper.shape[:-2].is_fully_defined():
1174    if lower_upper.shape[-2] != lower_upper.shape[-1]:
1175      raise ValueError(message)
1176  elif validate_args:
1177    m, n = array_ops.split(
1178        array_ops.shape(lower_upper)[-2:], num_or_size_splits=2)
1179    assertions.append(check_ops.assert_equal(m, n, message=message))
1180
1181  return assertions
1182
1183
1184def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
1185  """Returns list of assertions related to `lu_solve` assumptions."""
1186  assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1187
1188  message = 'Input `rhs` must have at least 2 dimensions.'
1189  if rhs.shape.ndims is not None:
1190    if rhs.shape.ndims < 2:
1191      raise ValueError(message)
1192  elif validate_args:
1193    assertions.append(
1194        check_ops.assert_rank_at_least(rhs, rank=2, message=message))
1195
1196  message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
1197  if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None):
1198    if lower_upper.shape[-1] != rhs.shape[-2]:
1199      raise ValueError(message)
1200  elif validate_args:
1201    assertions.append(
1202        check_ops.assert_equal(
1203            array_ops.shape(lower_upper)[-1],
1204            array_ops.shape(rhs)[-2],
1205            message=message))
1206
1207  return assertions
1208