1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for linear algebra.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import gen_linalg_ops 31from tensorflow.python.ops import linalg_ops 32from tensorflow.python.ops import map_fn 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import special_math_ops 35from tensorflow.python.util import dispatch 36from tensorflow.python.util.tf_export import tf_export 37 38# Linear algebra ops. 39band_part = array_ops.matrix_band_part 40cholesky = linalg_ops.cholesky 41cholesky_solve = linalg_ops.cholesky_solve 42det = linalg_ops.matrix_determinant 43slogdet = gen_linalg_ops.log_matrix_determinant 44tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet)) 45diag = array_ops.matrix_diag 46diag_part = array_ops.matrix_diag_part 47eigh = linalg_ops.self_adjoint_eig 48eigvalsh = linalg_ops.self_adjoint_eigvals 49einsum = special_math_ops.einsum 50eye = linalg_ops.eye 51inv = linalg_ops.matrix_inverse 52logm = gen_linalg_ops.matrix_logarithm 53lu = gen_linalg_ops.lu 54tf_export('linalg.logm')(dispatch.add_dispatch_support(logm)) 55lstsq = linalg_ops.matrix_solve_ls 56norm = linalg_ops.norm 57qr = linalg_ops.qr 58set_diag = array_ops.matrix_set_diag 59solve = linalg_ops.matrix_solve 60sqrtm = linalg_ops.matrix_square_root 61svd = linalg_ops.svd 62tensordot = math_ops.tensordot 63trace = math_ops.trace 64transpose = array_ops.matrix_transpose 65triangular_solve = linalg_ops.matrix_triangular_solve 66 67 68@tf_export('linalg.logdet') 69@dispatch.add_dispatch_support 70def logdet(matrix, name=None): 71 """Computes log of the determinant of a hermitian positive definite matrix. 72 73 ```python 74 # Compute the determinant of a matrix while reducing the chance of over- or 75 underflow: 76 A = ... # shape 10 x 10 77 det = tf.exp(tf.linalg.logdet(A)) # scalar 78 ``` 79 80 Args: 81 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 82 or `complex128` with shape `[..., M, M]`. 83 name: A name to give this `Op`. Defaults to `logdet`. 84 85 Returns: 86 The natural log of the determinant of `matrix`. 87 88 @compatibility(numpy) 89 Equivalent to numpy.linalg.slogdet, although no sign is returned since only 90 hermitian positive definite matrices are supported. 91 @end_compatibility 92 """ 93 # This uses the property that the log det(A) = 2*sum(log(real(diag(C)))) 94 # where C is the cholesky decomposition of A. 95 with ops.name_scope(name, 'logdet', [matrix]): 96 chol = gen_linalg_ops.cholesky(matrix) 97 return 2.0 * math_ops.reduce_sum( 98 math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), 99 axis=[-1]) 100 101 102@tf_export('linalg.adjoint') 103@dispatch.add_dispatch_support 104def adjoint(matrix, name=None): 105 """Transposes the last two dimensions of and conjugates tensor `matrix`. 106 107 For example: 108 109 ```python 110 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], 111 [4 + 4j, 5 + 5j, 6 + 6j]]) 112 tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j], 113 # [2 - 2j, 5 - 5j], 114 # [3 - 3j, 6 - 6j]] 115 ``` 116 117 Args: 118 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 119 or `complex128` with shape `[..., M, M]`. 120 name: A name to give this `Op` (optional). 121 122 Returns: 123 The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of 124 matrix. 125 """ 126 with ops.name_scope(name, 'adjoint', [matrix]): 127 matrix = ops.convert_to_tensor(matrix, name='matrix') 128 return array_ops.matrix_transpose(matrix, conjugate=True) 129 130 131# This section is ported nearly verbatim from Eigen's implementation: 132# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html 133def _matrix_exp_pade3(matrix): 134 """3rd-order Pade approximant for matrix exponential.""" 135 b = [120.0, 60.0, 12.0] 136 b = [constant_op.constant(x, matrix.dtype) for x in b] 137 ident = linalg_ops.eye( 138 array_ops.shape(matrix)[-2], 139 batch_shape=array_ops.shape(matrix)[:-2], 140 dtype=matrix.dtype) 141 matrix_2 = math_ops.matmul(matrix, matrix) 142 tmp = matrix_2 + b[1] * ident 143 matrix_u = math_ops.matmul(matrix, tmp) 144 matrix_v = b[2] * matrix_2 + b[0] * ident 145 return matrix_u, matrix_v 146 147 148def _matrix_exp_pade5(matrix): 149 """5th-order Pade approximant for matrix exponential.""" 150 b = [30240.0, 15120.0, 3360.0, 420.0, 30.0] 151 b = [constant_op.constant(x, matrix.dtype) for x in b] 152 ident = linalg_ops.eye( 153 array_ops.shape(matrix)[-2], 154 batch_shape=array_ops.shape(matrix)[:-2], 155 dtype=matrix.dtype) 156 matrix_2 = math_ops.matmul(matrix, matrix) 157 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 158 tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident 159 matrix_u = math_ops.matmul(matrix, tmp) 160 matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 161 return matrix_u, matrix_v 162 163 164def _matrix_exp_pade7(matrix): 165 """7th-order Pade approximant for matrix exponential.""" 166 b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0] 167 b = [constant_op.constant(x, matrix.dtype) for x in b] 168 ident = linalg_ops.eye( 169 array_ops.shape(matrix)[-2], 170 batch_shape=array_ops.shape(matrix)[:-2], 171 dtype=matrix.dtype) 172 matrix_2 = math_ops.matmul(matrix, matrix) 173 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 174 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 175 tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident 176 matrix_u = math_ops.matmul(matrix, tmp) 177 matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 178 return matrix_u, matrix_v 179 180 181def _matrix_exp_pade9(matrix): 182 """9th-order Pade approximant for matrix exponential.""" 183 b = [ 184 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 185 2162160.0, 110880.0, 3960.0, 90.0 186 ] 187 b = [constant_op.constant(x, matrix.dtype) for x in b] 188 ident = linalg_ops.eye( 189 array_ops.shape(matrix)[-2], 190 batch_shape=array_ops.shape(matrix)[:-2], 191 dtype=matrix.dtype) 192 matrix_2 = math_ops.matmul(matrix, matrix) 193 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 194 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 195 matrix_8 = math_ops.matmul(matrix_6, matrix_2) 196 tmp = ( 197 matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + 198 b[1] * ident) 199 matrix_u = math_ops.matmul(matrix, tmp) 200 matrix_v = ( 201 b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + 202 b[0] * ident) 203 return matrix_u, matrix_v 204 205 206def _matrix_exp_pade13(matrix): 207 """13th-order Pade approximant for matrix exponential.""" 208 b = [ 209 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, 210 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, 211 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0 212 ] 213 b = [constant_op.constant(x, matrix.dtype) for x in b] 214 ident = linalg_ops.eye( 215 array_ops.shape(matrix)[-2], 216 batch_shape=array_ops.shape(matrix)[:-2], 217 dtype=matrix.dtype) 218 matrix_2 = math_ops.matmul(matrix, matrix) 219 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 220 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 221 tmp_u = ( 222 math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) + 223 b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident) 224 matrix_u = math_ops.matmul(matrix, tmp_u) 225 tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2 226 matrix_v = ( 227 math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 + 228 b[2] * matrix_2 + b[0] * ident) 229 return matrix_u, matrix_v 230 231 232@tf_export('linalg.expm') 233@dispatch.add_dispatch_support 234def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin 235 r"""Computes the matrix exponential of one or more square matrices. 236 237 $$exp(A) = \sum_{n=0}^\infty A^n/n!$$ 238 239 The exponential is computed using a combination of the scaling and squaring 240 method and the Pade approximation. Details can be found in: 241 Nicholas J. Higham, "The scaling and squaring method for the matrix 242 exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. 243 244 The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions 245 form square matrices. The output is a tensor of the same shape as the input 246 containing the exponential for all input submatrices `[..., :, :]`. 247 248 Args: 249 input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or 250 `complex128` with shape `[..., M, M]`. 251 name: A name to give this `Op` (optional). 252 253 Returns: 254 the matrix exponential of the input. 255 256 Raises: 257 ValueError: An unsupported type is provided as input. 258 259 @compatibility(scipy) 260 Equivalent to scipy.linalg.expm 261 @end_compatibility 262 """ 263 with ops.name_scope(name, 'matrix_exponential', [input]): 264 matrix = ops.convert_to_tensor(input, name='input') 265 if matrix.shape[-2:] == [0, 0]: 266 return matrix 267 batch_shape = matrix.shape[:-2] 268 if not batch_shape.is_fully_defined(): 269 batch_shape = array_ops.shape(matrix)[:-2] 270 271 # reshaping the batch makes the where statements work better 272 matrix = array_ops.reshape( 273 matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0)) 274 l1_norm = math_ops.reduce_max( 275 math_ops.reduce_sum( 276 math_ops.abs(matrix), 277 axis=array_ops.size(array_ops.shape(matrix)) - 2), 278 axis=-1)[..., array_ops.newaxis, array_ops.newaxis] 279 280 const = lambda x: constant_op.constant(x, l1_norm.dtype) 281 282 def _nest_where(vals, cases): 283 assert len(vals) == len(cases) - 1 284 if len(vals) == 1: 285 return array_ops.where_v2( 286 math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) 287 else: 288 return array_ops.where_v2( 289 math_ops.less(l1_norm, const(vals[0])), cases[0], 290 _nest_where(vals[1:], cases[1:])) 291 292 if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]: 293 maxnorm = const(3.925724783138660) 294 squarings = math_ops.maximum( 295 math_ops.floor( 296 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 297 u3, v3 = _matrix_exp_pade3(matrix) 298 u5, v5 = _matrix_exp_pade5(matrix) 299 u7, v7 = _matrix_exp_pade7( 300 matrix / 301 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 302 conds = (4.258730016922831e-001, 1.880152677804762e+000) 303 u = _nest_where(conds, (u3, u5, u7)) 304 v = _nest_where(conds, (v3, v5, v7)) 305 elif matrix.dtype in [dtypes.float64, dtypes.complex128]: 306 maxnorm = const(5.371920351148152) 307 squarings = math_ops.maximum( 308 math_ops.floor( 309 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 310 u3, v3 = _matrix_exp_pade3(matrix) 311 u5, v5 = _matrix_exp_pade5(matrix) 312 u7, v7 = _matrix_exp_pade7(matrix) 313 u9, v9 = _matrix_exp_pade9(matrix) 314 u13, v13 = _matrix_exp_pade13( 315 matrix / 316 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 317 conds = (1.495585217958292e-002, 2.539398330063230e-001, 318 9.504178996162932e-001, 2.097847961257068e+000) 319 u = _nest_where(conds, (u3, u5, u7, u9, u13)) 320 v = _nest_where(conds, (v3, v5, v7, v9, v13)) 321 else: 322 raise ValueError('tf.linalg.expm does not support matrices of type %s' % 323 matrix.dtype) 324 325 is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm)) 326 nan = constant_op.constant(np.nan, matrix.dtype) 327 result = control_flow_ops.cond( 328 is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v), 329 lambda: array_ops.fill(array_ops.shape(matrix), nan)) 330 max_squarings = math_ops.reduce_max(squarings) 331 i = const(0.0) 332 333 def c(i, _): 334 return control_flow_ops.cond(is_finite, 335 lambda: math_ops.less(i, max_squarings), 336 lambda: constant_op.constant(False)) 337 338 def b(i, r): 339 return i + 1, array_ops.where_v2( 340 math_ops.less(i, squarings), math_ops.matmul(r, r), r) 341 342 _, result = control_flow_ops.while_loop(c, b, [i, result]) 343 if not matrix.shape.is_fully_defined(): 344 return array_ops.reshape( 345 result, 346 array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0)) 347 return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) 348 349 350@tf_export('linalg.banded_triangular_solve', v1=[]) 351def banded_triangular_solve( 352 bands, 353 rhs, 354 lower=True, 355 adjoint=False, # pylint: disable=redefined-outer-name 356 name=None): 357 r"""Solve triangular systems of equations with a banded solver. 358 359 `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number 360 of bands stored. This corresponds to a batch of `M` by `M` matrices, whose 361 `K` subdiagonals (when `lower` is `True`) are stored. 362 363 This operator broadcasts the batch dimensions of `bands` and the batch 364 dimensions of `rhs`. 365 366 367 Examples: 368 369 Storing 2 bands of a 3x3 matrix. 370 Note that first element in the second row is ignored due to 371 the 'LEFT_RIGHT' padding. 372 373 >>> x = [[2., 3., 4.], [1., 2., 3.]] 374 >>> x2 = [[2., 3., 4.], [10000., 2., 3.]] 375 >>> y = tf.zeros([3, 3]) 376 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0)) 377 >>> z 378 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 379 array([[2., 0., 0.], 380 [2., 3., 0.], 381 [0., 3., 4.]], dtype=float32)> 382 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1])) 383 >>> soln 384 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 385 array([[0.5 ], 386 [0. ], 387 [0.25]], dtype=float32)> 388 >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1])) 389 >>> tf.reduce_all(are_equal).numpy() 390 True 391 >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1])) 392 >>> tf.reduce_all(are_equal).numpy() 393 True 394 395 Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding 396 the last element of the first row is ignored. 397 398 >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]] 399 >>> y = tf.zeros([4, 4]) 400 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1)) 401 >>> z 402 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 403 array([[-1., 2., 0., 0.], 404 [ 0., -2., 3., 0.], 405 [ 0., 0., -3., 4.], 406 [ 0., 0., -0., -4.]], dtype=float32)> 407 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False) 408 >>> soln 409 <tf.Tensor: shape=(4, 1), dtype=float32, numpy= 410 array([[-4. ], 411 [-1.5 ], 412 [-0.6666667], 413 [-0.25 ]], dtype=float32)> 414 >>> are_equal = (soln == tf.linalg.triangular_solve( 415 ... z, tf.ones([4, 1]), lower=False)) 416 >>> tf.reduce_all(are_equal).numpy() 417 True 418 419 420 Args: 421 bands: A `Tensor` describing the bands of the left hand side, with shape 422 `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th 423 diagonal (the diagonal is the top row) when `lower` is `True` and 424 otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is 425 the bottom row) when `lower` is `False`. The bands are stored with 426 'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right 427 and subdiagonals are padded on the left. This is the alignment cuSPARSE 428 uses. See `tf.linalg.set_diag` for more details. 429 rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as 430 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 431 statically, `rhs` will be treated as a matrix rather than a vector. 432 lower: An optional `bool`. Defaults to `True`. Boolean indicating whether 433 `bands` represents a lower or upper triangular matrix. 434 adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether 435 to solve with the matrix's block-wise adjoint. 436 name: A name to give this `Op` (optional). 437 438 Returns: 439 A `Tensor` of shape [..., M] or [..., M, N] containing the solutions. 440 """ 441 with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]): 442 return gen_linalg_ops.banded_triangular_solve( 443 bands, rhs, lower=lower, adjoint=adjoint) 444 445 446@tf_export('linalg.tridiagonal_solve') 447@dispatch.add_dispatch_support 448def tridiagonal_solve(diagonals, 449 rhs, 450 diagonals_format='compact', 451 transpose_rhs=False, 452 conjugate_rhs=False, 453 name=None, 454 partial_pivoting=True): 455 r"""Solves tridiagonal systems of equations. 456 457 The input can be supplied in various formats: `matrix`, `sequence` and 458 `compact`, specified by the `diagonals_format` arg. 459 460 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 461 two inner-most dimensions representing the square tridiagonal matrices. 462 Elements outside of the three diagonals will be ignored. 463 464 In `sequence` format, `diagonals` are supplied as a tuple or list of three 465 tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing 466 superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either 467 `M-1` or `M`; in the latter case, the last element of superdiagonal and the 468 first element of subdiagonal will be ignored. 469 470 In `compact` format the three diagonals are brought together into one tensor 471 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 472 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 473 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 474 475 The `compact` format is recommended as the one with best performance. In case 476 you need to cast a tensor into a compact format manually, use `tf.gather_nd`. 477 An example for a tensor of shape [m, m]: 478 479 ```python 480 rhs = tf.constant([...]) 481 matrix = tf.constant([[...]]) 482 m = matrix.shape[0] 483 dummy_idx = [0, 0] # An arbitrary element to use as a dummy 484 indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal 485 [[i, i] for i in range(m)], # Diagonal 486 [dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal 487 diagonals=tf.gather_nd(matrix, indices) 488 x = tf.linalg.tridiagonal_solve(diagonals, rhs) 489 ``` 490 491 Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or 492 `[..., M, K]`. The latter allows to simultaneously solve K systems with the 493 same left-hand sides and K different right-hand sides. If `transpose_rhs` 494 is set to `True` the expected shape is `[..., M]` or `[..., K, M]`. 495 496 The batch dimensions, denoted as `...`, must be the same in `diagonals` and 497 `rhs`. 498 499 The output is a tensor of the same shape as `rhs`: either `[..., M]` or 500 `[..., M, K]`. 501 502 The op isn't guaranteed to raise an error if the input matrix is not 503 invertible. `tf.debugging.check_numerics` can be applied to the output to 504 detect invertibility problems. 505 506 **Note**: with large batch sizes, the computation on the GPU may be slow, if 507 either `partial_pivoting=True` or there are multiple right-hand sides 508 (`K > 1`). If this issue arises, consider if it's possible to disable pivoting 509 and have `K = 1`, or, alternatively, consider using CPU. 510 511 On CPU, solution is computed via Gaussian elimination with or without partial 512 pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE 513 library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv 514 515 Args: 516 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 517 shape depends of `diagonals_format`, see description above. Must be 518 `float32`, `float64`, `complex64`, or `complex128`. 519 rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as 520 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 521 statically, `rhs` will be treated as a matrix rather than a vector. 522 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 523 `compact`. 524 transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect 525 if the shape of rhs is [..., M]). 526 conjugate_rhs: If `True`, `rhs` is conjugated before solving. 527 name: A name to give this `Op` (optional). 528 partial_pivoting: whether to perform partial pivoting. `True` by default. 529 Partial pivoting makes the procedure more stable, but slower. Partial 530 pivoting is unnecessary in some cases, including diagonally dominant and 531 symmetric positive definite matrices (see e.g. theorem 9.12 in [1]). 532 533 Returns: 534 A `Tensor` of shape [..., M] or [..., M, K] containing the solutions. 535 536 Raises: 537 ValueError: An unsupported type is provided as input, or when the input 538 tensors have incorrect shapes. 539 UnimplementedError: Whenever `partial_pivoting` is true and the backend is 540 XLA. 541 542 [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: 543 Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. 544 545 """ 546 if diagonals_format == 'compact': 547 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 548 conjugate_rhs, partial_pivoting, 549 name) 550 551 if diagonals_format == 'sequence': 552 if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3: 553 raise ValueError('Expected diagonals to be a sequence of length 3.') 554 555 superdiag, maindiag, subdiag = diagonals 556 if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or 557 not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])): 558 raise ValueError( 559 'Tensors representing the three diagonals must have the same shape,' 560 'except for the last dimension, got {}, {}, {}'.format( 561 subdiag.shape, maindiag.shape, superdiag.shape)) 562 563 m = tensor_shape.dimension_value(maindiag.shape[-1]) 564 565 def pad_if_necessary(t, name, last_dim_padding): 566 n = tensor_shape.dimension_value(t.shape[-1]) 567 if not n or n == m: 568 return t 569 if n == m - 1: 570 paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] + 571 [last_dim_padding]) 572 return array_ops.pad(t, paddings) 573 raise ValueError('Expected {} to be have length {} or {}, got {}.'.format( 574 name, m, m - 1, n)) 575 576 subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0]) 577 superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1]) 578 579 diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2) 580 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 581 conjugate_rhs, partial_pivoting, 582 name) 583 584 if diagonals_format == 'matrix': 585 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 586 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 587 if m1 and m2 and m1 != m2: 588 raise ValueError( 589 'Expected last two dimensions of diagonals to be same, got {} and {}' 590 .format(m1, m2)) 591 m = m1 or m2 592 diagonals = array_ops.matrix_diag_part( 593 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 594 return _tridiagonal_solve_compact_format( 595 diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name) 596 597 raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format)) 598 599 600def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 601 conjugate_rhs, partial_pivoting, name): 602 """Helper function used after the input has been cast to compact form.""" 603 diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank 604 605 # If we know the rank of the diagonal tensor, do some static checking. 606 if diags_rank: 607 if diags_rank < 2: 608 raise ValueError( 609 'Expected diagonals to have rank at least 2, got {}'.format( 610 diags_rank)) 611 if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1: 612 raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format( 613 diags_rank - 1, diags_rank, rhs_rank)) 614 if (rhs_rank and not diagonals.shape[:-2].is_compatible_with( 615 rhs.shape[:diags_rank - 2])): 616 raise ValueError('Batch shapes {} and {} are incompatible'.format( 617 diagonals.shape[:-2], rhs.shape[:diags_rank - 2])) 618 619 if diagonals.shape[-2] and diagonals.shape[-2] != 3: 620 raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2])) 621 622 def check_num_lhs_matches_num_rhs(): 623 if (diagonals.shape[-1] and rhs.shape[-2] and 624 diagonals.shape[-1] != rhs.shape[-2]): 625 raise ValueError('Expected number of left-hand sided and right-hand ' 626 'sides to be equal, got {} and {}'.format( 627 diagonals.shape[-1], rhs.shape[-2])) 628 629 if rhs_rank and diags_rank and rhs_rank == diags_rank - 1: 630 # Rhs provided as a vector, ignoring transpose_rhs 631 if conjugate_rhs: 632 rhs = math_ops.conj(rhs) 633 rhs = array_ops.expand_dims(rhs, -1) 634 check_num_lhs_matches_num_rhs() 635 return array_ops.squeeze( 636 linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name), 637 -1) 638 639 if transpose_rhs: 640 rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs) 641 elif conjugate_rhs: 642 rhs = math_ops.conj(rhs) 643 644 check_num_lhs_matches_num_rhs() 645 return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name) 646 647 648@tf_export('linalg.tridiagonal_matmul') 649@dispatch.add_dispatch_support 650def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): 651 r"""Multiplies tridiagonal matrix by matrix. 652 653 `diagonals` is representation of 3-diagonal NxN matrix, which depends on 654 `diagonals_format`. 655 656 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 657 two inner-most dimensions representing the square tridiagonal matrices. 658 Elements outside of the three diagonals will be ignored. 659 660 If `sequence` format, `diagonals` is list or tuple of three tensors: 661 `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element 662 of `superdiag` first element of `subdiag` are ignored. 663 664 In `compact` format the three diagonals are brought together into one tensor 665 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 666 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 667 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 668 669 The `sequence` format is recommended as the one with the best performance. 670 671 `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. 672 673 Example: 674 675 ```python 676 superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) 677 maindiag = tf.constant([2, 2, 2], dtype=tf.float64) 678 subdiag = tf.constant([0, -1, -1], dtype=tf.float64) 679 diagonals = [superdiag, maindiag, subdiag] 680 rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) 681 x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') 682 ``` 683 684 Args: 685 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 686 shape depends of `diagonals_format`, see description above. Must be 687 `float32`, `float64`, `complex64`, or `complex128`. 688 rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. 689 diagonals_format: one of `sequence`, or `compact`. Default is `compact`. 690 name: A name to give this `Op` (optional). 691 692 Returns: 693 A `Tensor` of shape [..., M, N] containing the result of multiplication. 694 695 Raises: 696 ValueError: An unsupported type is provided as input, or when the input 697 tensors have incorrect shapes. 698 """ 699 if diagonals_format == 'compact': 700 superdiag = diagonals[..., 0, :] 701 maindiag = diagonals[..., 1, :] 702 subdiag = diagonals[..., 2, :] 703 elif diagonals_format == 'sequence': 704 superdiag, maindiag, subdiag = diagonals 705 elif diagonals_format == 'matrix': 706 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 707 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 708 if m1 and m2 and m1 != m2: 709 raise ValueError( 710 'Expected last two dimensions of diagonals to be same, got {} and {}' 711 .format(m1, m2)) 712 diags = array_ops.matrix_diag_part( 713 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 714 superdiag = diags[..., 0, :] 715 maindiag = diags[..., 1, :] 716 subdiag = diags[..., 2, :] 717 else: 718 raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) 719 720 # C++ backend requires matrices. 721 # Converting 1-dimensional vectors to matrices with 1 row. 722 superdiag = array_ops.expand_dims(superdiag, -2) 723 maindiag = array_ops.expand_dims(maindiag, -2) 724 subdiag = array_ops.expand_dims(subdiag, -2) 725 726 return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name) 727 728 729def _maybe_validate_matrix(a, validate_args): 730 """Checks that input is a `float` matrix.""" 731 assertions = [] 732 if not a.dtype.is_floating: 733 raise TypeError('Input `a` must have `float`-like `dtype` ' 734 '(saw {}).'.format(a.dtype.name)) 735 if a.shape is not None and a.shape.rank is not None: 736 if a.shape.rank < 2: 737 raise ValueError('Input `a` must have at least 2 dimensions ' 738 '(saw: {}).'.format(a.shape.rank)) 739 elif validate_args: 740 assertions.append( 741 check_ops.assert_rank_at_least( 742 a, rank=2, message='Input `a` must have at least 2 dimensions.')) 743 return assertions 744 745 746@tf_export('linalg.matrix_rank') 747@dispatch.add_dispatch_support 748def matrix_rank(a, tol=None, validate_args=False, name=None): 749 """Compute the matrix rank of one or more matrices. 750 751 Args: 752 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 753 pseudo-inverted. 754 tol: Threshold below which the singular value is counted as 'zero'. 755 Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). 756 validate_args: When `True`, additional assertions might be embedded in the 757 graph. 758 Default value: `False` (i.e., no graph assertions are added). 759 name: Python `str` prefixed to ops created by this function. 760 Default value: 'matrix_rank'. 761 762 Returns: 763 matrix_rank: (Batch of) `int32` scalars representing the number of non-zero 764 singular values. 765 """ 766 with ops.name_scope(name or 'matrix_rank'): 767 a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a') 768 assertions = _maybe_validate_matrix(a, validate_args) 769 if assertions: 770 with ops.control_dependencies(assertions): 771 a = array_ops.identity(a) 772 s = svd(a, compute_uv=False) 773 if tol is None: 774 if (a.shape[-2:]).is_fully_defined(): 775 m = np.max(a.shape[-2:].as_list()) 776 else: 777 m = math_ops.reduce_max(array_ops.shape(a)[-2:]) 778 eps = np.finfo(a.dtype.as_numpy_dtype).eps 779 tol = ( 780 eps * math_ops.cast(m, a.dtype) * 781 math_ops.reduce_max(s, axis=-1, keepdims=True)) 782 return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1) 783 784 785@tf_export('linalg.pinv') 786@dispatch.add_dispatch_support 787def pinv(a, rcond=None, validate_args=False, name=None): 788 """Compute the Moore-Penrose pseudo-inverse of one or more matrices. 789 790 Calculate the [generalized inverse of a matrix]( 791 https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its 792 singular-value decomposition (SVD) and including all large singular values. 793 794 The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' 795 [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then 796 `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if 797 `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then 798 `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] 799 800 This function is analogous to [`numpy.linalg.pinv`]( 801 https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). 802 It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the 803 default `rcond` is `1e-15`. Here the default is 804 `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. 805 806 Args: 807 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 808 pseudo-inverted. 809 rcond: `Tensor` of small singular value cutoffs. Singular values smaller 810 (in modulus) than `rcond` * largest_singular_value (again, in modulus) are 811 set to zero. Must broadcast against `tf.shape(a)[:-2]`. 812 Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. 813 validate_args: When `True`, additional assertions might be embedded in the 814 graph. 815 Default value: `False` (i.e., no graph assertions are added). 816 name: Python `str` prefixed to ops created by this function. 817 Default value: 'pinv'. 818 819 Returns: 820 a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except 821 rightmost two dimensions are transposed. 822 823 Raises: 824 TypeError: if input `a` does not have `float`-like `dtype`. 825 ValueError: if input `a` has fewer than 2 dimensions. 826 827 #### Examples 828 829 ```python 830 import tensorflow as tf 831 import tensorflow_probability as tfp 832 833 a = tf.constant([[1., 0.4, 0.5], 834 [0.4, 0.2, 0.25], 835 [0.5, 0.25, 0.35]]) 836 tf.matmul(tf.linalg..pinv(a), a) 837 # ==> array([[1., 0., 0.], 838 [0., 1., 0.], 839 [0., 0., 1.]], dtype=float32) 840 841 a = tf.constant([[1., 0.4, 0.5, 1.], 842 [0.4, 0.2, 0.25, 2.], 843 [0.5, 0.25, 0.35, 3.]]) 844 tf.matmul(tf.linalg..pinv(a), a) 845 # ==> array([[ 0.76, 0.37, 0.21, -0.02], 846 [ 0.37, 0.43, -0.33, 0.02], 847 [ 0.21, -0.33, 0.81, 0.01], 848 [-0.02, 0.02, 0.01, 1. ]], dtype=float32) 849 ``` 850 851 #### References 852 853 [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, 854 Inc., 1980, pp. 139-142. 855 """ 856 with ops.name_scope(name or 'pinv'): 857 a = ops.convert_to_tensor(a, name='a') 858 859 assertions = _maybe_validate_matrix(a, validate_args) 860 if assertions: 861 with ops.control_dependencies(assertions): 862 a = array_ops.identity(a) 863 864 dtype = a.dtype.as_numpy_dtype 865 866 if rcond is None: 867 868 def get_dim_size(dim): 869 dim_val = tensor_shape.dimension_value(a.shape[dim]) 870 if dim_val is not None: 871 return dim_val 872 return array_ops.shape(a)[dim] 873 874 num_rows = get_dim_size(-2) 875 num_cols = get_dim_size(-1) 876 if isinstance(num_rows, int) and isinstance(num_cols, int): 877 max_rows_cols = float(max(num_rows, num_cols)) 878 else: 879 max_rows_cols = math_ops.cast( 880 math_ops.maximum(num_rows, num_cols), dtype) 881 rcond = 10. * max_rows_cols * np.finfo(dtype).eps 882 883 rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond') 884 885 # Calculate pseudo inverse via SVD. 886 # Note: if a is Hermitian then u == v. (We might observe additional 887 # performance by explicitly setting `v = u` in such cases.) 888 [ 889 singular_values, # Sigma 890 left_singular_vectors, # U 891 right_singular_vectors, # V 892 ] = svd( 893 a, full_matrices=False, compute_uv=True) 894 895 # Saturate small singular values to inf. This has the effect of make 896 # `1. / s = 0.` while not resulting in `NaN` gradients. 897 cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1) 898 singular_values = array_ops.where_v2( 899 singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values, 900 np.array(np.inf, dtype)) 901 902 # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse 903 # is defined as `pinv(a) == v @ inv(s) @ u^H`. 904 a_pinv = math_ops.matmul( 905 right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2), 906 left_singular_vectors, 907 adjoint_b=True) 908 909 if a.shape is not None and a.shape.rank is not None: 910 a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]])) 911 912 return a_pinv 913 914 915@tf_export('linalg.lu_solve') 916@dispatch.add_dispatch_support 917def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): 918 """Solves systems of linear eqns `A X = RHS`, given LU factorizations. 919 920 Note: this function does not verify the implied matrix is actually invertible 921 nor is this condition checked even when `validate_args=True`. 922 923 Args: 924 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 925 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 926 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 927 X` then `perm = argmax(P)`. 928 rhs: Matrix-shaped float `Tensor` representing targets for which to solve; 929 `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., 930 tf.newaxis])[..., 0]`. 931 validate_args: Python `bool` indicating whether arguments should be checked 932 for correctness. Note: this function does not verify the implied matrix is 933 actually invertible, even when `validate_args=True`. 934 Default value: `False` (i.e., don't validate arguments). 935 name: Python `str` name given to ops managed by this object. 936 Default value: `None` (i.e., 'lu_solve'). 937 938 Returns: 939 x: The `X` in `A @ X = RHS`. 940 941 #### Examples 942 943 ```python 944 import numpy as np 945 import tensorflow as tf 946 import tensorflow_probability as tfp 947 948 x = [[[1., 2], 949 [3, 4]], 950 [[7, 8], 951 [3, 4]]] 952 inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) 953 tf.assert_near(tf.matrix_inverse(x), inv_x) 954 # ==> True 955 ``` 956 957 """ 958 959 with ops.name_scope(name or 'lu_solve'): 960 lower_upper = ops.convert_to_tensor( 961 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 962 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 963 rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') 964 965 assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) 966 if assertions: 967 with ops.control_dependencies(assertions): 968 lower_upper = array_ops.identity(lower_upper) 969 perm = array_ops.identity(perm) 970 rhs = array_ops.identity(rhs) 971 972 if (rhs.shape.rank == 2 and perm.shape.rank == 1): 973 # Both rhs and perm have scalar batch_shape. 974 permuted_rhs = array_ops.gather(rhs, perm, axis=-2) 975 else: 976 # Either rhs or perm have non-scalar batch_shape or we can't determine 977 # this information statically. 978 rhs_shape = array_ops.shape(rhs) 979 broadcast_batch_shape = array_ops.broadcast_dynamic_shape( 980 rhs_shape[:-2], 981 array_ops.shape(perm)[:-1]) 982 d, m = rhs_shape[-2], rhs_shape[-1] 983 rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]], 984 axis=0) 985 986 # Tile out rhs. 987 broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape) 988 broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m]) 989 990 # Tile out perm and add batch indices. 991 broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1]) 992 broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d]) 993 broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape) 994 broadcast_batch_indices = array_ops.broadcast_to( 995 math_ops.range(broadcast_batch_size)[:, array_ops.newaxis], 996 [broadcast_batch_size, d]) 997 broadcast_perm = array_ops.stack( 998 [broadcast_batch_indices, broadcast_perm], axis=-1) 999 1000 permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm) 1001 permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape) 1002 1003 lower = set_diag( 1004 band_part(lower_upper, num_lower=-1, num_upper=0), 1005 array_ops.ones( 1006 array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) 1007 return triangular_solve( 1008 lower_upper, # Only upper is accessed. 1009 triangular_solve(lower, permuted_rhs), 1010 lower=False) 1011 1012 1013@tf_export('linalg.lu_matrix_inverse') 1014@dispatch.add_dispatch_support 1015def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): 1016 """Computes the inverse given the LU decomposition(s) of one or more matrices. 1017 1018 This op is conceptually identical to, 1019 1020 ```python 1021 inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) 1022 tf.assert_near(tf.matrix_inverse(X), inv_X) 1023 # ==> True 1024 ``` 1025 1026 Note: this function does not verify the implied matrix is actually invertible 1027 nor is this condition checked even when `validate_args=True`. 1028 1029 Args: 1030 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1031 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1032 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1033 X` then `perm = argmax(P)`. 1034 validate_args: Python `bool` indicating whether arguments should be checked 1035 for correctness. Note: this function does not verify the implied matrix is 1036 actually invertible, even when `validate_args=True`. 1037 Default value: `False` (i.e., don't validate arguments). 1038 name: Python `str` name given to ops managed by this object. 1039 Default value: `None` (i.e., 'lu_matrix_inverse'). 1040 1041 Returns: 1042 inv_x: The matrix_inv, i.e., 1043 `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`. 1044 1045 #### Examples 1046 1047 ```python 1048 import numpy as np 1049 import tensorflow as tf 1050 import tensorflow_probability as tfp 1051 1052 x = [[[3., 4], [1, 2]], 1053 [[7., 8], [3, 4]]] 1054 inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x)) 1055 tf.assert_near(tf.matrix_inverse(x), inv_x) 1056 # ==> True 1057 ``` 1058 1059 """ 1060 1061 with ops.name_scope(name or 'lu_matrix_inverse'): 1062 lower_upper = ops.convert_to_tensor( 1063 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1064 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1065 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1066 if assertions: 1067 with ops.control_dependencies(assertions): 1068 lower_upper = array_ops.identity(lower_upper) 1069 perm = array_ops.identity(perm) 1070 shape = array_ops.shape(lower_upper) 1071 return lu_solve( 1072 lower_upper, 1073 perm, 1074 rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), 1075 validate_args=False) 1076 1077 1078@tf_export('linalg.lu_reconstruct') 1079@dispatch.add_dispatch_support 1080def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): 1081 """The reconstruct one or more matrices from their LU decomposition(s). 1082 1083 Args: 1084 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1085 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1086 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1087 X` then `perm = argmax(P)`. 1088 validate_args: Python `bool` indicating whether arguments should be checked 1089 for correctness. 1090 Default value: `False` (i.e., don't validate arguments). 1091 name: Python `str` name given to ops managed by this object. 1092 Default value: `None` (i.e., 'lu_reconstruct'). 1093 1094 Returns: 1095 x: The original input to `tf.linalg.lu`, i.e., `x` as in, 1096 `lu_reconstruct(*tf.linalg.lu(x))`. 1097 1098 #### Examples 1099 1100 ```python 1101 import numpy as np 1102 import tensorflow as tf 1103 import tensorflow_probability as tfp 1104 1105 x = [[[3., 4], [1, 2]], 1106 [[7., 8], [3, 4]]] 1107 x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x)) 1108 tf.assert_near(x, x_reconstructed) 1109 # ==> True 1110 ``` 1111 1112 """ 1113 with ops.name_scope(name or 'lu_reconstruct'): 1114 lower_upper = ops.convert_to_tensor( 1115 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1116 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1117 1118 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1119 if assertions: 1120 with ops.control_dependencies(assertions): 1121 lower_upper = array_ops.identity(lower_upper) 1122 perm = array_ops.identity(perm) 1123 1124 shape = array_ops.shape(lower_upper) 1125 1126 lower = set_diag( 1127 band_part(lower_upper, num_lower=-1, num_upper=0), 1128 array_ops.ones(shape[:-1], dtype=lower_upper.dtype)) 1129 upper = band_part(lower_upper, num_lower=0, num_upper=-1) 1130 x = math_ops.matmul(lower, upper) 1131 1132 if (lower_upper.shape is None or lower_upper.shape.rank is None or 1133 lower_upper.shape.rank != 2): 1134 # We either don't know the batch rank or there are >0 batch dims. 1135 batch_size = math_ops.reduce_prod(shape[:-2]) 1136 d = shape[-1] 1137 x = array_ops.reshape(x, [batch_size, d, d]) 1138 perm = array_ops.reshape(perm, [batch_size, d]) 1139 perm = map_fn.map_fn(array_ops.invert_permutation, perm) 1140 batch_indices = array_ops.broadcast_to( 1141 math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d]) 1142 x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm], 1143 axis=-1)) 1144 x = array_ops.reshape(x, shape) 1145 else: 1146 x = array_ops.gather(x, array_ops.invert_permutation(perm)) 1147 1148 x.set_shape(lower_upper.shape) 1149 return x 1150 1151 1152def lu_reconstruct_assertions(lower_upper, perm, validate_args): 1153 """Returns list of assertions related to `lu_reconstruct` assumptions.""" 1154 assertions = [] 1155 1156 message = 'Input `lower_upper` must have at least 2 dimensions.' 1157 if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2: 1158 raise ValueError(message) 1159 elif validate_args: 1160 assertions.append( 1161 check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message)) 1162 1163 message = '`rank(lower_upper)` must equal `rank(perm) + 1`' 1164 if lower_upper.shape.rank is not None and perm.shape.rank is not None: 1165 if lower_upper.shape.rank != perm.shape.rank + 1: 1166 raise ValueError(message) 1167 elif validate_args: 1168 assertions.append( 1169 check_ops.assert_rank( 1170 lower_upper, rank=array_ops.rank(perm) + 1, message=message)) 1171 1172 message = '`lower_upper` must be square.' 1173 if lower_upper.shape[:-2].is_fully_defined(): 1174 if lower_upper.shape[-2] != lower_upper.shape[-1]: 1175 raise ValueError(message) 1176 elif validate_args: 1177 m, n = array_ops.split( 1178 array_ops.shape(lower_upper)[-2:], num_or_size_splits=2) 1179 assertions.append(check_ops.assert_equal(m, n, message=message)) 1180 1181 return assertions 1182 1183 1184def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): 1185 """Returns list of assertions related to `lu_solve` assumptions.""" 1186 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1187 1188 message = 'Input `rhs` must have at least 2 dimensions.' 1189 if rhs.shape.ndims is not None: 1190 if rhs.shape.ndims < 2: 1191 raise ValueError(message) 1192 elif validate_args: 1193 assertions.append( 1194 check_ops.assert_rank_at_least(rhs, rank=2, message=message)) 1195 1196 message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' 1197 if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None): 1198 if lower_upper.shape[-1] != rhs.shape[-2]: 1199 raise ValueError(message) 1200 elif validate_args: 1201 assertions.append( 1202 check_ops.assert_equal( 1203 array_ops.shape(lower_upper)[-1], 1204 array_ops.shape(rhs)[-2], 1205 message=message)) 1206 1207 return assertions 1208