1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for linear operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23 24import numpy as np 25import six 26 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.module import module 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import linalg_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops.linalg import linalg_impl as linalg 37from tensorflow.python.ops.linalg import linear_operator_algebra 38from tensorflow.python.ops.linalg import linear_operator_util 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43 44__all__ = ["LinearOperator"] 45 46 47# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. 48@tf_export("linalg.LinearOperator") 49@six.add_metaclass(abc.ABCMeta) 50class LinearOperator(module.Module): 51 """Base class defining a [batch of] linear operator[s]. 52 53 Subclasses of `LinearOperator` provide access to common methods on a 54 (batch) matrix, without the need to materialize the matrix. This allows: 55 56 * Matrix free computations 57 * Operators that take advantage of special structure, while providing a 58 consistent API to users. 59 60 #### Subclassing 61 62 To enable a public method, subclasses should implement the leading-underscore 63 version of the method. The argument signature should be identical except for 64 the omission of `name="..."`. For example, to enable 65 `matmul(x, adjoint=False, name="matmul")` a subclass should implement 66 `_matmul(x, adjoint=False)`. 67 68 #### Performance contract 69 70 Subclasses should only implement the assert methods 71 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)` 72 time. 73 74 Class docstrings should contain an explanation of computational complexity. 75 Since this is a high-performance library, attention should be paid to detail, 76 and explanations can include constants as well as Big-O notation. 77 78 #### Shape compatibility 79 80 `LinearOperator` subclasses should operate on a [batch] matrix with 81 compatible shape. Class docstrings should define what is meant by compatible 82 shape. Some subclasses may not support batching. 83 84 Examples: 85 86 `x` is a batch matrix with compatible shape for `matmul` if 87 88 ``` 89 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 90 x.shape = [B1,...,Bb] + [N, R] 91 ``` 92 93 `rhs` is a batch matrix with compatible shape for `solve` if 94 95 ``` 96 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 97 rhs.shape = [B1,...,Bb] + [M, R] 98 ``` 99 100 #### Example docstring for subclasses. 101 102 This operator acts like a (batch) matrix `A` with shape 103 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 104 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 105 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for 106 purposes of identifying and working with compatible arguments the shape is 107 relevant. 108 109 Examples: 110 111 ```python 112 some_tensor = ... shape = ???? 113 operator = MyLinOp(some_tensor) 114 115 operator.shape() 116 ==> [2, 4, 4] 117 118 operator.log_abs_determinant() 119 ==> Shape [2] Tensor 120 121 x = ... Shape [2, 4, 5] Tensor 122 123 operator.matmul(x) 124 ==> Shape [2, 4, 5] Tensor 125 ``` 126 127 #### Shape compatibility 128 129 This operator acts on batch matrices with compatible shape. 130 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE 131 132 #### Performance 133 134 FILL THIS IN 135 136 #### Matrix property hints 137 138 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 139 for `X = non_singular, self_adjoint, positive_definite, square`. 140 These have the following meaning: 141 142 * If `is_X == True`, callers should expect the operator to have the 143 property `X`. This is a promise that should be fulfilled, but is *not* a 144 runtime assert. For example, finite floating point precision may result 145 in these promises being violated. 146 * If `is_X == False`, callers should expect the operator to not have `X`. 147 * If `is_X == None` (the default), callers should have no expectation either 148 way. 149 150 #### Initialization parameters 151 152 All subclasses of `LinearOperator` are expected to pass a `parameters` 153 argument to `super().__init__()`. This should be a `dict` containing 154 the unadulterated arguments passed to the subclass `__init__`. For example, 155 `MyLinearOperator` with an initializer should look like: 156 157 ```python 158 def __init__(self, operator, is_square=False, name=None): 159 parameters = dict( 160 operator=operator, 161 is_square=is_square, 162 name=name 163 ) 164 ... 165 super().__init__(..., parameters=parameters) 166 ``` 167 168 Users can then access `my_linear_operator.parameters` to see all arguments 169 passed to its initializer. 170 """ 171 172 # TODO(b/143910018) Remove graph_parents in V3. 173 @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will " 174 " no longer be used.", "graph_parents") 175 def __init__(self, 176 dtype, 177 graph_parents=None, 178 is_non_singular=None, 179 is_self_adjoint=None, 180 is_positive_definite=None, 181 is_square=None, 182 name=None, 183 parameters=None): 184 r"""Initialize the `LinearOperator`. 185 186 **This is a private method for subclass use.** 187 **Subclasses should copy-paste this `__init__` documentation.** 188 189 Args: 190 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and 191 `solve` will have to be this type. 192 graph_parents: (Deprecated) Python list of graph prerequisites of this 193 `LinearOperator` Typically tensors that are passed during initialization 194 is_non_singular: Expect that this operator is non-singular. 195 is_self_adjoint: Expect that this operator is equal to its hermitian 196 transpose. If `dtype` is real, this is equivalent to being symmetric. 197 is_positive_definite: Expect that this operator is positive definite, 198 meaning the quadratic form `x^H A x` has positive real part for all 199 nonzero `x`. Note that we do not require the operator to be 200 self-adjoint to be positive-definite. See: 201 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 202 is_square: Expect that this operator acts like square [batch] matrices. 203 name: A name for this `LinearOperator`. 204 parameters: Python `dict` of parameters used to instantiate this 205 `LinearOperator`. 206 207 Raises: 208 ValueError: If any member of graph_parents is `None` or not a `Tensor`. 209 ValueError: If hints are set incorrectly. 210 """ 211 # Check and auto-set flags. 212 if is_positive_definite: 213 if is_non_singular is False: 214 raise ValueError("A positive definite matrix is always non-singular.") 215 is_non_singular = True 216 217 if is_non_singular: 218 if is_square is False: 219 raise ValueError("A non-singular matrix is always square.") 220 is_square = True 221 222 if is_self_adjoint: 223 if is_square is False: 224 raise ValueError("A self-adjoint matrix is always square.") 225 is_square = True 226 227 self._is_square_set_or_implied_by_hints = is_square 228 229 if graph_parents is not None: 230 self._set_graph_parents(graph_parents) 231 else: 232 self._graph_parents = [] 233 self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype 234 self._is_non_singular = is_non_singular 235 self._is_self_adjoint = is_self_adjoint 236 self._is_positive_definite = is_positive_definite 237 self._parameters = self._no_dependency(parameters) 238 self._parameters_sanitized = False 239 self._name = name or type(self).__name__ 240 241 @contextlib.contextmanager 242 def _name_scope(self, name=None): 243 """Helper function to standardize op scope.""" 244 full_name = self.name 245 if name is not None: 246 full_name += "/" + name 247 with ops.name_scope(full_name) as scope: 248 yield scope 249 250 @property 251 def parameters(self): 252 """Dictionary of parameters used to instantiate this `LinearOperator`.""" 253 return dict(self._parameters) 254 255 @property 256 def dtype(self): 257 """The `DType` of `Tensor`s handled by this `LinearOperator`.""" 258 return self._dtype 259 260 @property 261 def name(self): 262 """Name prepended to all ops created by this `LinearOperator`.""" 263 return self._name 264 265 @property 266 @deprecation.deprecated(None, "Do not call `graph_parents`.") 267 def graph_parents(self): 268 """List of graph dependencies of this `LinearOperator`.""" 269 return self._graph_parents 270 271 @property 272 def is_non_singular(self): 273 return self._is_non_singular 274 275 @property 276 def is_self_adjoint(self): 277 return self._is_self_adjoint 278 279 @property 280 def is_positive_definite(self): 281 return self._is_positive_definite 282 283 @property 284 def is_square(self): 285 """Return `True/False` depending on if this operator is square.""" 286 # Static checks done after __init__. Why? Because domain/range dimension 287 # sometimes requires lots of work done in the derived class after init. 288 auto_square_check = self.domain_dimension == self.range_dimension 289 if self._is_square_set_or_implied_by_hints is False and auto_square_check: 290 raise ValueError( 291 "User set is_square hint to False, but the operator was square.") 292 if self._is_square_set_or_implied_by_hints is None: 293 return auto_square_check 294 295 return self._is_square_set_or_implied_by_hints 296 297 @abc.abstractmethod 298 def _shape(self): 299 # Write this in derived class to enable all static shape methods. 300 raise NotImplementedError("_shape is not implemented.") 301 302 @property 303 def shape(self): 304 """`TensorShape` of this `LinearOperator`. 305 306 If this operator acts like the batch matrix `A` with 307 `A.shape = [B1,...,Bb, M, N]`, then this returns 308 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`. 309 310 Returns: 311 `TensorShape`, statically determined, may be undefined. 312 """ 313 return self._shape() 314 315 def _shape_tensor(self): 316 # This is not an abstractmethod, since we want derived classes to be able to 317 # override this with optional kwargs, which can reduce the number of 318 # `convert_to_tensor` calls. See derived classes for examples. 319 raise NotImplementedError("_shape_tensor is not implemented.") 320 321 def shape_tensor(self, name="shape_tensor"): 322 """Shape of this `LinearOperator`, determined at runtime. 323 324 If this operator acts like the batch matrix `A` with 325 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 326 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. 327 328 Args: 329 name: A name for this `Op`. 330 331 Returns: 332 `int32` `Tensor` 333 """ 334 with self._name_scope(name): 335 # Prefer to use statically defined shape if available. 336 if self.shape.is_fully_defined(): 337 return linear_operator_util.shape_tensor(self.shape.as_list()) 338 else: 339 return self._shape_tensor() 340 341 @property 342 def batch_shape(self): 343 """`TensorShape` of batch dimensions of this `LinearOperator`. 344 345 If this operator acts like the batch matrix `A` with 346 `A.shape = [B1,...,Bb, M, N]`, then this returns 347 `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]` 348 349 Returns: 350 `TensorShape`, statically determined, may be undefined. 351 """ 352 # Derived classes get this "for free" once .shape is implemented. 353 return self.shape[:-2] 354 355 def batch_shape_tensor(self, name="batch_shape_tensor"): 356 """Shape of batch dimensions of this operator, determined at runtime. 357 358 If this operator acts like the batch matrix `A` with 359 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 360 `[B1,...,Bb]`. 361 362 Args: 363 name: A name for this `Op`. 364 365 Returns: 366 `int32` `Tensor` 367 """ 368 # Derived classes get this "for free" once .shape() is implemented. 369 with self._name_scope(name): 370 return self._batch_shape_tensor() 371 372 def _batch_shape_tensor(self, shape=None): 373 # `shape` may be passed in if this can be pre-computed in a 374 # more efficient manner, e.g. without excessive Tensor conversions. 375 if self.batch_shape.is_fully_defined(): 376 return linear_operator_util.shape_tensor( 377 self.batch_shape.as_list(), name="batch_shape") 378 else: 379 shape = self.shape_tensor() if shape is None else shape 380 return shape[:-2] 381 382 @property 383 def tensor_rank(self, name="tensor_rank"): 384 """Rank (in the sense of tensors) of matrix corresponding to this operator. 385 386 If this operator acts like the batch matrix `A` with 387 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 388 389 Args: 390 name: A name for this `Op`. 391 392 Returns: 393 Python integer, or None if the tensor rank is undefined. 394 """ 395 # Derived classes get this "for free" once .shape() is implemented. 396 with self._name_scope(name): 397 return self.shape.ndims 398 399 def tensor_rank_tensor(self, name="tensor_rank_tensor"): 400 """Rank (in the sense of tensors) of matrix corresponding to this operator. 401 402 If this operator acts like the batch matrix `A` with 403 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 404 405 Args: 406 name: A name for this `Op`. 407 408 Returns: 409 `int32` `Tensor`, determined at runtime. 410 """ 411 # Derived classes get this "for free" once .shape() is implemented. 412 with self._name_scope(name): 413 return self._tensor_rank_tensor() 414 415 def _tensor_rank_tensor(self, shape=None): 416 # `shape` may be passed in if this can be pre-computed in a 417 # more efficient manner, e.g. without excessive Tensor conversions. 418 if self.tensor_rank is not None: 419 return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank) 420 else: 421 shape = self.shape_tensor() if shape is None else shape 422 return array_ops.size(shape) 423 424 @property 425 def domain_dimension(self): 426 """Dimension (in the sense of vector spaces) of the domain of this operator. 427 428 If this operator acts like the batch matrix `A` with 429 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 430 431 Returns: 432 `Dimension` object. 433 """ 434 # Derived classes get this "for free" once .shape is implemented. 435 if self.shape.rank is None: 436 return tensor_shape.Dimension(None) 437 else: 438 return self.shape.dims[-1] 439 440 def domain_dimension_tensor(self, name="domain_dimension_tensor"): 441 """Dimension (in the sense of vector spaces) of the domain of this operator. 442 443 Determined at runtime. 444 445 If this operator acts like the batch matrix `A` with 446 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 447 448 Args: 449 name: A name for this `Op`. 450 451 Returns: 452 `int32` `Tensor` 453 """ 454 # Derived classes get this "for free" once .shape() is implemented. 455 with self._name_scope(name): 456 return self._domain_dimension_tensor() 457 458 def _domain_dimension_tensor(self, shape=None): 459 # `shape` may be passed in if this can be pre-computed in a 460 # more efficient manner, e.g. without excessive Tensor conversions. 461 dim_value = tensor_shape.dimension_value(self.domain_dimension) 462 if dim_value is not None: 463 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 464 else: 465 shape = self.shape_tensor() if shape is None else shape 466 return shape[-1] 467 468 @property 469 def range_dimension(self): 470 """Dimension (in the sense of vector spaces) of the range of this operator. 471 472 If this operator acts like the batch matrix `A` with 473 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 474 475 Returns: 476 `Dimension` object. 477 """ 478 # Derived classes get this "for free" once .shape is implemented. 479 if self.shape.dims: 480 return self.shape.dims[-2] 481 else: 482 return tensor_shape.Dimension(None) 483 484 def range_dimension_tensor(self, name="range_dimension_tensor"): 485 """Dimension (in the sense of vector spaces) of the range of this operator. 486 487 Determined at runtime. 488 489 If this operator acts like the batch matrix `A` with 490 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 491 492 Args: 493 name: A name for this `Op`. 494 495 Returns: 496 `int32` `Tensor` 497 """ 498 # Derived classes get this "for free" once .shape() is implemented. 499 with self._name_scope(name): 500 return self._range_dimension_tensor() 501 502 def _range_dimension_tensor(self, shape=None): 503 # `shape` may be passed in if this can be pre-computed in a 504 # more efficient manner, e.g. without excessive Tensor conversions. 505 dim_value = tensor_shape.dimension_value(self.range_dimension) 506 if dim_value is not None: 507 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 508 else: 509 shape = self.shape_tensor() if shape is None else shape 510 return shape[-2] 511 512 def _assert_non_singular(self): 513 """Private default implementation of _assert_non_singular.""" 514 logging.warn( 515 "Using (possibly slow) default implementation of assert_non_singular." 516 " Requires conversion to a dense matrix and O(N^3) operations.") 517 if self._can_use_cholesky(): 518 return self.assert_positive_definite() 519 else: 520 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) 521 # TODO(langmore) Add .eig and .cond as methods. 522 cond = (math_ops.reduce_max(singular_values, axis=-1) / 523 math_ops.reduce_min(singular_values, axis=-1)) 524 return check_ops.assert_less( 525 cond, 526 self._max_condition_number_to_be_non_singular(), 527 message="Singular matrix up to precision epsilon.") 528 529 def _max_condition_number_to_be_non_singular(self): 530 """Return the maximum condition number that we consider nonsingular.""" 531 with ops.name_scope("max_nonsingular_condition_number"): 532 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps 533 eps = math_ops.cast( 534 math_ops.reduce_max([ 535 100., 536 math_ops.cast(self.range_dimension_tensor(), self.dtype), 537 math_ops.cast(self.domain_dimension_tensor(), self.dtype) 538 ]), self.dtype) * dtype_eps 539 return 1. / eps 540 541 def assert_non_singular(self, name="assert_non_singular"): 542 """Returns an `Op` that asserts this operator is non singular. 543 544 This operator is considered non-singular if 545 546 ``` 547 ConditionNumber < max{100, range_dimension, domain_dimension} * eps, 548 eps := np.finfo(self.dtype.as_numpy_dtype).eps 549 ``` 550 551 Args: 552 name: A string name to prepend to created ops. 553 554 Returns: 555 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 556 the operator is singular. 557 """ 558 with self._name_scope(name): 559 return self._assert_non_singular() 560 561 def _assert_positive_definite(self): 562 """Default implementation of _assert_positive_definite.""" 563 logging.warn( 564 "Using (possibly slow) default implementation of " 565 "assert_positive_definite." 566 " Requires conversion to a dense matrix and O(N^3) operations.") 567 # If the operator is self-adjoint, then checking that 568 # Cholesky decomposition succeeds + results in positive diag is necessary 569 # and sufficient. 570 if self.is_self_adjoint: 571 return check_ops.assert_positive( 572 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), 573 message="Matrix was not positive definite.") 574 # We have no generic check for positive definite. 575 raise NotImplementedError("assert_positive_definite is not implemented.") 576 577 def assert_positive_definite(self, name="assert_positive_definite"): 578 """Returns an `Op` that asserts this operator is positive definite. 579 580 Here, positive definite means that the quadratic form `x^H A x` has positive 581 real part for all nonzero `x`. Note that we do not require the operator to 582 be self-adjoint to be positive definite. 583 584 Args: 585 name: A name to give this `Op`. 586 587 Returns: 588 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 589 the operator is not positive definite. 590 """ 591 with self._name_scope(name): 592 return self._assert_positive_definite() 593 594 def _assert_self_adjoint(self): 595 dense = self.to_dense() 596 logging.warn( 597 "Using (possibly slow) default implementation of assert_self_adjoint." 598 " Requires conversion to a dense matrix.") 599 return check_ops.assert_equal( 600 dense, 601 linalg.adjoint(dense), 602 message="Matrix was not equal to its adjoint.") 603 604 def assert_self_adjoint(self, name="assert_self_adjoint"): 605 """Returns an `Op` that asserts this operator is self-adjoint. 606 607 Here we check that this operator is *exactly* equal to its hermitian 608 transpose. 609 610 Args: 611 name: A string name to prepend to created ops. 612 613 Returns: 614 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 615 the operator is not self-adjoint. 616 """ 617 with self._name_scope(name): 618 return self._assert_self_adjoint() 619 620 def _check_input_dtype(self, arg): 621 """Check that arg.dtype == self.dtype.""" 622 if arg.dtype.base_dtype != self.dtype: 623 raise TypeError( 624 "Expected argument to have dtype %s. Found: %s in tensor %s" % 625 (self.dtype, arg.dtype, arg)) 626 627 @abc.abstractmethod 628 def _matmul(self, x, adjoint=False, adjoint_arg=False): 629 raise NotImplementedError("_matmul is not implemented.") 630 631 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 632 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 633 634 ```python 635 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 636 operator = LinearOperator(...) 637 operator.shape = [..., M, N] 638 639 X = ... # shape [..., N, R], batch matrix, R > 0. 640 641 Y = operator.matmul(X) 642 Y.shape 643 ==> [..., M, R] 644 645 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 646 ``` 647 648 Args: 649 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as 650 `self`. See class docstring for definition of compatibility. 651 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 652 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 653 the hermitian transpose (transposition and complex conjugation). 654 name: A name for this `Op`. 655 656 Returns: 657 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 658 as `self`. 659 """ 660 if isinstance(x, LinearOperator): 661 left_operator = self.adjoint() if adjoint else self 662 right_operator = x.adjoint() if adjoint_arg else x 663 664 if (right_operator.range_dimension is not None and 665 left_operator.domain_dimension is not None and 666 right_operator.range_dimension != left_operator.domain_dimension): 667 raise ValueError( 668 "Operators are incompatible. Expected `x` to have dimension" 669 " {} but got {}.".format( 670 left_operator.domain_dimension, right_operator.range_dimension)) 671 with self._name_scope(name): 672 return linear_operator_algebra.matmul(left_operator, right_operator) 673 674 with self._name_scope(name): 675 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 676 self._check_input_dtype(x) 677 678 self_dim = -2 if adjoint else -1 679 arg_dim = -1 if adjoint_arg else -2 680 tensor_shape.dimension_at_index( 681 self.shape, self_dim).assert_is_compatible_with( 682 x.shape[arg_dim]) 683 684 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 685 686 def __matmul__(self, other): 687 return self.matmul(other) 688 689 def _matvec(self, x, adjoint=False): 690 x_mat = array_ops.expand_dims(x, axis=-1) 691 y_mat = self.matmul(x_mat, adjoint=adjoint) 692 return array_ops.squeeze(y_mat, axis=-1) 693 694 def matvec(self, x, adjoint=False, name="matvec"): 695 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 696 697 ```python 698 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 699 operator = LinearOperator(...) 700 701 X = ... # shape [..., N], batch vector 702 703 Y = operator.matvec(X) 704 Y.shape 705 ==> [..., M] 706 707 Y[..., :] = sum_j A[..., :, j] X[..., j] 708 ``` 709 710 Args: 711 x: `Tensor` with compatible shape and same `dtype` as `self`. 712 `x` is treated as a [batch] vector meaning for every set of leading 713 dimensions, the last dimension defines a vector. 714 See class docstring for definition of compatibility. 715 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 716 name: A name for this `Op`. 717 718 Returns: 719 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 720 """ 721 with self._name_scope(name): 722 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 723 self._check_input_dtype(x) 724 self_dim = -2 if adjoint else -1 725 tensor_shape.dimension_at_index( 726 self.shape, self_dim).assert_is_compatible_with(x.shape[-1]) 727 return self._matvec(x, adjoint=adjoint) 728 729 def _determinant(self): 730 logging.warn( 731 "Using (possibly slow) default implementation of determinant." 732 " Requires conversion to a dense matrix and O(N^3) operations.") 733 if self._can_use_cholesky(): 734 return math_ops.exp(self.log_abs_determinant()) 735 return linalg_ops.matrix_determinant(self.to_dense()) 736 737 def determinant(self, name="det"): 738 """Determinant for every batch member. 739 740 Args: 741 name: A name for this `Op`. 742 743 Returns: 744 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 745 746 Raises: 747 NotImplementedError: If `self.is_square` is `False`. 748 """ 749 if self.is_square is False: 750 raise NotImplementedError( 751 "Determinant not implemented for an operator that is expected to " 752 "not be square.") 753 with self._name_scope(name): 754 return self._determinant() 755 756 def _log_abs_determinant(self): 757 logging.warn( 758 "Using (possibly slow) default implementation of determinant." 759 " Requires conversion to a dense matrix and O(N^3) operations.") 760 if self._can_use_cholesky(): 761 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) 762 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) 763 _, log_abs_det = linalg.slogdet(self.to_dense()) 764 return log_abs_det 765 766 def log_abs_determinant(self, name="log_abs_det"): 767 """Log absolute value of determinant for every batch member. 768 769 Args: 770 name: A name for this `Op`. 771 772 Returns: 773 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 774 775 Raises: 776 NotImplementedError: If `self.is_square` is `False`. 777 """ 778 if self.is_square is False: 779 raise NotImplementedError( 780 "Determinant not implemented for an operator that is expected to " 781 "not be square.") 782 with self._name_scope(name): 783 return self._log_abs_determinant() 784 785 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False): 786 """Solve by conversion to a dense matrix.""" 787 if self.is_square is False: # pylint: disable=g-bool-id-comparison 788 raise NotImplementedError( 789 "Solve is not yet implemented for non-square operators.") 790 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 791 if self._can_use_cholesky(): 792 return linalg_ops.cholesky_solve( 793 linalg_ops.cholesky(self.to_dense()), rhs) 794 return linear_operator_util.matrix_solve_with_broadcast( 795 self.to_dense(), rhs, adjoint=adjoint) 796 797 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 798 """Default implementation of _solve.""" 799 logging.warn( 800 "Using (possibly slow) default implementation of solve." 801 " Requires conversion to a dense matrix and O(N^3) operations.") 802 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 803 804 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 805 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 806 807 The returned `Tensor` will be close to an exact solution if `A` is well 808 conditioned. Otherwise closeness will vary. See class docstring for details. 809 810 Examples: 811 812 ```python 813 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 814 operator = LinearOperator(...) 815 operator.shape = [..., M, N] 816 817 # Solve R > 0 linear systems for every member of the batch. 818 RHS = ... # shape [..., M, R] 819 820 X = operator.solve(RHS) 821 # X[..., :, r] is the solution to the r'th linear system 822 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 823 824 operator.matmul(X) 825 ==> RHS 826 ``` 827 828 Args: 829 rhs: `Tensor` with same `dtype` as this operator and compatible shape. 830 `rhs` is treated like a [batch] matrix meaning for every set of leading 831 dimensions, the last two dimensions defines a matrix. 832 See class docstring for definition of compatibility. 833 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 834 of this `LinearOperator`: `A^H X = rhs`. 835 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 836 is the hermitian transpose (transposition and complex conjugation). 837 name: A name scope to use for ops added by this method. 838 839 Returns: 840 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 841 842 Raises: 843 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 844 """ 845 if self.is_non_singular is False: 846 raise NotImplementedError( 847 "Exact solve not implemented for an operator that is expected to " 848 "be singular.") 849 if self.is_square is False: 850 raise NotImplementedError( 851 "Exact solve not implemented for an operator that is expected to " 852 "not be square.") 853 if isinstance(rhs, LinearOperator): 854 left_operator = self.adjoint() if adjoint else self 855 right_operator = rhs.adjoint() if adjoint_arg else rhs 856 857 if (right_operator.range_dimension is not None and 858 left_operator.domain_dimension is not None and 859 right_operator.range_dimension != left_operator.domain_dimension): 860 raise ValueError( 861 "Operators are incompatible. Expected `rhs` to have dimension" 862 " {} but got {}.".format( 863 left_operator.domain_dimension, right_operator.range_dimension)) 864 with self._name_scope(name): 865 return linear_operator_algebra.solve(left_operator, right_operator) 866 867 with self._name_scope(name): 868 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 869 self._check_input_dtype(rhs) 870 871 self_dim = -1 if adjoint else -2 872 arg_dim = -1 if adjoint_arg else -2 873 tensor_shape.dimension_at_index( 874 self.shape, self_dim).assert_is_compatible_with( 875 rhs.shape[arg_dim]) 876 877 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 878 879 def _solvevec(self, rhs, adjoint=False): 880 """Default implementation of _solvevec.""" 881 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 882 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 883 return array_ops.squeeze(solution_mat, axis=-1) 884 885 def solvevec(self, rhs, adjoint=False, name="solve"): 886 """Solve single equation with best effort: `A X = rhs`. 887 888 The returned `Tensor` will be close to an exact solution if `A` is well 889 conditioned. Otherwise closeness will vary. See class docstring for details. 890 891 Examples: 892 893 ```python 894 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 895 operator = LinearOperator(...) 896 operator.shape = [..., M, N] 897 898 # Solve one linear system for every member of the batch. 899 RHS = ... # shape [..., M] 900 901 X = operator.solvevec(RHS) 902 # X is the solution to the linear system 903 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 904 905 operator.matvec(X) 906 ==> RHS 907 ``` 908 909 Args: 910 rhs: `Tensor` with same `dtype` as this operator. 911 `rhs` is treated like a [batch] vector meaning for every set of leading 912 dimensions, the last dimension defines a vector. See class docstring 913 for definition of compatibility regarding batch dimensions. 914 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 915 of this `LinearOperator`: `A^H X = rhs`. 916 name: A name scope to use for ops added by this method. 917 918 Returns: 919 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 920 921 Raises: 922 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 923 """ 924 with self._name_scope(name): 925 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 926 self._check_input_dtype(rhs) 927 self_dim = -1 if adjoint else -2 928 tensor_shape.dimension_at_index( 929 self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1]) 930 931 return self._solvevec(rhs, adjoint=adjoint) 932 933 def adjoint(self, name="adjoint"): 934 """Returns the adjoint of the current `LinearOperator`. 935 936 Given `A` representing this `LinearOperator`, return `A*`. 937 Note that calling `self.adjoint()` and `self.H` are equivalent. 938 939 Args: 940 name: A name for this `Op`. 941 942 Returns: 943 `LinearOperator` which represents the adjoint of this `LinearOperator`. 944 """ 945 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison 946 return self 947 with self._name_scope(name): 948 return linear_operator_algebra.adjoint(self) 949 950 # self.H is equivalent to self.adjoint(). 951 H = property(adjoint, None) 952 953 def inverse(self, name="inverse"): 954 """Returns the Inverse of this `LinearOperator`. 955 956 Given `A` representing this `LinearOperator`, return a `LinearOperator` 957 representing `A^-1`. 958 959 Args: 960 name: A name scope to use for ops added by this method. 961 962 Returns: 963 `LinearOperator` representing inverse of this matrix. 964 965 Raises: 966 ValueError: When the `LinearOperator` is not hinted to be `non_singular`. 967 """ 968 if self.is_square is False: # pylint: disable=g-bool-id-comparison 969 raise ValueError("Cannot take the Inverse: This operator represents " 970 "a non square matrix.") 971 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison 972 raise ValueError("Cannot take the Inverse: This operator represents " 973 "a singular matrix.") 974 975 with self._name_scope(name): 976 return linear_operator_algebra.inverse(self) 977 978 def cholesky(self, name="cholesky"): 979 """Returns a Cholesky factor as a `LinearOperator`. 980 981 Given `A` representing this `LinearOperator`, if `A` is positive definite 982 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky 983 decomposition. 984 985 Args: 986 name: A name for this `Op`. 987 988 Returns: 989 `LinearOperator` which represents the lower triangular matrix 990 in the Cholesky decomposition. 991 992 Raises: 993 ValueError: When the `LinearOperator` is not hinted to be positive 994 definite and self adjoint. 995 """ 996 997 if not self._can_use_cholesky(): 998 raise ValueError("Cannot take the Cholesky decomposition: " 999 "Not a positive definite self adjoint matrix.") 1000 with self._name_scope(name): 1001 return linear_operator_algebra.cholesky(self) 1002 1003 def _to_dense(self): 1004 """Generic and often inefficient implementation. Override often.""" 1005 if self.batch_shape.is_fully_defined(): 1006 batch_shape = self.batch_shape 1007 else: 1008 batch_shape = self.batch_shape_tensor() 1009 1010 dim_value = tensor_shape.dimension_value(self.domain_dimension) 1011 if dim_value is not None: 1012 n = dim_value 1013 else: 1014 n = self.domain_dimension_tensor() 1015 1016 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) 1017 return self.matmul(eye) 1018 1019 def to_dense(self, name="to_dense"): 1020 """Return a dense (batch) matrix representing this operator.""" 1021 with self._name_scope(name): 1022 return self._to_dense() 1023 1024 def _diag_part(self): 1025 """Generic and often inefficient implementation. Override often.""" 1026 return array_ops.matrix_diag_part(self.to_dense()) 1027 1028 def diag_part(self, name="diag_part"): 1029 """Efficiently get the [batch] diagonal part of this operator. 1030 1031 If this operator has shape `[B1,...,Bb, M, N]`, this returns a 1032 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where 1033 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`. 1034 1035 ``` 1036 my_operator = LinearOperatorDiag([1., 2.]) 1037 1038 # Efficiently get the diagonal 1039 my_operator.diag_part() 1040 ==> [1., 2.] 1041 1042 # Equivalent, but inefficient method 1043 tf.linalg.diag_part(my_operator.to_dense()) 1044 ==> [1., 2.] 1045 ``` 1046 1047 Args: 1048 name: A name for this `Op`. 1049 1050 Returns: 1051 diag_part: A `Tensor` of same `dtype` as self. 1052 """ 1053 with self._name_scope(name): 1054 return self._diag_part() 1055 1056 def _trace(self): 1057 return math_ops.reduce_sum(self.diag_part(), axis=-1) 1058 1059 def trace(self, name="trace"): 1060 """Trace of the linear operator, equal to sum of `self.diag_part()`. 1061 1062 If the operator is square, this is also the sum of the eigenvalues. 1063 1064 Args: 1065 name: A name for this `Op`. 1066 1067 Returns: 1068 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1069 """ 1070 with self._name_scope(name): 1071 return self._trace() 1072 1073 def _add_to_tensor(self, x): 1074 # Override if a more efficient implementation is available. 1075 return self.to_dense() + x 1076 1077 def add_to_tensor(self, x, name="add_to_tensor"): 1078 """Add matrix represented by this operator to `x`. Equivalent to `A + x`. 1079 1080 Args: 1081 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. 1082 name: A name to give this `Op`. 1083 1084 Returns: 1085 A `Tensor` with broadcast shape and same `dtype` as `self`. 1086 """ 1087 with self._name_scope(name): 1088 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 1089 self._check_input_dtype(x) 1090 return self._add_to_tensor(x) 1091 1092 def _eigvals(self): 1093 return linalg_ops.self_adjoint_eigvals(self.to_dense()) 1094 1095 def eigvals(self, name="eigvals"): 1096 """Returns the eigenvalues of this linear operator. 1097 1098 If the operator is marked as self-adjoint (via `is_self_adjoint`) 1099 this computation can be more efficient. 1100 1101 Note: This currently only supports self-adjoint operators. 1102 1103 Args: 1104 name: A name for this `Op`. 1105 1106 Returns: 1107 Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`. 1108 """ 1109 if not self.is_self_adjoint: 1110 raise NotImplementedError("Only self-adjoint matrices are supported.") 1111 with self._name_scope(name): 1112 return self._eigvals() 1113 1114 def _cond(self): 1115 if not self.is_self_adjoint: 1116 # In general the condition number is the ratio of the 1117 # absolute value of the largest and smallest singular values. 1118 vals = linalg_ops.svd(self.to_dense(), compute_uv=False) 1119 else: 1120 # For self-adjoint matrices, and in general normal matrices, 1121 # we can use eigenvalues. 1122 vals = math_ops.abs(self._eigvals()) 1123 1124 return (math_ops.reduce_max(vals, axis=-1) / 1125 math_ops.reduce_min(vals, axis=-1)) 1126 1127 def cond(self, name="cond"): 1128 """Returns the condition number of this linear operator. 1129 1130 Args: 1131 name: A name for this `Op`. 1132 1133 Returns: 1134 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1135 """ 1136 with self._name_scope(name): 1137 return self._cond() 1138 1139 def _can_use_cholesky(self): 1140 return self.is_self_adjoint and self.is_positive_definite 1141 1142 def _set_graph_parents(self, graph_parents): 1143 """Set self._graph_parents. Called during derived class init. 1144 1145 This method allows derived classes to set graph_parents, without triggering 1146 a deprecation warning (which is invoked if `graph_parents` is passed during 1147 `__init__`. 1148 1149 Args: 1150 graph_parents: Iterable over Tensors. 1151 """ 1152 # TODO(b/143910018) Remove this function in V3. 1153 graph_parents = [] if graph_parents is None else graph_parents 1154 for i, t in enumerate(graph_parents): 1155 if t is None or not (linear_operator_util.is_ref(t) or 1156 tensor_util.is_tf_type(t)): 1157 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 1158 self._graph_parents = graph_parents 1159 1160 1161# Overrides for tf.linalg functions. This allows a LinearOperator to be used in 1162# place of a Tensor. 1163# For instance tf.trace(linop) and linop.trace() both work. 1164 1165 1166@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator) 1167def _adjoint(matrix, name=None): 1168 return matrix.adjoint(name) 1169 1170 1171@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator) 1172def _cholesky(input, name=None): # pylint:disable=redefined-builtin 1173 return input.cholesky(name) 1174 1175 1176# The signature has to match with the one in python/op/array_ops.py, 1177# so we have k, padding_value, and align even though we don't use them here. 1178# pylint:disable=unused-argument 1179@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator) 1180def _diag_part( 1181 input, # pylint:disable=redefined-builtin 1182 name="diag_part", 1183 k=0, 1184 padding_value=0, 1185 align="RIGHT_LEFT"): 1186 return input.diag_part(name) 1187# pylint:enable=unused-argument 1188 1189 1190@dispatch.dispatch_for_types(linalg.det, LinearOperator) 1191def _det(input, name=None): # pylint:disable=redefined-builtin 1192 return input.determinant(name) 1193 1194 1195@dispatch.dispatch_for_types(linalg.inv, LinearOperator) 1196def _inverse(input, adjoint=False, name=None): # pylint:disable=redefined-builtin 1197 inv = input.inverse(name) 1198 if adjoint: 1199 inv = inv.adjoint() 1200 return inv 1201 1202 1203@dispatch.dispatch_for_types(linalg.logdet, LinearOperator) 1204def _logdet(matrix, name=None): 1205 if matrix.is_positive_definite and matrix.is_self_adjoint: 1206 return matrix.log_abs_determinant(name) 1207 raise ValueError("Expected matrix to be self-adjoint positive definite.") 1208 1209 1210@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator) 1211def _matmul( # pylint:disable=missing-docstring 1212 a, 1213 b, 1214 transpose_a=False, 1215 transpose_b=False, 1216 adjoint_a=False, 1217 adjoint_b=False, 1218 a_is_sparse=False, 1219 b_is_sparse=False, 1220 name=None): 1221 if transpose_a or transpose_b: 1222 raise ValueError("Transposing not supported at this time.") 1223 if a_is_sparse or b_is_sparse: 1224 raise ValueError("Sparse methods not supported at this time.") 1225 if not isinstance(a, LinearOperator): 1226 # We use the identity (B^HA^H)^H = AB 1227 adjoint_matmul = b.matmul( 1228 a, 1229 adjoint=(not adjoint_b), 1230 adjoint_arg=(not adjoint_a), 1231 name=name) 1232 return linalg.adjoint(adjoint_matmul) 1233 return a.matmul( 1234 b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name) 1235 1236 1237@dispatch.dispatch_for_types(linalg.solve, LinearOperator) 1238def _solve( 1239 matrix, 1240 rhs, 1241 adjoint=False, 1242 name=None): 1243 if not isinstance(matrix, LinearOperator): 1244 raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a " 1245 "LinearOperator is not supported.") 1246 return matrix.solve(rhs, adjoint=adjoint, name=name) 1247 1248 1249@dispatch.dispatch_for_types(linalg.trace, LinearOperator) 1250def _trace(x, name=None): 1251 return x.trace(name) 1252