1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for linear operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23
24import numpy as np
25import six
26
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.module import module
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import linalg_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops.linalg import linalg_impl as linalg
37from tensorflow.python.ops.linalg import linear_operator_algebra
38from tensorflow.python.ops.linalg import linear_operator_util
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43
44__all__ = ["LinearOperator"]
45
46
47# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
48@tf_export("linalg.LinearOperator")
49@six.add_metaclass(abc.ABCMeta)
50class LinearOperator(module.Module):
51  """Base class defining a [batch of] linear operator[s].
52
53  Subclasses of `LinearOperator` provide access to common methods on a
54  (batch) matrix, without the need to materialize the matrix.  This allows:
55
56  * Matrix free computations
57  * Operators that take advantage of special structure, while providing a
58    consistent API to users.
59
60  #### Subclassing
61
62  To enable a public method, subclasses should implement the leading-underscore
63  version of the method.  The argument signature should be identical except for
64  the omission of `name="..."`.  For example, to enable
65  `matmul(x, adjoint=False, name="matmul")` a subclass should implement
66  `_matmul(x, adjoint=False)`.
67
68  #### Performance contract
69
70  Subclasses should only implement the assert methods
71  (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
72  time.
73
74  Class docstrings should contain an explanation of computational complexity.
75  Since this is a high-performance library, attention should be paid to detail,
76  and explanations can include constants as well as Big-O notation.
77
78  #### Shape compatibility
79
80  `LinearOperator` subclasses should operate on a [batch] matrix with
81  compatible shape.  Class docstrings should define what is meant by compatible
82  shape.  Some subclasses may not support batching.
83
84  Examples:
85
86  `x` is a batch matrix with compatible shape for `matmul` if
87
88  ```
89  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
90  x.shape =   [B1,...,Bb] + [N, R]
91  ```
92
93  `rhs` is a batch matrix with compatible shape for `solve` if
94
95  ```
96  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
97  rhs.shape =   [B1,...,Bb] + [M, R]
98  ```
99
100  #### Example docstring for subclasses.
101
102  This operator acts like a (batch) matrix `A` with shape
103  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
104  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
105  an `m x n` matrix.  Again, this matrix `A` may not be materialized, but for
106  purposes of identifying and working with compatible arguments the shape is
107  relevant.
108
109  Examples:
110
111  ```python
112  some_tensor = ... shape = ????
113  operator = MyLinOp(some_tensor)
114
115  operator.shape()
116  ==> [2, 4, 4]
117
118  operator.log_abs_determinant()
119  ==> Shape [2] Tensor
120
121  x = ... Shape [2, 4, 5] Tensor
122
123  operator.matmul(x)
124  ==> Shape [2, 4, 5] Tensor
125  ```
126
127  #### Shape compatibility
128
129  This operator acts on batch matrices with compatible shape.
130  FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
131
132  #### Performance
133
134  FILL THIS IN
135
136  #### Matrix property hints
137
138  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
139  for `X = non_singular, self_adjoint, positive_definite, square`.
140  These have the following meaning:
141
142  * If `is_X == True`, callers should expect the operator to have the
143    property `X`.  This is a promise that should be fulfilled, but is *not* a
144    runtime assert.  For example, finite floating point precision may result
145    in these promises being violated.
146  * If `is_X == False`, callers should expect the operator to not have `X`.
147  * If `is_X == None` (the default), callers should have no expectation either
148    way.
149
150  #### Initialization parameters
151
152  All subclasses of `LinearOperator` are expected to pass a `parameters`
153  argument to `super().__init__()`.  This should be a `dict` containing
154  the unadulterated arguments passed to the subclass `__init__`.  For example,
155  `MyLinearOperator` with an initializer should look like:
156
157  ```python
158  def __init__(self, operator, is_square=False, name=None):
159     parameters = dict(
160         operator=operator,
161         is_square=is_square,
162         name=name
163     )
164     ...
165     super().__init__(..., parameters=parameters)
166   ```
167
168   Users can then access `my_linear_operator.parameters` to see all arguments
169   passed to its initializer.
170  """
171
172  # TODO(b/143910018) Remove graph_parents in V3.
173  @deprecation.deprecated_args(None, "Do not pass `graph_parents`.  They will "
174                               " no longer be used.", "graph_parents")
175  def __init__(self,
176               dtype,
177               graph_parents=None,
178               is_non_singular=None,
179               is_self_adjoint=None,
180               is_positive_definite=None,
181               is_square=None,
182               name=None,
183               parameters=None):
184    r"""Initialize the `LinearOperator`.
185
186    **This is a private method for subclass use.**
187    **Subclasses should copy-paste this `__init__` documentation.**
188
189    Args:
190      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
191        `solve` will have to be this type.
192      graph_parents: (Deprecated) Python list of graph prerequisites of this
193        `LinearOperator` Typically tensors that are passed during initialization
194      is_non_singular:  Expect that this operator is non-singular.
195      is_self_adjoint:  Expect that this operator is equal to its hermitian
196        transpose.  If `dtype` is real, this is equivalent to being symmetric.
197      is_positive_definite:  Expect that this operator is positive definite,
198        meaning the quadratic form `x^H A x` has positive real part for all
199        nonzero `x`.  Note that we do not require the operator to be
200        self-adjoint to be positive-definite.  See:
201        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
202      is_square:  Expect that this operator acts like square [batch] matrices.
203      name: A name for this `LinearOperator`.
204      parameters: Python `dict` of parameters used to instantiate this
205        `LinearOperator`.
206
207    Raises:
208      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
209      ValueError:  If hints are set incorrectly.
210    """
211    # Check and auto-set flags.
212    if is_positive_definite:
213      if is_non_singular is False:
214        raise ValueError("A positive definite matrix is always non-singular.")
215      is_non_singular = True
216
217    if is_non_singular:
218      if is_square is False:
219        raise ValueError("A non-singular matrix is always square.")
220      is_square = True
221
222    if is_self_adjoint:
223      if is_square is False:
224        raise ValueError("A self-adjoint matrix is always square.")
225      is_square = True
226
227    self._is_square_set_or_implied_by_hints = is_square
228
229    if graph_parents is not None:
230      self._set_graph_parents(graph_parents)
231    else:
232      self._graph_parents = []
233    self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
234    self._is_non_singular = is_non_singular
235    self._is_self_adjoint = is_self_adjoint
236    self._is_positive_definite = is_positive_definite
237    self._parameters = self._no_dependency(parameters)
238    self._parameters_sanitized = False
239    self._name = name or type(self).__name__
240
241  @contextlib.contextmanager
242  def _name_scope(self, name=None):
243    """Helper function to standardize op scope."""
244    full_name = self.name
245    if name is not None:
246      full_name += "/" + name
247    with ops.name_scope(full_name) as scope:
248      yield scope
249
250  @property
251  def parameters(self):
252    """Dictionary of parameters used to instantiate this `LinearOperator`."""
253    return dict(self._parameters)
254
255  @property
256  def dtype(self):
257    """The `DType` of `Tensor`s handled by this `LinearOperator`."""
258    return self._dtype
259
260  @property
261  def name(self):
262    """Name prepended to all ops created by this `LinearOperator`."""
263    return self._name
264
265  @property
266  @deprecation.deprecated(None, "Do not call `graph_parents`.")
267  def graph_parents(self):
268    """List of graph dependencies of this `LinearOperator`."""
269    return self._graph_parents
270
271  @property
272  def is_non_singular(self):
273    return self._is_non_singular
274
275  @property
276  def is_self_adjoint(self):
277    return self._is_self_adjoint
278
279  @property
280  def is_positive_definite(self):
281    return self._is_positive_definite
282
283  @property
284  def is_square(self):
285    """Return `True/False` depending on if this operator is square."""
286    # Static checks done after __init__.  Why?  Because domain/range dimension
287    # sometimes requires lots of work done in the derived class after init.
288    auto_square_check = self.domain_dimension == self.range_dimension
289    if self._is_square_set_or_implied_by_hints is False and auto_square_check:
290      raise ValueError(
291          "User set is_square hint to False, but the operator was square.")
292    if self._is_square_set_or_implied_by_hints is None:
293      return auto_square_check
294
295    return self._is_square_set_or_implied_by_hints
296
297  @abc.abstractmethod
298  def _shape(self):
299    # Write this in derived class to enable all static shape methods.
300    raise NotImplementedError("_shape is not implemented.")
301
302  @property
303  def shape(self):
304    """`TensorShape` of this `LinearOperator`.
305
306    If this operator acts like the batch matrix `A` with
307    `A.shape = [B1,...,Bb, M, N]`, then this returns
308    `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`.
309
310    Returns:
311      `TensorShape`, statically determined, may be undefined.
312    """
313    return self._shape()
314
315  def _shape_tensor(self):
316    # This is not an abstractmethod, since we want derived classes to be able to
317    # override this with optional kwargs, which can reduce the number of
318    # `convert_to_tensor` calls.  See derived classes for examples.
319    raise NotImplementedError("_shape_tensor is not implemented.")
320
321  def shape_tensor(self, name="shape_tensor"):
322    """Shape of this `LinearOperator`, determined at runtime.
323
324    If this operator acts like the batch matrix `A` with
325    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
326    `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
327
328    Args:
329      name:  A name for this `Op`.
330
331    Returns:
332      `int32` `Tensor`
333    """
334    with self._name_scope(name):
335      # Prefer to use statically defined shape if available.
336      if self.shape.is_fully_defined():
337        return linear_operator_util.shape_tensor(self.shape.as_list())
338      else:
339        return self._shape_tensor()
340
341  @property
342  def batch_shape(self):
343    """`TensorShape` of batch dimensions of this `LinearOperator`.
344
345    If this operator acts like the batch matrix `A` with
346    `A.shape = [B1,...,Bb, M, N]`, then this returns
347    `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]`
348
349    Returns:
350      `TensorShape`, statically determined, may be undefined.
351    """
352    # Derived classes get this "for free" once .shape is implemented.
353    return self.shape[:-2]
354
355  def batch_shape_tensor(self, name="batch_shape_tensor"):
356    """Shape of batch dimensions of this operator, determined at runtime.
357
358    If this operator acts like the batch matrix `A` with
359    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
360    `[B1,...,Bb]`.
361
362    Args:
363      name:  A name for this `Op`.
364
365    Returns:
366      `int32` `Tensor`
367    """
368    # Derived classes get this "for free" once .shape() is implemented.
369    with self._name_scope(name):
370      return self._batch_shape_tensor()
371
372  def _batch_shape_tensor(self, shape=None):
373    # `shape` may be passed in if this can be pre-computed in a
374    # more efficient manner, e.g. without excessive Tensor conversions.
375    if self.batch_shape.is_fully_defined():
376      return linear_operator_util.shape_tensor(
377          self.batch_shape.as_list(), name="batch_shape")
378    else:
379      shape = self.shape_tensor() if shape is None else shape
380      return shape[:-2]
381
382  @property
383  def tensor_rank(self, name="tensor_rank"):
384    """Rank (in the sense of tensors) of matrix corresponding to this operator.
385
386    If this operator acts like the batch matrix `A` with
387    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
388
389    Args:
390      name:  A name for this `Op`.
391
392    Returns:
393      Python integer, or None if the tensor rank is undefined.
394    """
395    # Derived classes get this "for free" once .shape() is implemented.
396    with self._name_scope(name):
397      return self.shape.ndims
398
399  def tensor_rank_tensor(self, name="tensor_rank_tensor"):
400    """Rank (in the sense of tensors) of matrix corresponding to this operator.
401
402    If this operator acts like the batch matrix `A` with
403    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
404
405    Args:
406      name:  A name for this `Op`.
407
408    Returns:
409      `int32` `Tensor`, determined at runtime.
410    """
411    # Derived classes get this "for free" once .shape() is implemented.
412    with self._name_scope(name):
413      return self._tensor_rank_tensor()
414
415  def _tensor_rank_tensor(self, shape=None):
416    # `shape` may be passed in if this can be pre-computed in a
417    # more efficient manner, e.g. without excessive Tensor conversions.
418    if self.tensor_rank is not None:
419      return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
420    else:
421      shape = self.shape_tensor() if shape is None else shape
422      return array_ops.size(shape)
423
424  @property
425  def domain_dimension(self):
426    """Dimension (in the sense of vector spaces) of the domain of this operator.
427
428    If this operator acts like the batch matrix `A` with
429    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
430
431    Returns:
432      `Dimension` object.
433    """
434    # Derived classes get this "for free" once .shape is implemented.
435    if self.shape.rank is None:
436      return tensor_shape.Dimension(None)
437    else:
438      return self.shape.dims[-1]
439
440  def domain_dimension_tensor(self, name="domain_dimension_tensor"):
441    """Dimension (in the sense of vector spaces) of the domain of this operator.
442
443    Determined at runtime.
444
445    If this operator acts like the batch matrix `A` with
446    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
447
448    Args:
449      name:  A name for this `Op`.
450
451    Returns:
452      `int32` `Tensor`
453    """
454    # Derived classes get this "for free" once .shape() is implemented.
455    with self._name_scope(name):
456      return self._domain_dimension_tensor()
457
458  def _domain_dimension_tensor(self, shape=None):
459    # `shape` may be passed in if this can be pre-computed in a
460    # more efficient manner, e.g. without excessive Tensor conversions.
461    dim_value = tensor_shape.dimension_value(self.domain_dimension)
462    if dim_value is not None:
463      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
464    else:
465      shape = self.shape_tensor() if shape is None else shape
466      return shape[-1]
467
468  @property
469  def range_dimension(self):
470    """Dimension (in the sense of vector spaces) of the range of this operator.
471
472    If this operator acts like the batch matrix `A` with
473    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
474
475    Returns:
476      `Dimension` object.
477    """
478    # Derived classes get this "for free" once .shape is implemented.
479    if self.shape.dims:
480      return self.shape.dims[-2]
481    else:
482      return tensor_shape.Dimension(None)
483
484  def range_dimension_tensor(self, name="range_dimension_tensor"):
485    """Dimension (in the sense of vector spaces) of the range of this operator.
486
487    Determined at runtime.
488
489    If this operator acts like the batch matrix `A` with
490    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
491
492    Args:
493      name:  A name for this `Op`.
494
495    Returns:
496      `int32` `Tensor`
497    """
498    # Derived classes get this "for free" once .shape() is implemented.
499    with self._name_scope(name):
500      return self._range_dimension_tensor()
501
502  def _range_dimension_tensor(self, shape=None):
503    # `shape` may be passed in if this can be pre-computed in a
504    # more efficient manner, e.g. without excessive Tensor conversions.
505    dim_value = tensor_shape.dimension_value(self.range_dimension)
506    if dim_value is not None:
507      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
508    else:
509      shape = self.shape_tensor() if shape is None else shape
510      return shape[-2]
511
512  def _assert_non_singular(self):
513    """Private default implementation of _assert_non_singular."""
514    logging.warn(
515        "Using (possibly slow) default implementation of assert_non_singular."
516        "  Requires conversion to a dense matrix and O(N^3) operations.")
517    if self._can_use_cholesky():
518      return self.assert_positive_definite()
519    else:
520      singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
521      # TODO(langmore) Add .eig and .cond as methods.
522      cond = (math_ops.reduce_max(singular_values, axis=-1) /
523              math_ops.reduce_min(singular_values, axis=-1))
524      return check_ops.assert_less(
525          cond,
526          self._max_condition_number_to_be_non_singular(),
527          message="Singular matrix up to precision epsilon.")
528
529  def _max_condition_number_to_be_non_singular(self):
530    """Return the maximum condition number that we consider nonsingular."""
531    with ops.name_scope("max_nonsingular_condition_number"):
532      dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
533      eps = math_ops.cast(
534          math_ops.reduce_max([
535              100.,
536              math_ops.cast(self.range_dimension_tensor(), self.dtype),
537              math_ops.cast(self.domain_dimension_tensor(), self.dtype)
538          ]), self.dtype) * dtype_eps
539      return 1. / eps
540
541  def assert_non_singular(self, name="assert_non_singular"):
542    """Returns an `Op` that asserts this operator is non singular.
543
544    This operator is considered non-singular if
545
546    ```
547    ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
548    eps := np.finfo(self.dtype.as_numpy_dtype).eps
549    ```
550
551    Args:
552      name:  A string name to prepend to created ops.
553
554    Returns:
555      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
556        the operator is singular.
557    """
558    with self._name_scope(name):
559      return self._assert_non_singular()
560
561  def _assert_positive_definite(self):
562    """Default implementation of _assert_positive_definite."""
563    logging.warn(
564        "Using (possibly slow) default implementation of "
565        "assert_positive_definite."
566        "  Requires conversion to a dense matrix and O(N^3) operations.")
567    # If the operator is self-adjoint, then checking that
568    # Cholesky decomposition succeeds + results in positive diag is necessary
569    # and sufficient.
570    if self.is_self_adjoint:
571      return check_ops.assert_positive(
572          array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
573          message="Matrix was not positive definite.")
574    # We have no generic check for positive definite.
575    raise NotImplementedError("assert_positive_definite is not implemented.")
576
577  def assert_positive_definite(self, name="assert_positive_definite"):
578    """Returns an `Op` that asserts this operator is positive definite.
579
580    Here, positive definite means that the quadratic form `x^H A x` has positive
581    real part for all nonzero `x`.  Note that we do not require the operator to
582    be self-adjoint to be positive definite.
583
584    Args:
585      name:  A name to give this `Op`.
586
587    Returns:
588      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
589        the operator is not positive definite.
590    """
591    with self._name_scope(name):
592      return self._assert_positive_definite()
593
594  def _assert_self_adjoint(self):
595    dense = self.to_dense()
596    logging.warn(
597        "Using (possibly slow) default implementation of assert_self_adjoint."
598        "  Requires conversion to a dense matrix.")
599    return check_ops.assert_equal(
600        dense,
601        linalg.adjoint(dense),
602        message="Matrix was not equal to its adjoint.")
603
604  def assert_self_adjoint(self, name="assert_self_adjoint"):
605    """Returns an `Op` that asserts this operator is self-adjoint.
606
607    Here we check that this operator is *exactly* equal to its hermitian
608    transpose.
609
610    Args:
611      name:  A string name to prepend to created ops.
612
613    Returns:
614      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
615        the operator is not self-adjoint.
616    """
617    with self._name_scope(name):
618      return self._assert_self_adjoint()
619
620  def _check_input_dtype(self, arg):
621    """Check that arg.dtype == self.dtype."""
622    if arg.dtype.base_dtype != self.dtype:
623      raise TypeError(
624          "Expected argument to have dtype %s.  Found: %s in tensor %s" %
625          (self.dtype, arg.dtype, arg))
626
627  @abc.abstractmethod
628  def _matmul(self, x, adjoint=False, adjoint_arg=False):
629    raise NotImplementedError("_matmul is not implemented.")
630
631  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
632    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
633
634    ```python
635    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
636    operator = LinearOperator(...)
637    operator.shape = [..., M, N]
638
639    X = ... # shape [..., N, R], batch matrix, R > 0.
640
641    Y = operator.matmul(X)
642    Y.shape
643    ==> [..., M, R]
644
645    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
646    ```
647
648    Args:
649      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
650        `self`. See class docstring for definition of compatibility.
651      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
652      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
653        the hermitian transpose (transposition and complex conjugation).
654      name:  A name for this `Op`.
655
656    Returns:
657      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
658        as `self`.
659    """
660    if isinstance(x, LinearOperator):
661      left_operator = self.adjoint() if adjoint else self
662      right_operator = x.adjoint() if adjoint_arg else x
663
664      if (right_operator.range_dimension is not None and
665          left_operator.domain_dimension is not None and
666          right_operator.range_dimension != left_operator.domain_dimension):
667        raise ValueError(
668            "Operators are incompatible. Expected `x` to have dimension"
669            " {} but got {}.".format(
670                left_operator.domain_dimension, right_operator.range_dimension))
671      with self._name_scope(name):
672        return linear_operator_algebra.matmul(left_operator, right_operator)
673
674    with self._name_scope(name):
675      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
676      self._check_input_dtype(x)
677
678      self_dim = -2 if adjoint else -1
679      arg_dim = -1 if adjoint_arg else -2
680      tensor_shape.dimension_at_index(
681          self.shape, self_dim).assert_is_compatible_with(
682              x.shape[arg_dim])
683
684      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
685
686  def __matmul__(self, other):
687    return self.matmul(other)
688
689  def _matvec(self, x, adjoint=False):
690    x_mat = array_ops.expand_dims(x, axis=-1)
691    y_mat = self.matmul(x_mat, adjoint=adjoint)
692    return array_ops.squeeze(y_mat, axis=-1)
693
694  def matvec(self, x, adjoint=False, name="matvec"):
695    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
696
697    ```python
698    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
699    operator = LinearOperator(...)
700
701    X = ... # shape [..., N], batch vector
702
703    Y = operator.matvec(X)
704    Y.shape
705    ==> [..., M]
706
707    Y[..., :] = sum_j A[..., :, j] X[..., j]
708    ```
709
710    Args:
711      x: `Tensor` with compatible shape and same `dtype` as `self`.
712        `x` is treated as a [batch] vector meaning for every set of leading
713        dimensions, the last dimension defines a vector.
714        See class docstring for definition of compatibility.
715      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
716      name:  A name for this `Op`.
717
718    Returns:
719      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
720    """
721    with self._name_scope(name):
722      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
723      self._check_input_dtype(x)
724      self_dim = -2 if adjoint else -1
725      tensor_shape.dimension_at_index(
726          self.shape, self_dim).assert_is_compatible_with(x.shape[-1])
727      return self._matvec(x, adjoint=adjoint)
728
729  def _determinant(self):
730    logging.warn(
731        "Using (possibly slow) default implementation of determinant."
732        "  Requires conversion to a dense matrix and O(N^3) operations.")
733    if self._can_use_cholesky():
734      return math_ops.exp(self.log_abs_determinant())
735    return linalg_ops.matrix_determinant(self.to_dense())
736
737  def determinant(self, name="det"):
738    """Determinant for every batch member.
739
740    Args:
741      name:  A name for this `Op`.
742
743    Returns:
744      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
745
746    Raises:
747      NotImplementedError:  If `self.is_square` is `False`.
748    """
749    if self.is_square is False:
750      raise NotImplementedError(
751          "Determinant not implemented for an operator that is expected to "
752          "not be square.")
753    with self._name_scope(name):
754      return self._determinant()
755
756  def _log_abs_determinant(self):
757    logging.warn(
758        "Using (possibly slow) default implementation of determinant."
759        "  Requires conversion to a dense matrix and O(N^3) operations.")
760    if self._can_use_cholesky():
761      diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
762      return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
763    _, log_abs_det = linalg.slogdet(self.to_dense())
764    return log_abs_det
765
766  def log_abs_determinant(self, name="log_abs_det"):
767    """Log absolute value of determinant for every batch member.
768
769    Args:
770      name:  A name for this `Op`.
771
772    Returns:
773      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
774
775    Raises:
776      NotImplementedError:  If `self.is_square` is `False`.
777    """
778    if self.is_square is False:
779      raise NotImplementedError(
780          "Determinant not implemented for an operator that is expected to "
781          "not be square.")
782    with self._name_scope(name):
783      return self._log_abs_determinant()
784
785  def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
786    """Solve by conversion to a dense matrix."""
787    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
788      raise NotImplementedError(
789          "Solve is not yet implemented for non-square operators.")
790    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
791    if self._can_use_cholesky():
792      return linalg_ops.cholesky_solve(
793          linalg_ops.cholesky(self.to_dense()), rhs)
794    return linear_operator_util.matrix_solve_with_broadcast(
795        self.to_dense(), rhs, adjoint=adjoint)
796
797  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
798    """Default implementation of _solve."""
799    logging.warn(
800        "Using (possibly slow) default implementation of solve."
801        "  Requires conversion to a dense matrix and O(N^3) operations.")
802    return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
803
804  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
805    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
806
807    The returned `Tensor` will be close to an exact solution if `A` is well
808    conditioned. Otherwise closeness will vary. See class docstring for details.
809
810    Examples:
811
812    ```python
813    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
814    operator = LinearOperator(...)
815    operator.shape = [..., M, N]
816
817    # Solve R > 0 linear systems for every member of the batch.
818    RHS = ... # shape [..., M, R]
819
820    X = operator.solve(RHS)
821    # X[..., :, r] is the solution to the r'th linear system
822    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
823
824    operator.matmul(X)
825    ==> RHS
826    ```
827
828    Args:
829      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
830        `rhs` is treated like a [batch] matrix meaning for every set of leading
831        dimensions, the last two dimensions defines a matrix.
832        See class docstring for definition of compatibility.
833      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
834        of this `LinearOperator`:  `A^H X = rhs`.
835      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
836        is the hermitian transpose (transposition and complex conjugation).
837      name:  A name scope to use for ops added by this method.
838
839    Returns:
840      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
841
842    Raises:
843      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
844    """
845    if self.is_non_singular is False:
846      raise NotImplementedError(
847          "Exact solve not implemented for an operator that is expected to "
848          "be singular.")
849    if self.is_square is False:
850      raise NotImplementedError(
851          "Exact solve not implemented for an operator that is expected to "
852          "not be square.")
853    if isinstance(rhs, LinearOperator):
854      left_operator = self.adjoint() if adjoint else self
855      right_operator = rhs.adjoint() if adjoint_arg else rhs
856
857      if (right_operator.range_dimension is not None and
858          left_operator.domain_dimension is not None and
859          right_operator.range_dimension != left_operator.domain_dimension):
860        raise ValueError(
861            "Operators are incompatible. Expected `rhs` to have dimension"
862            " {} but got {}.".format(
863                left_operator.domain_dimension, right_operator.range_dimension))
864      with self._name_scope(name):
865        return linear_operator_algebra.solve(left_operator, right_operator)
866
867    with self._name_scope(name):
868      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
869      self._check_input_dtype(rhs)
870
871      self_dim = -1 if adjoint else -2
872      arg_dim = -1 if adjoint_arg else -2
873      tensor_shape.dimension_at_index(
874          self.shape, self_dim).assert_is_compatible_with(
875              rhs.shape[arg_dim])
876
877      return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
878
879  def _solvevec(self, rhs, adjoint=False):
880    """Default implementation of _solvevec."""
881    rhs_mat = array_ops.expand_dims(rhs, axis=-1)
882    solution_mat = self.solve(rhs_mat, adjoint=adjoint)
883    return array_ops.squeeze(solution_mat, axis=-1)
884
885  def solvevec(self, rhs, adjoint=False, name="solve"):
886    """Solve single equation with best effort: `A X = rhs`.
887
888    The returned `Tensor` will be close to an exact solution if `A` is well
889    conditioned. Otherwise closeness will vary. See class docstring for details.
890
891    Examples:
892
893    ```python
894    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
895    operator = LinearOperator(...)
896    operator.shape = [..., M, N]
897
898    # Solve one linear system for every member of the batch.
899    RHS = ... # shape [..., M]
900
901    X = operator.solvevec(RHS)
902    # X is the solution to the linear system
903    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
904
905    operator.matvec(X)
906    ==> RHS
907    ```
908
909    Args:
910      rhs: `Tensor` with same `dtype` as this operator.
911        `rhs` is treated like a [batch] vector meaning for every set of leading
912        dimensions, the last dimension defines a vector.  See class docstring
913        for definition of compatibility regarding batch dimensions.
914      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
915        of this `LinearOperator`:  `A^H X = rhs`.
916      name:  A name scope to use for ops added by this method.
917
918    Returns:
919      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
920
921    Raises:
922      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
923    """
924    with self._name_scope(name):
925      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
926      self._check_input_dtype(rhs)
927      self_dim = -1 if adjoint else -2
928      tensor_shape.dimension_at_index(
929          self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1])
930
931      return self._solvevec(rhs, adjoint=adjoint)
932
933  def adjoint(self, name="adjoint"):
934    """Returns the adjoint of the current `LinearOperator`.
935
936    Given `A` representing this `LinearOperator`, return `A*`.
937    Note that calling `self.adjoint()` and `self.H` are equivalent.
938
939    Args:
940      name:  A name for this `Op`.
941
942    Returns:
943      `LinearOperator` which represents the adjoint of this `LinearOperator`.
944    """
945    if self.is_self_adjoint is True:  # pylint: disable=g-bool-id-comparison
946      return self
947    with self._name_scope(name):
948      return linear_operator_algebra.adjoint(self)
949
950  # self.H is equivalent to self.adjoint().
951  H = property(adjoint, None)
952
953  def inverse(self, name="inverse"):
954    """Returns the Inverse of this `LinearOperator`.
955
956    Given `A` representing this `LinearOperator`, return a `LinearOperator`
957    representing `A^-1`.
958
959    Args:
960      name: A name scope to use for ops added by this method.
961
962    Returns:
963      `LinearOperator` representing inverse of this matrix.
964
965    Raises:
966      ValueError: When the `LinearOperator` is not hinted to be `non_singular`.
967    """
968    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
969      raise ValueError("Cannot take the Inverse: This operator represents "
970                       "a non square matrix.")
971    if self.is_non_singular is False:  # pylint: disable=g-bool-id-comparison
972      raise ValueError("Cannot take the Inverse: This operator represents "
973                       "a singular matrix.")
974
975    with self._name_scope(name):
976      return linear_operator_algebra.inverse(self)
977
978  def cholesky(self, name="cholesky"):
979    """Returns a Cholesky factor as a `LinearOperator`.
980
981    Given `A` representing this `LinearOperator`, if `A` is positive definite
982    self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky
983    decomposition.
984
985    Args:
986      name:  A name for this `Op`.
987
988    Returns:
989      `LinearOperator` which represents the lower triangular matrix
990      in the Cholesky decomposition.
991
992    Raises:
993      ValueError: When the `LinearOperator` is not hinted to be positive
994        definite and self adjoint.
995    """
996
997    if not self._can_use_cholesky():
998      raise ValueError("Cannot take the Cholesky decomposition: "
999                       "Not a positive definite self adjoint matrix.")
1000    with self._name_scope(name):
1001      return linear_operator_algebra.cholesky(self)
1002
1003  def _to_dense(self):
1004    """Generic and often inefficient implementation.  Override often."""
1005    if self.batch_shape.is_fully_defined():
1006      batch_shape = self.batch_shape
1007    else:
1008      batch_shape = self.batch_shape_tensor()
1009
1010    dim_value = tensor_shape.dimension_value(self.domain_dimension)
1011    if dim_value is not None:
1012      n = dim_value
1013    else:
1014      n = self.domain_dimension_tensor()
1015
1016    eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
1017    return self.matmul(eye)
1018
1019  def to_dense(self, name="to_dense"):
1020    """Return a dense (batch) matrix representing this operator."""
1021    with self._name_scope(name):
1022      return self._to_dense()
1023
1024  def _diag_part(self):
1025    """Generic and often inefficient implementation.  Override often."""
1026    return array_ops.matrix_diag_part(self.to_dense())
1027
1028  def diag_part(self, name="diag_part"):
1029    """Efficiently get the [batch] diagonal part of this operator.
1030
1031    If this operator has shape `[B1,...,Bb, M, N]`, this returns a
1032    `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where
1033    `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`.
1034
1035    ```
1036    my_operator = LinearOperatorDiag([1., 2.])
1037
1038    # Efficiently get the diagonal
1039    my_operator.diag_part()
1040    ==> [1., 2.]
1041
1042    # Equivalent, but inefficient method
1043    tf.linalg.diag_part(my_operator.to_dense())
1044    ==> [1., 2.]
1045    ```
1046
1047    Args:
1048      name:  A name for this `Op`.
1049
1050    Returns:
1051      diag_part:  A `Tensor` of same `dtype` as self.
1052    """
1053    with self._name_scope(name):
1054      return self._diag_part()
1055
1056  def _trace(self):
1057    return math_ops.reduce_sum(self.diag_part(), axis=-1)
1058
1059  def trace(self, name="trace"):
1060    """Trace of the linear operator, equal to sum of `self.diag_part()`.
1061
1062    If the operator is square, this is also the sum of the eigenvalues.
1063
1064    Args:
1065      name:  A name for this `Op`.
1066
1067    Returns:
1068      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1069    """
1070    with self._name_scope(name):
1071      return self._trace()
1072
1073  def _add_to_tensor(self, x):
1074    # Override if a more efficient implementation is available.
1075    return self.to_dense() + x
1076
1077  def add_to_tensor(self, x, name="add_to_tensor"):
1078    """Add matrix represented by this operator to `x`.  Equivalent to `A + x`.
1079
1080    Args:
1081      x:  `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
1082      name:  A name to give this `Op`.
1083
1084    Returns:
1085      A `Tensor` with broadcast shape and same `dtype` as `self`.
1086    """
1087    with self._name_scope(name):
1088      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
1089      self._check_input_dtype(x)
1090      return self._add_to_tensor(x)
1091
1092  def _eigvals(self):
1093    return linalg_ops.self_adjoint_eigvals(self.to_dense())
1094
1095  def eigvals(self, name="eigvals"):
1096    """Returns the eigenvalues of this linear operator.
1097
1098    If the operator is marked as self-adjoint (via `is_self_adjoint`)
1099    this computation can be more efficient.
1100
1101    Note: This currently only supports self-adjoint operators.
1102
1103    Args:
1104      name:  A name for this `Op`.
1105
1106    Returns:
1107      Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`.
1108    """
1109    if not self.is_self_adjoint:
1110      raise NotImplementedError("Only self-adjoint matrices are supported.")
1111    with self._name_scope(name):
1112      return self._eigvals()
1113
1114  def _cond(self):
1115    if not self.is_self_adjoint:
1116      # In general the condition number is the ratio of the
1117      # absolute value of the largest and smallest singular values.
1118      vals = linalg_ops.svd(self.to_dense(), compute_uv=False)
1119    else:
1120      # For self-adjoint matrices, and in general normal matrices,
1121      # we can use eigenvalues.
1122      vals = math_ops.abs(self._eigvals())
1123
1124    return (math_ops.reduce_max(vals, axis=-1) /
1125            math_ops.reduce_min(vals, axis=-1))
1126
1127  def cond(self, name="cond"):
1128    """Returns the condition number of this linear operator.
1129
1130    Args:
1131      name:  A name for this `Op`.
1132
1133    Returns:
1134      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1135    """
1136    with self._name_scope(name):
1137      return self._cond()
1138
1139  def _can_use_cholesky(self):
1140    return self.is_self_adjoint and self.is_positive_definite
1141
1142  def _set_graph_parents(self, graph_parents):
1143    """Set self._graph_parents.  Called during derived class init.
1144
1145    This method allows derived classes to set graph_parents, without triggering
1146    a deprecation warning (which is invoked if `graph_parents` is passed during
1147    `__init__`.
1148
1149    Args:
1150      graph_parents: Iterable over Tensors.
1151    """
1152    # TODO(b/143910018) Remove this function in V3.
1153    graph_parents = [] if graph_parents is None else graph_parents
1154    for i, t in enumerate(graph_parents):
1155      if t is None or not (linear_operator_util.is_ref(t) or
1156                           tensor_util.is_tf_type(t)):
1157        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
1158    self._graph_parents = graph_parents
1159
1160
1161# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
1162# place of a Tensor.
1163# For instance tf.trace(linop) and linop.trace() both work.
1164
1165
1166@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator)
1167def _adjoint(matrix, name=None):
1168  return matrix.adjoint(name)
1169
1170
1171@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator)
1172def _cholesky(input, name=None):   # pylint:disable=redefined-builtin
1173  return input.cholesky(name)
1174
1175
1176# The signature has to match with the one in python/op/array_ops.py,
1177# so we have k, padding_value, and align even though we don't use them here.
1178# pylint:disable=unused-argument
1179@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
1180def _diag_part(
1181    input,  # pylint:disable=redefined-builtin
1182    name="diag_part",
1183    k=0,
1184    padding_value=0,
1185    align="RIGHT_LEFT"):
1186  return input.diag_part(name)
1187# pylint:enable=unused-argument
1188
1189
1190@dispatch.dispatch_for_types(linalg.det, LinearOperator)
1191def _det(input, name=None):  # pylint:disable=redefined-builtin
1192  return input.determinant(name)
1193
1194
1195@dispatch.dispatch_for_types(linalg.inv, LinearOperator)
1196def _inverse(input, adjoint=False, name=None):   # pylint:disable=redefined-builtin
1197  inv = input.inverse(name)
1198  if adjoint:
1199    inv = inv.adjoint()
1200  return inv
1201
1202
1203@dispatch.dispatch_for_types(linalg.logdet, LinearOperator)
1204def _logdet(matrix, name=None):
1205  if matrix.is_positive_definite and matrix.is_self_adjoint:
1206    return matrix.log_abs_determinant(name)
1207  raise ValueError("Expected matrix to be self-adjoint positive definite.")
1208
1209
1210@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator)
1211def _matmul(  # pylint:disable=missing-docstring
1212    a,
1213    b,
1214    transpose_a=False,
1215    transpose_b=False,
1216    adjoint_a=False,
1217    adjoint_b=False,
1218    a_is_sparse=False,
1219    b_is_sparse=False,
1220    name=None):
1221  if transpose_a or transpose_b:
1222    raise ValueError("Transposing not supported at this time.")
1223  if a_is_sparse or b_is_sparse:
1224    raise ValueError("Sparse methods not supported at this time.")
1225  if not isinstance(a, LinearOperator):
1226    # We use the identity (B^HA^H)^H =  AB
1227    adjoint_matmul = b.matmul(
1228        a,
1229        adjoint=(not adjoint_b),
1230        adjoint_arg=(not adjoint_a),
1231        name=name)
1232    return linalg.adjoint(adjoint_matmul)
1233  return a.matmul(
1234      b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
1235
1236
1237@dispatch.dispatch_for_types(linalg.solve, LinearOperator)
1238def _solve(
1239    matrix,
1240    rhs,
1241    adjoint=False,
1242    name=None):
1243  if not isinstance(matrix, LinearOperator):
1244    raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a "
1245                     "LinearOperator is not supported.")
1246  return matrix.solve(rhs, adjoint=adjoint, name=name)
1247
1248
1249@dispatch.dispatch_for_types(linalg.trace, LinearOperator)
1250def _trace(x, name=None):
1251  return x.trace(name)
1252