1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for linear operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23
24import numpy as np
25import six
26
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import linalg_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.linalg import linalg_impl as linalg
35from tensorflow.python.ops.linalg import linear_operator_algebra
36from tensorflow.python.ops.linalg import linear_operator_util
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util.tf_export import tf_export
39
40__all__ = ["LinearOperator"]
41
42
43# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
44@tf_export("linalg.LinearOperator")
45@six.add_metaclass(abc.ABCMeta)
46class LinearOperator(object):
47  """Base class defining a [batch of] linear operator[s].
48
49  Subclasses of `LinearOperator` provide access to common methods on a
50  (batch) matrix, without the need to materialize the matrix.  This allows:
51
52  * Matrix free computations
53  * Operators that take advantage of special structure, while providing a
54    consistent API to users.
55
56  #### Subclassing
57
58  To enable a public method, subclasses should implement the leading-underscore
59  version of the method.  The argument signature should be identical except for
60  the omission of `name="..."`.  For example, to enable
61  `matmul(x, adjoint=False, name="matmul")` a subclass should implement
62  `_matmul(x, adjoint=False)`.
63
64  #### Performance contract
65
66  Subclasses should only implement the assert methods
67  (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
68  time.
69
70  Class docstrings should contain an explanation of computational complexity.
71  Since this is a high-performance library, attention should be paid to detail,
72  and explanations can include constants as well as Big-O notation.
73
74  #### Shape compatibility
75
76  `LinearOperator` subclasses should operate on a [batch] matrix with
77  compatible shape.  Class docstrings should define what is meant by compatible
78  shape.  Some subclasses may not support batching.
79
80  Examples:
81
82  `x` is a batch matrix with compatible shape for `matmul` if
83
84  ```
85  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
86  x.shape =   [B1,...,Bb] + [N, R]
87  ```
88
89  `rhs` is a batch matrix with compatible shape for `solve` if
90
91  ```
92  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
93  rhs.shape =   [B1,...,Bb] + [M, R]
94  ```
95
96  #### Example docstring for subclasses.
97
98  This operator acts like a (batch) matrix `A` with shape
99  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
100  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
101  an `m x n` matrix.  Again, this matrix `A` may not be materialized, but for
102  purposes of identifying and working with compatible arguments the shape is
103  relevant.
104
105  Examples:
106
107  ```python
108  some_tensor = ... shape = ????
109  operator = MyLinOp(some_tensor)
110
111  operator.shape()
112  ==> [2, 4, 4]
113
114  operator.log_abs_determinant()
115  ==> Shape [2] Tensor
116
117  x = ... Shape [2, 4, 5] Tensor
118
119  operator.matmul(x)
120  ==> Shape [2, 4, 5] Tensor
121  ```
122
123  #### Shape compatibility
124
125  This operator acts on batch matrices with compatible shape.
126  FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
127
128  #### Performance
129
130  FILL THIS IN
131
132  #### Matrix property hints
133
134  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
135  for `X = non_singular, self_adjoint, positive_definite, square`.
136  These have the following meaning:
137
138  * If `is_X == True`, callers should expect the operator to have the
139    property `X`.  This is a promise that should be fulfilled, but is *not* a
140    runtime assert.  For example, finite floating point precision may result
141    in these promises being violated.
142  * If `is_X == False`, callers should expect the operator to not have `X`.
143  * If `is_X == None` (the default), callers should have no expectation either
144    way.
145  """
146
147  def __init__(self,
148               dtype,
149               graph_parents=None,
150               is_non_singular=None,
151               is_self_adjoint=None,
152               is_positive_definite=None,
153               is_square=None,
154               name=None):
155    r"""Initialize the `LinearOperator`.
156
157    **This is a private method for subclass use.**
158    **Subclasses should copy-paste this `__init__` documentation.**
159
160    Args:
161      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
162        `solve` will have to be this type.
163      graph_parents: Python list of graph prerequisites of this `LinearOperator`
164        Typically tensors that are passed during initialization.
165      is_non_singular:  Expect that this operator is non-singular.
166      is_self_adjoint:  Expect that this operator is equal to its hermitian
167        transpose.  If `dtype` is real, this is equivalent to being symmetric.
168      is_positive_definite:  Expect that this operator is positive definite,
169        meaning the quadratic form `x^H A x` has positive real part for all
170        nonzero `x`.  Note that we do not require the operator to be
171        self-adjoint to be positive-definite.  See:
172        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
173      is_square:  Expect that this operator acts like square [batch] matrices.
174      name: A name for this `LinearOperator`.
175
176    Raises:
177      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
178      ValueError:  If hints are set incorrectly.
179    """
180    # Check and auto-set flags.
181    if is_positive_definite:
182      if is_non_singular is False:
183        raise ValueError("A positive definite matrix is always non-singular.")
184      is_non_singular = True
185
186    if is_non_singular:
187      if is_square is False:
188        raise ValueError("A non-singular matrix is always square.")
189      is_square = True
190
191    if is_self_adjoint:
192      if is_square is False:
193        raise ValueError("A self-adjoint matrix is always square.")
194      is_square = True
195
196    self._is_square_set_or_implied_by_hints = is_square
197
198    graph_parents = [] if graph_parents is None else graph_parents
199    for i, t in enumerate(graph_parents):
200      if t is None or not tensor_util.is_tensor(t):
201        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
202    self._dtype = dtype
203    self._graph_parents = graph_parents
204    self._is_non_singular = is_non_singular
205    self._is_self_adjoint = is_self_adjoint
206    self._is_positive_definite = is_positive_definite
207    self._name = name or type(self).__name__
208
209  @contextlib.contextmanager
210  def _name_scope(self, name=None, values=None):
211    """Helper function to standardize op scope."""
212    with ops.name_scope(self.name):
213      with ops.name_scope(
214          name, values=((values or []) + self._graph_parents)) as scope:
215        yield scope
216
217  @property
218  def dtype(self):
219    """The `DType` of `Tensor`s handled by this `LinearOperator`."""
220    return self._dtype
221
222  @property
223  def name(self):
224    """Name prepended to all ops created by this `LinearOperator`."""
225    return self._name
226
227  @property
228  def graph_parents(self):
229    """List of graph dependencies of this `LinearOperator`."""
230    return self._graph_parents
231
232  @property
233  def is_non_singular(self):
234    return self._is_non_singular
235
236  @property
237  def is_self_adjoint(self):
238    return self._is_self_adjoint
239
240  @property
241  def is_positive_definite(self):
242    return self._is_positive_definite
243
244  @property
245  def is_square(self):
246    """Return `True/False` depending on if this operator is square."""
247    # Static checks done after __init__.  Why?  Because domain/range dimension
248    # sometimes requires lots of work done in the derived class after init.
249    auto_square_check = self.domain_dimension == self.range_dimension
250    if self._is_square_set_or_implied_by_hints is False and auto_square_check:
251      raise ValueError(
252          "User set is_square hint to False, but the operator was square.")
253    if self._is_square_set_or_implied_by_hints is None:
254      return auto_square_check
255
256    return self._is_square_set_or_implied_by_hints
257
258  @abc.abstractmethod
259  def _shape(self):
260    # Write this in derived class to enable all static shape methods.
261    raise NotImplementedError("_shape is not implemented.")
262
263  @property
264  def shape(self):
265    """`TensorShape` of this `LinearOperator`.
266
267    If this operator acts like the batch matrix `A` with
268    `A.shape = [B1,...,Bb, M, N]`, then this returns
269    `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.get_shape()`.
270
271    Returns:
272      `TensorShape`, statically determined, may be undefined.
273    """
274    return self._shape()
275
276  @abc.abstractmethod
277  def _shape_tensor(self):
278    raise NotImplementedError("_shape_tensor is not implemented.")
279
280  def shape_tensor(self, name="shape_tensor"):
281    """Shape of this `LinearOperator`, determined at runtime.
282
283    If this operator acts like the batch matrix `A` with
284    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
285    `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
286
287    Args:
288      name:  A name for this `Op`.
289
290    Returns:
291      `int32` `Tensor`
292    """
293    with self._name_scope(name):
294      # Prefer to use statically defined shape if available.
295      if self.shape.is_fully_defined():
296        return linear_operator_util.shape_tensor(self.shape.as_list())
297      else:
298        return self._shape_tensor()
299
300  @property
301  def batch_shape(self):
302    """`TensorShape` of batch dimensions of this `LinearOperator`.
303
304    If this operator acts like the batch matrix `A` with
305    `A.shape = [B1,...,Bb, M, N]`, then this returns
306    `TensorShape([B1,...,Bb])`, equivalent to `A.get_shape()[:-2]`
307
308    Returns:
309      `TensorShape`, statically determined, may be undefined.
310    """
311    # Derived classes get this "for free" once .shape is implemented.
312    return self.shape[:-2]
313
314  def batch_shape_tensor(self, name="batch_shape_tensor"):
315    """Shape of batch dimensions of this operator, determined at runtime.
316
317    If this operator acts like the batch matrix `A` with
318    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
319    `[B1,...,Bb]`.
320
321    Args:
322      name:  A name for this `Op`.
323
324    Returns:
325      `int32` `Tensor`
326    """
327    # Derived classes get this "for free" once .shape() is implemented.
328    with self._name_scope(name):
329      # Prefer to use statically defined shape if available.
330      if self.batch_shape.is_fully_defined():
331        return linear_operator_util.shape_tensor(
332            self.batch_shape.as_list(), name="batch_shape")
333      else:
334        return self.shape_tensor()[:-2]
335
336  @property
337  def tensor_rank(self, name="tensor_rank"):
338    """Rank (in the sense of tensors) of matrix corresponding to this operator.
339
340    If this operator acts like the batch matrix `A` with
341    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
342
343    Args:
344      name:  A name for this `Op`.
345
346    Returns:
347      Python integer, or None if the tensor rank is undefined.
348    """
349    # Derived classes get this "for free" once .shape() is implemented.
350    with self._name_scope(name):
351      return self.shape.ndims
352
353  def tensor_rank_tensor(self, name="tensor_rank_tensor"):
354    """Rank (in the sense of tensors) of matrix corresponding to this operator.
355
356    If this operator acts like the batch matrix `A` with
357    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
358
359    Args:
360      name:  A name for this `Op`.
361
362    Returns:
363      `int32` `Tensor`, determined at runtime.
364    """
365    # Derived classes get this "for free" once .shape() is implemented.
366    with self._name_scope(name):
367      # Prefer to use statically defined shape if available.
368      if self.tensor_rank is not None:
369        return ops.convert_to_tensor(self.tensor_rank)
370      else:
371        return array_ops.size(self.shape_tensor())
372
373  @property
374  def domain_dimension(self):
375    """Dimension (in the sense of vector spaces) of the domain of this operator.
376
377    If this operator acts like the batch matrix `A` with
378    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
379
380    Returns:
381      `Dimension` object.
382    """
383    # Derived classes get this "for free" once .shape is implemented.
384    if self.shape.rank is None:
385      return tensor_shape.Dimension(None)
386    else:
387      return self.shape.dims[-1]
388
389  def domain_dimension_tensor(self, name="domain_dimension_tensor"):
390    """Dimension (in the sense of vector spaces) of the domain of this operator.
391
392    Determined at runtime.
393
394    If this operator acts like the batch matrix `A` with
395    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
396
397    Args:
398      name:  A name for this `Op`.
399
400    Returns:
401      `int32` `Tensor`
402    """
403    # Derived classes get this "for free" once .shape() is implemented.
404    with self._name_scope(name):
405      # Prefer to use statically defined shape if available.
406      dim_value = tensor_shape.dimension_value(self.domain_dimension)
407      if dim_value is not None:
408        return ops.convert_to_tensor(dim_value)
409      else:
410        return self.shape_tensor()[-1]
411
412  @property
413  def range_dimension(self):
414    """Dimension (in the sense of vector spaces) of the range of this operator.
415
416    If this operator acts like the batch matrix `A` with
417    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
418
419    Returns:
420      `Dimension` object.
421    """
422    # Derived classes get this "for free" once .shape is implemented.
423    if self.shape.dims:
424      return self.shape.dims[-2]
425    else:
426      return tensor_shape.Dimension(None)
427
428  def range_dimension_tensor(self, name="range_dimension_tensor"):
429    """Dimension (in the sense of vector spaces) of the range of this operator.
430
431    Determined at runtime.
432
433    If this operator acts like the batch matrix `A` with
434    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
435
436    Args:
437      name:  A name for this `Op`.
438
439    Returns:
440      `int32` `Tensor`
441    """
442    # Derived classes get this "for free" once .shape() is implemented.
443    with self._name_scope(name):
444      # Prefer to use statically defined shape if available.
445      dim_value = tensor_shape.dimension_value(self.range_dimension)
446      if dim_value is not None:
447        return ops.convert_to_tensor(dim_value)
448      else:
449        return self.shape_tensor()[-2]
450
451  def _assert_non_singular(self):
452    """Private default implementation of _assert_non_singular."""
453    logging.warn(
454        "Using (possibly slow) default implementation of assert_non_singular."
455        "  Requires conversion to a dense matrix and O(N^3) operations.")
456    if self._can_use_cholesky():
457      return self.assert_positive_definite()
458    else:
459      singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
460      # TODO(langmore) Add .eig and .cond as methods.
461      cond = (math_ops.reduce_max(singular_values, axis=-1) /
462              math_ops.reduce_min(singular_values, axis=-1))
463      return check_ops.assert_less(
464          cond,
465          self._max_condition_number_to_be_non_singular(),
466          message="Singular matrix up to precision epsilon.")
467
468  def _max_condition_number_to_be_non_singular(self):
469    """Return the maximum condition number that we consider nonsingular."""
470    with ops.name_scope("max_nonsingular_condition_number"):
471      dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
472      eps = math_ops.cast(
473          math_ops.reduce_max([
474              100.,
475              math_ops.cast(self.range_dimension_tensor(), self.dtype),
476              math_ops.cast(self.domain_dimension_tensor(), self.dtype)
477          ]), self.dtype) * dtype_eps
478      return 1. / eps
479
480  def assert_non_singular(self, name="assert_non_singular"):
481    """Returns an `Op` that asserts this operator is non singular.
482
483    This operator is considered non-singular if
484
485    ```
486    ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
487    eps := np.finfo(self.dtype.as_numpy_dtype).eps
488    ```
489
490    Args:
491      name:  A string name to prepend to created ops.
492
493    Returns:
494      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
495        the operator is singular.
496    """
497    with self._name_scope(name):
498      return self._assert_non_singular()
499
500  def _assert_positive_definite(self):
501    """Default implementation of _assert_positive_definite."""
502    logging.warn(
503        "Using (possibly slow) default implementation of "
504        "assert_positive_definite."
505        "  Requires conversion to a dense matrix and O(N^3) operations.")
506    # If the operator is self-adjoint, then checking that
507    # Cholesky decomposition succeeds + results in positive diag is necessary
508    # and sufficient.
509    if self.is_self_adjoint:
510      return check_ops.assert_positive(
511          array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
512          message="Matrix was not positive definite.")
513    # We have no generic check for positive definite.
514    raise NotImplementedError("assert_positive_definite is not implemented.")
515
516  def assert_positive_definite(self, name="assert_positive_definite"):
517    """Returns an `Op` that asserts this operator is positive definite.
518
519    Here, positive definite means that the quadratic form `x^H A x` has positive
520    real part for all nonzero `x`.  Note that we do not require the operator to
521    be self-adjoint to be positive definite.
522
523    Args:
524      name:  A name to give this `Op`.
525
526    Returns:
527      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
528        the operator is not positive definite.
529    """
530    with self._name_scope(name):
531      return self._assert_positive_definite()
532
533  def _assert_self_adjoint(self):
534    dense = self.to_dense()
535    logging.warn(
536        "Using (possibly slow) default implementation of assert_self_adjoint."
537        "  Requires conversion to a dense matrix.")
538    return check_ops.assert_equal(
539        dense,
540        linalg.adjoint(dense),
541        message="Matrix was not equal to its adjoint.")
542
543  def assert_self_adjoint(self, name="assert_self_adjoint"):
544    """Returns an `Op` that asserts this operator is self-adjoint.
545
546    Here we check that this operator is *exactly* equal to its hermitian
547    transpose.
548
549    Args:
550      name:  A string name to prepend to created ops.
551
552    Returns:
553      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
554        the operator is not self-adjoint.
555    """
556    with self._name_scope(name):
557      return self._assert_self_adjoint()
558
559  def _check_input_dtype(self, arg):
560    """Check that arg.dtype == self.dtype."""
561    if arg.dtype != self.dtype:
562      raise TypeError(
563          "Expected argument to have dtype %s.  Found: %s in tensor %s" %
564          (self.dtype, arg.dtype, arg))
565
566  @abc.abstractmethod
567  def _matmul(self, x, adjoint=False, adjoint_arg=False):
568    raise NotImplementedError("_matmul is not implemented.")
569
570  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
571    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
572
573    ```python
574    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
575    operator = LinearOperator(...)
576    operator.shape = [..., M, N]
577
578    X = ... # shape [..., N, R], batch matrix, R > 0.
579
580    Y = operator.matmul(X)
581    Y.shape
582    ==> [..., M, R]
583
584    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
585    ```
586
587    Args:
588      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
589        `self`. See class docstring for definition of compatibility.
590      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
591      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
592        the hermitian transpose (transposition and complex conjugation).
593      name:  A name for this `Op`.
594
595    Returns:
596      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
597        as `self`.
598    """
599    if isinstance(x, LinearOperator):
600      if adjoint or adjoint_arg:
601        raise ValueError(".matmul not supported with adjoints.")
602      if (x.range_dimension is not None and
603          self.domain_dimension is not None and
604          x.range_dimension != self.domain_dimension):
605        raise ValueError(
606            "Operators are incompatible. Expected `x` to have dimension"
607            " {} but got {}.".format(self.domain_dimension, x.range_dimension))
608      with self._name_scope(name):
609        return linear_operator_algebra.matmul(self, x)
610
611    with self._name_scope(name, values=[x]):
612      x = ops.convert_to_tensor(x, name="x")
613      self._check_input_dtype(x)
614
615      self_dim = -2 if adjoint else -1
616      arg_dim = -1 if adjoint_arg else -2
617      tensor_shape.dimension_at_index(
618          self.shape, self_dim).assert_is_compatible_with(
619              x.get_shape()[arg_dim])
620
621      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
622
623  def _matvec(self, x, adjoint=False):
624    x_mat = array_ops.expand_dims(x, axis=-1)
625    y_mat = self.matmul(x_mat, adjoint=adjoint)
626    return array_ops.squeeze(y_mat, axis=-1)
627
628  def matvec(self, x, adjoint=False, name="matvec"):
629    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
630
631    ```python
632    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
633    operator = LinearOperator(...)
634
635    X = ... # shape [..., N], batch vector
636
637    Y = operator.matvec(X)
638    Y.shape
639    ==> [..., M]
640
641    Y[..., :] = sum_j A[..., :, j] X[..., j]
642    ```
643
644    Args:
645      x: `Tensor` with compatible shape and same `dtype` as `self`.
646        `x` is treated as a [batch] vector meaning for every set of leading
647        dimensions, the last dimension defines a vector.
648        See class docstring for definition of compatibility.
649      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
650      name:  A name for this `Op`.
651
652    Returns:
653      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
654    """
655    with self._name_scope(name, values=[x]):
656      x = ops.convert_to_tensor(x, name="x")
657      self._check_input_dtype(x)
658      self_dim = -2 if adjoint else -1
659      tensor_shape.dimension_at_index(
660          self.shape, self_dim).assert_is_compatible_with(x.get_shape()[-1])
661      return self._matvec(x, adjoint=adjoint)
662
663  def _determinant(self):
664    logging.warn(
665        "Using (possibly slow) default implementation of determinant."
666        "  Requires conversion to a dense matrix and O(N^3) operations.")
667    if self._can_use_cholesky():
668      return math_ops.exp(self.log_abs_determinant())
669    return linalg_ops.matrix_determinant(self.to_dense())
670
671  def determinant(self, name="det"):
672    """Determinant for every batch member.
673
674    Args:
675      name:  A name for this `Op`.
676
677    Returns:
678      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
679
680    Raises:
681      NotImplementedError:  If `self.is_square` is `False`.
682    """
683    if self.is_square is False:
684      raise NotImplementedError(
685          "Determinant not implemented for an operator that is expected to "
686          "not be square.")
687    with self._name_scope(name):
688      return self._determinant()
689
690  def _log_abs_determinant(self):
691    logging.warn(
692        "Using (possibly slow) default implementation of determinant."
693        "  Requires conversion to a dense matrix and O(N^3) operations.")
694    if self._can_use_cholesky():
695      diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
696      return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
697    _, log_abs_det = linalg.slogdet(self.to_dense())
698    return log_abs_det
699
700  def log_abs_determinant(self, name="log_abs_det"):
701    """Log absolute value of determinant for every batch member.
702
703    Args:
704      name:  A name for this `Op`.
705
706    Returns:
707      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
708
709    Raises:
710      NotImplementedError:  If `self.is_square` is `False`.
711    """
712    if self.is_square is False:
713      raise NotImplementedError(
714          "Determinant not implemented for an operator that is expected to "
715          "not be square.")
716    with self._name_scope(name):
717      return self._log_abs_determinant()
718
719  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
720    """Default implementation of _solve."""
721    if self.is_square is False:
722      raise NotImplementedError(
723          "Solve is not yet implemented for non-square operators.")
724    logging.warn(
725        "Using (possibly slow) default implementation of solve."
726        "  Requires conversion to a dense matrix and O(N^3) operations.")
727    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
728    if self._can_use_cholesky():
729      return linear_operator_util.cholesky_solve_with_broadcast(
730          linalg_ops.cholesky(self.to_dense()), rhs)
731    return linear_operator_util.matrix_solve_with_broadcast(
732        self.to_dense(), rhs, adjoint=adjoint)
733
734  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
735    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
736
737    The returned `Tensor` will be close to an exact solution if `A` is well
738    conditioned. Otherwise closeness will vary. See class docstring for details.
739
740    Examples:
741
742    ```python
743    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
744    operator = LinearOperator(...)
745    operator.shape = [..., M, N]
746
747    # Solve R > 0 linear systems for every member of the batch.
748    RHS = ... # shape [..., M, R]
749
750    X = operator.solve(RHS)
751    # X[..., :, r] is the solution to the r'th linear system
752    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
753
754    operator.matmul(X)
755    ==> RHS
756    ```
757
758    Args:
759      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
760        `rhs` is treated like a [batch] matrix meaning for every set of leading
761        dimensions, the last two dimensions defines a matrix.
762        See class docstring for definition of compatibility.
763      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
764        of this `LinearOperator`:  `A^H X = rhs`.
765      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
766        is the hermitian transpose (transposition and complex conjugation).
767      name:  A name scope to use for ops added by this method.
768
769    Returns:
770      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
771
772    Raises:
773      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
774    """
775    if self.is_non_singular is False:
776      raise NotImplementedError(
777          "Exact solve not implemented for an operator that is expected to "
778          "be singular.")
779    if self.is_square is False:
780      raise NotImplementedError(
781          "Exact solve not implemented for an operator that is expected to "
782          "not be square.")
783    with self._name_scope(name, values=[rhs]):
784      rhs = ops.convert_to_tensor(rhs, name="rhs")
785      self._check_input_dtype(rhs)
786
787      self_dim = -1 if adjoint else -2
788      arg_dim = -1 if adjoint_arg else -2
789      tensor_shape.dimension_at_index(
790          self.shape, self_dim).assert_is_compatible_with(
791              rhs.get_shape()[arg_dim])
792
793      return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
794
795  def _solvevec(self, rhs, adjoint=False):
796    """Default implementation of _solvevec."""
797    rhs_mat = array_ops.expand_dims(rhs, axis=-1)
798    solution_mat = self.solve(rhs_mat, adjoint=adjoint)
799    return array_ops.squeeze(solution_mat, axis=-1)
800
801  def solvevec(self, rhs, adjoint=False, name="solve"):
802    """Solve single equation with best effort: `A X = rhs`.
803
804    The returned `Tensor` will be close to an exact solution if `A` is well
805    conditioned. Otherwise closeness will vary. See class docstring for details.
806
807    Examples:
808
809    ```python
810    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
811    operator = LinearOperator(...)
812    operator.shape = [..., M, N]
813
814    # Solve one linear system for every member of the batch.
815    RHS = ... # shape [..., M]
816
817    X = operator.solvevec(RHS)
818    # X is the solution to the linear system
819    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
820
821    operator.matvec(X)
822    ==> RHS
823    ```
824
825    Args:
826      rhs: `Tensor` with same `dtype` as this operator.
827        `rhs` is treated like a [batch] vector meaning for every set of leading
828        dimensions, the last dimension defines a vector.  See class docstring
829        for definition of compatibility regarding batch dimensions.
830      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
831        of this `LinearOperator`:  `A^H X = rhs`.
832      name:  A name scope to use for ops added by this method.
833
834    Returns:
835      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
836
837    Raises:
838      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
839    """
840    with self._name_scope(name, values=[rhs]):
841      rhs = ops.convert_to_tensor(rhs, name="rhs")
842      self._check_input_dtype(rhs)
843      self_dim = -1 if adjoint else -2
844      tensor_shape.dimension_at_index(
845          self.shape, self_dim).assert_is_compatible_with(
846              rhs.get_shape()[-1])
847
848      return self._solvevec(rhs, adjoint=adjoint)
849
850  def adjoint(self, name="adjoint"):
851    """Returns the adjoint of the current `LinearOperator`.
852
853    Given `A` representing this `LinearOperator`, return `A*`.
854    Note that calling `self.adjoint()` and `self.H` are equivalent.
855
856    Args:
857      name:  A name for this `Op`.
858
859    Returns:
860      `LinearOperator` which represents the adjoint of this `LinearOperator`.
861    """
862    if self.is_self_adjoint is True:  # pylint: disable=g-bool-id-comparison
863      return self
864    with self._name_scope(name):
865      return linear_operator_algebra.adjoint(self)
866
867  # self.H is equivalent to self.adjoint().
868  H = property(adjoint, None)
869
870  def inverse(self, name="inverse"):
871    """Returns the Inverse of this `LinearOperator`.
872
873    Given `A` representing this `LinearOperator`, return a `LinearOperator`
874    representing `A^-1`.
875
876    Args:
877      name: A name scope to use for ops added by this method.
878
879    Returns:
880      `LinearOperator` representing inverse of this matrix.
881
882    Raises:
883      ValueError: When the `LinearOperator` is not hinted to be `non_singular`.
884    """
885    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
886      raise ValueError("Cannot take the Inverse: This operator represents "
887                       "a non square matrix.")
888    if self.is_non_singular is False:  # pylint: disable=g-bool-id-comparison
889      raise ValueError("Cannot take the Inverse: This operator represents "
890                       "a singular matrix.")
891
892    with self._name_scope(name):
893      return linear_operator_algebra.inverse(self)
894
895  def cholesky(self, name="cholesky"):
896    """Returns a Cholesky factor as a `LinearOperator`.
897
898    Given `A` representing this `LinearOperator`, if `A` is positive definite
899    self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky
900    decomposition.
901
902    Args:
903      name:  A name for this `Op`.
904
905    Returns:
906      `LinearOperator` which represents the lower triangular matrix
907      in the Cholesky decomposition.
908
909    Raises:
910      ValueError: When the `LinearOperator` is not hinted to be positive
911        definite and self adjoint.
912    """
913
914    if not self._can_use_cholesky():
915      raise ValueError("Cannot take the Cholesky decomposition: "
916                       "Not a positive definite self adjoint matrix.")
917    with self._name_scope(name):
918      return linear_operator_algebra.cholesky(self)
919
920  def _to_dense(self):
921    """Generic and often inefficient implementation.  Override often."""
922    logging.warn("Using (possibly slow) default implementation of to_dense."
923                 "  Converts by self.matmul(identity).")
924    if self.batch_shape.is_fully_defined():
925      batch_shape = self.batch_shape
926    else:
927      batch_shape = self.batch_shape_tensor()
928
929    dim_value = tensor_shape.dimension_value(self.domain_dimension)
930    if dim_value is not None:
931      n = dim_value
932    else:
933      n = self.domain_dimension_tensor()
934
935    eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
936    return self.matmul(eye)
937
938  def to_dense(self, name="to_dense"):
939    """Return a dense (batch) matrix representing this operator."""
940    with self._name_scope(name):
941      return self._to_dense()
942
943  def _diag_part(self):
944    """Generic and often inefficient implementation.  Override often."""
945    return array_ops.matrix_diag_part(self.to_dense())
946
947  def diag_part(self, name="diag_part"):
948    """Efficiently get the [batch] diagonal part of this operator.
949
950    If this operator has shape `[B1,...,Bb, M, N]`, this returns a
951    `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where
952    `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`.
953
954    ```
955    my_operator = LinearOperatorDiag([1., 2.])
956
957    # Efficiently get the diagonal
958    my_operator.diag_part()
959    ==> [1., 2.]
960
961    # Equivalent, but inefficient method
962    tf.matrix_diag_part(my_operator.to_dense())
963    ==> [1., 2.]
964    ```
965
966    Args:
967      name:  A name for this `Op`.
968
969    Returns:
970      diag_part:  A `Tensor` of same `dtype` as self.
971    """
972    with self._name_scope(name):
973      return self._diag_part()
974
975  def _trace(self):
976    return math_ops.reduce_sum(self.diag_part(), axis=-1)
977
978  def trace(self, name="trace"):
979    """Trace of the linear operator, equal to sum of `self.diag_part()`.
980
981    If the operator is square, this is also the sum of the eigenvalues.
982
983    Args:
984      name:  A name for this `Op`.
985
986    Returns:
987      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
988    """
989    with self._name_scope(name):
990      return self._trace()
991
992  def _add_to_tensor(self, x):
993    # Override if a more efficient implementation is available.
994    return self.to_dense() + x
995
996  def add_to_tensor(self, x, name="add_to_tensor"):
997    """Add matrix represented by this operator to `x`.  Equivalent to `A + x`.
998
999    Args:
1000      x:  `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
1001      name:  A name to give this `Op`.
1002
1003    Returns:
1004      A `Tensor` with broadcast shape and same `dtype` as `self`.
1005    """
1006    with self._name_scope(name, values=[x]):
1007      x = ops.convert_to_tensor(x, name="x")
1008      self._check_input_dtype(x)
1009      return self._add_to_tensor(x)
1010
1011  def _can_use_cholesky(self):
1012    return self.is_self_adjoint and self.is_positive_definite
1013