1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Create a Block Diagonal operator from one or more `LinearOperators`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import common_shapes 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops.linalg import linear_operator 29from tensorflow.python.ops.linalg import linear_operator_algebra 30from tensorflow.python.ops.linalg import linear_operator_util 31from tensorflow.python.util.tf_export import tf_export 32 33__all__ = ["LinearOperatorBlockDiag"] 34 35 36@tf_export("linalg.LinearOperatorBlockDiag") 37class LinearOperatorBlockDiag(linear_operator.LinearOperator): 38 """Combines one or more `LinearOperators` in to a Block Diagonal matrix. 39 40 This operator combines one or more linear operators `[op1,...,opJ]`, 41 building a new `LinearOperator`, whose underlying matrix representation is 42 square and has each operator `opi` on the main diagonal, and zero's elsewhere. 43 44 #### Shape compatibility 45 46 If `opj` acts like a [batch] square matrix `Aj`, then `op_combined` acts like 47 the [batch] square matrix formed by having each matrix `Aj` on the main 48 diagonal. 49 50 Each `opj` is required to represent a square matrix, and hence will have 51 shape `batch_shape_j + [M_j, M_j]`. 52 53 If `opj` has shape `batch_shape_j + [M_j, M_j]`, then the combined operator 54 has shape `broadcast_batch_shape + [sum M_j, sum M_j]`, where 55 `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, 56 `j = 1,...,J`, assuming the intermediate batch shapes broadcast. 57 Even if the combined shape is well defined, the combined operator's 58 methods may fail due to lack of broadcasting ability in the defining 59 operators' methods. 60 61 Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single 62 `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th 63 element of a blockwise list of `Tensor`s must have dimensions that match 64 `opj` for the given method. If a list of blocks is input, then a list of 65 blocks is returned as well. 66 67 ```python 68 # Create a 4 x 4 linear operator combined of two 2 x 2 operators. 69 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 70 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 71 operator = LinearOperatorBlockDiag([operator_1, operator_2]) 72 73 operator.to_dense() 74 ==> [[1., 2., 0., 0.], 75 [3., 4., 0., 0.], 76 [0., 0., 1., 0.], 77 [0., 0., 0., 1.]] 78 79 operator.shape 80 ==> [4, 4] 81 82 operator.log_abs_determinant() 83 ==> scalar Tensor 84 85 x1 = ... # Shape [2, 2] Tensor 86 x2 = ... # Shape [2, 2] Tensor 87 x = tf.concat([x1, x2], 0) # Shape [2, 4] Tensor 88 operator.matmul(x) 89 ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)]) 90 91 # Create a [2, 3] batch of 4 x 4 linear operators. 92 matrix_44 = tf.random.normal(shape=[2, 3, 4, 4]) 93 operator_44 = LinearOperatorFullMatrix(matrix) 94 95 # Create a [1, 3] batch of 5 x 5 linear operators. 96 matrix_55 = tf.random.normal(shape=[1, 3, 5, 5]) 97 operator_55 = LinearOperatorFullMatrix(matrix_55) 98 99 # Combine to create a [2, 3] batch of 9 x 9 operators. 100 operator_99 = LinearOperatorBlockDiag([operator_44, operator_55]) 101 102 # Create a shape [2, 3, 9] vector. 103 x = tf.random.normal(shape=[2, 3, 9]) 104 operator_99.matmul(x) 105 ==> Shape [2, 3, 9] Tensor 106 107 # Create a blockwise list of vectors. 108 x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])] 109 operator_99.matmul(x) 110 ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor] 111 ``` 112 113 #### Performance 114 115 The performance of `LinearOperatorBlockDiag` on any operation is equal to 116 the sum of the individual operators' operations. 117 118 119 #### Matrix property hints 120 121 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 122 for `X = non_singular, self_adjoint, positive_definite, square`. 123 These have the following meaning: 124 125 * If `is_X == True`, callers should expect the operator to have the 126 property `X`. This is a promise that should be fulfilled, but is *not* a 127 runtime assert. For example, finite floating point precision may result 128 in these promises being violated. 129 * If `is_X == False`, callers should expect the operator to not have `X`. 130 * If `is_X == None` (the default), callers should have no expectation either 131 way. 132 """ 133 134 def __init__(self, 135 operators, 136 is_non_singular=None, 137 is_self_adjoint=None, 138 is_positive_definite=None, 139 is_square=True, 140 name=None): 141 r"""Initialize a `LinearOperatorBlockDiag`. 142 143 `LinearOperatorBlockDiag` is initialized with a list of operators 144 `[op_1,...,op_J]`. 145 146 Args: 147 operators: Iterable of `LinearOperator` objects, each with 148 the same `dtype` and composable shape. 149 is_non_singular: Expect that this operator is non-singular. 150 is_self_adjoint: Expect that this operator is equal to its hermitian 151 transpose. 152 is_positive_definite: Expect that this operator is positive definite, 153 meaning the quadratic form `x^H A x` has positive real part for all 154 nonzero `x`. Note that we do not require the operator to be 155 self-adjoint to be positive-definite. See: 156 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 157 is_square: Expect that this operator acts like square [batch] matrices. 158 This is true by default, and will raise a `ValueError` otherwise. 159 name: A name for this `LinearOperator`. Default is the individual 160 operators names joined with `_o_`. 161 162 Raises: 163 TypeError: If all operators do not have the same `dtype`. 164 ValueError: If `operators` is empty or are non-square. 165 """ 166 parameters = dict( 167 operators=operators, 168 is_non_singular=is_non_singular, 169 is_self_adjoint=is_self_adjoint, 170 is_positive_definite=is_positive_definite, 171 is_square=is_square, 172 name=name 173 ) 174 175 # Validate operators. 176 check_ops.assert_proper_iterable(operators) 177 operators = list(operators) 178 if not operators: 179 raise ValueError( 180 "Expected a non-empty list of operators. Found: %s" % operators) 181 self._operators = operators 182 183 # Define diagonal operators, for functions that are shared across blockwise 184 # `LinearOperator` types. 185 self._diagonal_operators = operators 186 187 # Validate dtype. 188 dtype = operators[0].dtype 189 for operator in operators: 190 if operator.dtype != dtype: 191 name_type = (str((o.name, o.dtype)) for o in operators) 192 raise TypeError( 193 "Expected all operators to have the same dtype. Found %s" 194 % " ".join(name_type)) 195 196 # Auto-set and check hints. 197 if all(operator.is_non_singular for operator in operators): 198 if is_non_singular is False: 199 raise ValueError( 200 "The direct sum of non-singular operators is always non-singular.") 201 is_non_singular = True 202 203 if all(operator.is_self_adjoint for operator in operators): 204 if is_self_adjoint is False: 205 raise ValueError( 206 "The direct sum of self-adjoint operators is always self-adjoint.") 207 is_self_adjoint = True 208 209 if all(operator.is_positive_definite for operator in operators): 210 if is_positive_definite is False: 211 raise ValueError( 212 "The direct sum of positive definite operators is always " 213 "positive definite.") 214 is_positive_definite = True 215 216 if not (is_square and all(operator.is_square for operator in operators)): 217 raise ValueError( 218 "Can only represent a block diagonal of square matrices.") 219 220 # Initialization. 221 graph_parents = [] 222 for operator in operators: 223 graph_parents.extend(operator.graph_parents) 224 225 if name is None: 226 # Using ds to mean direct sum. 227 name = "_ds_".join(operator.name for operator in operators) 228 with ops.name_scope(name, values=graph_parents): 229 super(LinearOperatorBlockDiag, self).__init__( 230 dtype=dtype, 231 is_non_singular=is_non_singular, 232 is_self_adjoint=is_self_adjoint, 233 is_positive_definite=is_positive_definite, 234 is_square=True, 235 parameters=parameters, 236 name=name) 237 238 # TODO(b/143910018) Remove graph_parents in V3. 239 self._set_graph_parents(graph_parents) 240 241 @property 242 def operators(self): 243 return self._operators 244 245 def _block_range_dimensions(self): 246 return [op.range_dimension for op in self._diagonal_operators] 247 248 def _block_domain_dimensions(self): 249 return [op.domain_dimension for op in self._diagonal_operators] 250 251 def _block_range_dimension_tensors(self): 252 return [op.range_dimension_tensor() for op in self._diagonal_operators] 253 254 def _block_domain_dimension_tensors(self): 255 return [op.domain_dimension_tensor() for op in self._diagonal_operators] 256 257 def _shape(self): 258 # Get final matrix shape. 259 domain_dimension = sum(self._block_domain_dimensions()) 260 range_dimension = sum(self._block_range_dimensions()) 261 matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) 262 263 # Get broadcast batch shape. 264 # broadcast_shape checks for compatibility. 265 batch_shape = self.operators[0].batch_shape 266 for operator in self.operators[1:]: 267 batch_shape = common_shapes.broadcast_shape( 268 batch_shape, operator.batch_shape) 269 270 return batch_shape.concatenate(matrix_shape) 271 272 def _shape_tensor(self): 273 # Avoid messy broadcasting if possible. 274 if self.shape.is_fully_defined(): 275 return ops.convert_to_tensor_v2_with_dispatch( 276 self.shape.as_list(), dtype=dtypes.int32, name="shape") 277 278 domain_dimension = sum(self._block_domain_dimension_tensors()) 279 range_dimension = sum(self._block_range_dimension_tensors()) 280 matrix_shape = array_ops.stack([domain_dimension, range_dimension]) 281 282 # Dummy Tensor of zeros. Will never be materialized. 283 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 284 for operator in self.operators[1:]: 285 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 286 batch_shape = array_ops.shape(zeros) 287 288 return array_ops.concat((batch_shape, matrix_shape), 0) 289 290 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 291 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 292 293 ```python 294 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 295 operator = LinearOperator(...) 296 operator.shape = [..., M, N] 297 298 X = ... # shape [..., N, R], batch matrix, R > 0. 299 300 Y = operator.matmul(X) 301 Y.shape 302 ==> [..., M, R] 303 304 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 305 ``` 306 307 Args: 308 x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as 309 `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See 310 class docstring for definition of shape compatibility. 311 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 312 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 313 the hermitian transpose (transposition and complex conjugation). 314 name: A name for this `Op`. 315 316 Returns: 317 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 318 as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that 319 concatenate to `[..., M, R]`. 320 """ 321 if isinstance(x, linear_operator.LinearOperator): 322 left_operator = self.adjoint() if adjoint else self 323 right_operator = x.adjoint() if adjoint_arg else x 324 325 if (right_operator.range_dimension is not None and 326 left_operator.domain_dimension is not None and 327 right_operator.range_dimension != left_operator.domain_dimension): 328 raise ValueError( 329 "Operators are incompatible. Expected `x` to have dimension" 330 " {} but got {}.".format( 331 left_operator.domain_dimension, right_operator.range_dimension)) 332 with self._name_scope(name): 333 return linear_operator_algebra.matmul(left_operator, right_operator) 334 335 with self._name_scope(name): 336 arg_dim = -1 if adjoint_arg else -2 337 block_dimensions = (self._block_range_dimensions() if adjoint 338 else self._block_domain_dimensions()) 339 if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): 340 for i, block in enumerate(x): 341 if not isinstance(block, linear_operator.LinearOperator): 342 block = ops.convert_to_tensor_v2_with_dispatch(block) 343 self._check_input_dtype(block) 344 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 345 x[i] = block 346 else: 347 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 348 self._check_input_dtype(x) 349 op_dimension = (self.range_dimension if adjoint 350 else self.domain_dimension) 351 op_dimension.assert_is_compatible_with(x.shape[arg_dim]) 352 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 353 354 def _matmul(self, x, adjoint=False, adjoint_arg=False): 355 arg_dim = -1 if adjoint_arg else -2 356 block_dimensions = (self._block_range_dimensions() if adjoint 357 else self._block_domain_dimensions()) 358 blockwise_arg = linear_operator_util.arg_is_blockwise( 359 block_dimensions, x, arg_dim) 360 if blockwise_arg: 361 split_x = x 362 else: 363 split_dim = -1 if adjoint_arg else -2 364 # Split input by rows normally, and otherwise columns. 365 split_x = linear_operator_util.split_arg_into_blocks( 366 self._block_domain_dimensions(), 367 self._block_domain_dimension_tensors, 368 x, axis=split_dim) 369 370 result_list = [] 371 for index, operator in enumerate(self.operators): 372 result_list += [operator.matmul( 373 split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 374 375 if blockwise_arg: 376 return result_list 377 378 result_list = linear_operator_util.broadcast_matrix_batch_dims( 379 result_list) 380 return array_ops.concat(result_list, axis=-2) 381 382 def matvec(self, x, adjoint=False, name="matvec"): 383 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 384 385 ```python 386 # Make an operator acting like batch matric A. Assume A.shape = [..., M, N] 387 operator = LinearOperator(...) 388 389 X = ... # shape [..., N], batch vector 390 391 Y = operator.matvec(X) 392 Y.shape 393 ==> [..., M] 394 395 Y[..., :] = sum_j A[..., :, j] X[..., j] 396 ``` 397 398 Args: 399 x: `Tensor` with compatible shape and same `dtype` as `self`, or an 400 iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated 401 a [batch] vectors, meaning for every set of leading dimensions, the last 402 dimension defines a vector. 403 See class docstring for definition of compatibility. 404 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 405 name: A name for this `Op`. 406 407 Returns: 408 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 409 """ 410 with self._name_scope(name): 411 block_dimensions = (self._block_range_dimensions() if adjoint 412 else self._block_domain_dimensions()) 413 if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1): 414 for i, block in enumerate(x): 415 if not isinstance(block, linear_operator.LinearOperator): 416 block = ops.convert_to_tensor_v2_with_dispatch(block) 417 self._check_input_dtype(block) 418 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 419 x[i] = block 420 x_mat = [block[..., array_ops.newaxis] for block in x] 421 y_mat = self.matmul(x_mat, adjoint=adjoint) 422 return [array_ops.squeeze(y, axis=-1) for y in y_mat] 423 424 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 425 self._check_input_dtype(x) 426 op_dimension = (self.range_dimension if adjoint 427 else self.domain_dimension) 428 op_dimension.assert_is_compatible_with(x.shape[-1]) 429 x_mat = x[..., array_ops.newaxis] 430 y_mat = self.matmul(x_mat, adjoint=adjoint) 431 return array_ops.squeeze(y_mat, axis=-1) 432 433 def _determinant(self): 434 result = self.operators[0].determinant() 435 for operator in self.operators[1:]: 436 result *= operator.determinant() 437 return result 438 439 def _log_abs_determinant(self): 440 result = self.operators[0].log_abs_determinant() 441 for operator in self.operators[1:]: 442 result += operator.log_abs_determinant() 443 return result 444 445 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 446 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 447 448 The returned `Tensor` will be close to an exact solution if `A` is well 449 conditioned. Otherwise closeness will vary. See class docstring for details. 450 451 Examples: 452 453 ```python 454 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 455 operator = LinearOperator(...) 456 operator.shape = [..., M, N] 457 458 # Solve R > 0 linear systems for every member of the batch. 459 RHS = ... # shape [..., M, R] 460 461 X = operator.solve(RHS) 462 # X[..., :, r] is the solution to the r'th linear system 463 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 464 465 operator.matmul(X) 466 ==> RHS 467 ``` 468 469 Args: 470 rhs: `Tensor` with same `dtype` as this operator and compatible shape, 471 or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated 472 like a [batch] matrices meaning for every set of leading dimensions, the 473 last two dimensions defines a matrix. 474 See class docstring for definition of compatibility. 475 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 476 of this `LinearOperator`: `A^H X = rhs`. 477 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 478 is the hermitian transpose (transposition and complex conjugation). 479 name: A name scope to use for ops added by this method. 480 481 Returns: 482 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 483 484 Raises: 485 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 486 """ 487 if self.is_non_singular is False: 488 raise NotImplementedError( 489 "Exact solve not implemented for an operator that is expected to " 490 "be singular.") 491 if self.is_square is False: 492 raise NotImplementedError( 493 "Exact solve not implemented for an operator that is expected to " 494 "not be square.") 495 if isinstance(rhs, linear_operator.LinearOperator): 496 left_operator = self.adjoint() if adjoint else self 497 right_operator = rhs.adjoint() if adjoint_arg else rhs 498 499 if (right_operator.range_dimension is not None and 500 left_operator.domain_dimension is not None and 501 right_operator.range_dimension != left_operator.domain_dimension): 502 raise ValueError( 503 "Operators are incompatible. Expected `rhs` to have dimension" 504 " {} but got {}.".format( 505 left_operator.domain_dimension, right_operator.range_dimension)) 506 with self._name_scope(name): 507 return linear_operator_algebra.solve(left_operator, right_operator) 508 509 with self._name_scope(name): 510 block_dimensions = (self._block_domain_dimensions() if adjoint 511 else self._block_range_dimensions()) 512 arg_dim = -1 if adjoint_arg else -2 513 blockwise_arg = linear_operator_util.arg_is_blockwise( 514 block_dimensions, rhs, arg_dim) 515 516 if blockwise_arg: 517 split_rhs = rhs 518 for i, block in enumerate(split_rhs): 519 if not isinstance(block, linear_operator.LinearOperator): 520 block = ops.convert_to_tensor_v2_with_dispatch(block) 521 self._check_input_dtype(block) 522 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 523 split_rhs[i] = block 524 else: 525 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 526 self._check_input_dtype(rhs) 527 op_dimension = (self.domain_dimension if adjoint 528 else self.range_dimension) 529 op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) 530 split_dim = -1 if adjoint_arg else -2 531 # Split input by rows normally, and otherwise columns. 532 split_rhs = linear_operator_util.split_arg_into_blocks( 533 self._block_domain_dimensions(), 534 self._block_domain_dimension_tensors, 535 rhs, axis=split_dim) 536 537 solution_list = [] 538 for index, operator in enumerate(self.operators): 539 solution_list += [operator.solve( 540 split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 541 542 if blockwise_arg: 543 return solution_list 544 545 solution_list = linear_operator_util.broadcast_matrix_batch_dims( 546 solution_list) 547 return array_ops.concat(solution_list, axis=-2) 548 549 def solvevec(self, rhs, adjoint=False, name="solve"): 550 """Solve single equation with best effort: `A X = rhs`. 551 552 The returned `Tensor` will be close to an exact solution if `A` is well 553 conditioned. Otherwise closeness will vary. See class docstring for details. 554 555 Examples: 556 557 ```python 558 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 559 operator = LinearOperator(...) 560 operator.shape = [..., M, N] 561 562 # Solve one linear system for every member of the batch. 563 RHS = ... # shape [..., M] 564 565 X = operator.solvevec(RHS) 566 # X is the solution to the linear system 567 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 568 569 operator.matvec(X) 570 ==> RHS 571 ``` 572 573 Args: 574 rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s 575 (for blockwise operators). `Tensor`s are treated as [batch] vectors, 576 meaning for every set of leading dimensions, the last dimension defines 577 a vector. See class docstring for definition of compatibility regarding 578 batch dimensions. 579 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 580 of this `LinearOperator`: `A^H X = rhs`. 581 name: A name scope to use for ops added by this method. 582 583 Returns: 584 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 585 586 Raises: 587 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 588 """ 589 with self._name_scope(name): 590 block_dimensions = (self._block_domain_dimensions() if adjoint 591 else self._block_range_dimensions()) 592 if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): 593 for i, block in enumerate(rhs): 594 if not isinstance(block, linear_operator.LinearOperator): 595 block = ops.convert_to_tensor_v2_with_dispatch(block) 596 self._check_input_dtype(block) 597 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 598 rhs[i] = block 599 rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs] 600 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 601 return [array_ops.squeeze(x, axis=-1) for x in solution_mat] 602 603 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 604 self._check_input_dtype(rhs) 605 op_dimension = (self.domain_dimension if adjoint 606 else self.range_dimension) 607 op_dimension.assert_is_compatible_with(rhs.shape[-1]) 608 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 609 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 610 return array_ops.squeeze(solution_mat, axis=-1) 611 612 def _diag_part(self): 613 diag_list = [] 614 for operator in self.operators: 615 # Extend the axis for broadcasting. 616 diag_list += [operator.diag_part()[..., array_ops.newaxis]] 617 diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) 618 diagonal = array_ops.concat(diag_list, axis=-2) 619 return array_ops.squeeze(diagonal, axis=-1) 620 621 def _trace(self): 622 result = self.operators[0].trace() 623 for operator in self.operators[1:]: 624 result += operator.trace() 625 return result 626 627 def _to_dense(self): 628 num_cols = 0 629 rows = [] 630 broadcasted_blocks = [operator.to_dense() for operator in self.operators] 631 broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( 632 broadcasted_blocks) 633 for block in broadcasted_blocks: 634 batch_row_shape = array_ops.shape(block)[:-1] 635 636 zeros_to_pad_before_shape = array_ops.concat( 637 [batch_row_shape, [num_cols]], axis=-1) 638 zeros_to_pad_before = array_ops.zeros( 639 shape=zeros_to_pad_before_shape, dtype=block.dtype) 640 num_cols += array_ops.shape(block)[-1] 641 zeros_to_pad_after_shape = array_ops.concat( 642 [batch_row_shape, 643 [self.domain_dimension_tensor() - num_cols]], axis=-1) 644 zeros_to_pad_after = array_ops.zeros( 645 shape=zeros_to_pad_after_shape, dtype=block.dtype) 646 647 rows.append(array_ops.concat( 648 [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) 649 650 mat = array_ops.concat(rows, axis=-2) 651 mat.set_shape(self.shape) 652 return mat 653 654 def _assert_non_singular(self): 655 return control_flow_ops.group([ 656 operator.assert_non_singular() for operator in self.operators]) 657 658 def _assert_self_adjoint(self): 659 return control_flow_ops.group([ 660 operator.assert_self_adjoint() for operator in self.operators]) 661 662 def _assert_positive_definite(self): 663 return control_flow_ops.group([ 664 operator.assert_positive_definite() for operator in self.operators]) 665 666 def _eigvals(self): 667 eig_list = [] 668 for operator in self.operators: 669 # Extend the axis for broadcasting. 670 eig_list += [operator.eigvals()[..., array_ops.newaxis]] 671 eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) 672 eigs = array_ops.concat(eig_list, axis=-2) 673 return array_ops.squeeze(eigs, axis=-1) 674