1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Create a Block Diagonal operator from one or more `LinearOperators`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import common_shapes
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.ops.linalg import linear_operator_algebra
30from tensorflow.python.ops.linalg import linear_operator_util
31from tensorflow.python.util.tf_export import tf_export
32
33__all__ = ["LinearOperatorBlockDiag"]
34
35
36@tf_export("linalg.LinearOperatorBlockDiag")
37class LinearOperatorBlockDiag(linear_operator.LinearOperator):
38  """Combines one or more `LinearOperators` in to a Block Diagonal matrix.
39
40  This operator combines one or more linear operators `[op1,...,opJ]`,
41  building a new `LinearOperator`, whose underlying matrix representation is
42  square and has each operator `opi` on the main diagonal, and zero's elsewhere.
43
44  #### Shape compatibility
45
46  If `opj` acts like a [batch] square matrix `Aj`, then `op_combined` acts like
47  the [batch] square matrix formed by having each matrix `Aj` on the main
48  diagonal.
49
50  Each `opj` is required to represent a square matrix, and hence will have
51  shape `batch_shape_j + [M_j, M_j]`.
52
53  If `opj` has shape `batch_shape_j + [M_j, M_j]`, then the combined operator
54  has shape `broadcast_batch_shape + [sum M_j, sum M_j]`, where
55  `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`,
56  `j = 1,...,J`, assuming the intermediate batch shapes broadcast.
57  Even if the combined shape is well defined, the combined operator's
58  methods may fail due to lack of broadcasting ability in the defining
59  operators' methods.
60
61  Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single
62  `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th
63  element of a blockwise list of `Tensor`s must have dimensions that match
64  `opj` for the given method. If a list of blocks is input, then a list of
65  blocks is returned as well.
66
67  ```python
68  # Create a 4 x 4 linear operator combined of two 2 x 2 operators.
69  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
70  operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
71  operator = LinearOperatorBlockDiag([operator_1, operator_2])
72
73  operator.to_dense()
74  ==> [[1., 2., 0., 0.],
75       [3., 4., 0., 0.],
76       [0., 0., 1., 0.],
77       [0., 0., 0., 1.]]
78
79  operator.shape
80  ==> [4, 4]
81
82  operator.log_abs_determinant()
83  ==> scalar Tensor
84
85  x1 = ... # Shape [2, 2] Tensor
86  x2 = ... # Shape [2, 2] Tensor
87  x = tf.concat([x1, x2], 0)  # Shape [2, 4] Tensor
88  operator.matmul(x)
89  ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)])
90
91  # Create a [2, 3] batch of 4 x 4 linear operators.
92  matrix_44 = tf.random.normal(shape=[2, 3, 4, 4])
93  operator_44 = LinearOperatorFullMatrix(matrix)
94
95  # Create a [1, 3] batch of 5 x 5 linear operators.
96  matrix_55 = tf.random.normal(shape=[1, 3, 5, 5])
97  operator_55 = LinearOperatorFullMatrix(matrix_55)
98
99  # Combine to create a [2, 3] batch of 9 x 9 operators.
100  operator_99 = LinearOperatorBlockDiag([operator_44, operator_55])
101
102  # Create a shape [2, 3, 9] vector.
103  x = tf.random.normal(shape=[2, 3, 9])
104  operator_99.matmul(x)
105  ==> Shape [2, 3, 9] Tensor
106
107  # Create a blockwise list of vectors.
108  x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])]
109  operator_99.matmul(x)
110  ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor]
111  ```
112
113  #### Performance
114
115  The performance of `LinearOperatorBlockDiag` on any operation is equal to
116  the sum of the individual operators' operations.
117
118
119  #### Matrix property hints
120
121  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
122  for `X = non_singular, self_adjoint, positive_definite, square`.
123  These have the following meaning:
124
125  * If `is_X == True`, callers should expect the operator to have the
126    property `X`.  This is a promise that should be fulfilled, but is *not* a
127    runtime assert.  For example, finite floating point precision may result
128    in these promises being violated.
129  * If `is_X == False`, callers should expect the operator to not have `X`.
130  * If `is_X == None` (the default), callers should have no expectation either
131    way.
132  """
133
134  def __init__(self,
135               operators,
136               is_non_singular=None,
137               is_self_adjoint=None,
138               is_positive_definite=None,
139               is_square=True,
140               name=None):
141    r"""Initialize a `LinearOperatorBlockDiag`.
142
143    `LinearOperatorBlockDiag` is initialized with a list of operators
144    `[op_1,...,op_J]`.
145
146    Args:
147      operators:  Iterable of `LinearOperator` objects, each with
148        the same `dtype` and composable shape.
149      is_non_singular:  Expect that this operator is non-singular.
150      is_self_adjoint:  Expect that this operator is equal to its hermitian
151        transpose.
152      is_positive_definite:  Expect that this operator is positive definite,
153        meaning the quadratic form `x^H A x` has positive real part for all
154        nonzero `x`.  Note that we do not require the operator to be
155        self-adjoint to be positive-definite.  See:
156        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
157      is_square:  Expect that this operator acts like square [batch] matrices.
158        This is true by default, and will raise a `ValueError` otherwise.
159      name: A name for this `LinearOperator`.  Default is the individual
160        operators names joined with `_o_`.
161
162    Raises:
163      TypeError:  If all operators do not have the same `dtype`.
164      ValueError:  If `operators` is empty or are non-square.
165    """
166    parameters = dict(
167        operators=operators,
168        is_non_singular=is_non_singular,
169        is_self_adjoint=is_self_adjoint,
170        is_positive_definite=is_positive_definite,
171        is_square=is_square,
172        name=name
173    )
174
175    # Validate operators.
176    check_ops.assert_proper_iterable(operators)
177    operators = list(operators)
178    if not operators:
179      raise ValueError(
180          "Expected a non-empty list of operators. Found: %s" % operators)
181    self._operators = operators
182
183    # Define diagonal operators, for functions that are shared across blockwise
184    # `LinearOperator` types.
185    self._diagonal_operators = operators
186
187    # Validate dtype.
188    dtype = operators[0].dtype
189    for operator in operators:
190      if operator.dtype != dtype:
191        name_type = (str((o.name, o.dtype)) for o in operators)
192        raise TypeError(
193            "Expected all operators to have the same dtype.  Found %s"
194            % "   ".join(name_type))
195
196    # Auto-set and check hints.
197    if all(operator.is_non_singular for operator in operators):
198      if is_non_singular is False:
199        raise ValueError(
200            "The direct sum of non-singular operators is always non-singular.")
201      is_non_singular = True
202
203    if all(operator.is_self_adjoint for operator in operators):
204      if is_self_adjoint is False:
205        raise ValueError(
206            "The direct sum of self-adjoint operators is always self-adjoint.")
207      is_self_adjoint = True
208
209    if all(operator.is_positive_definite for operator in operators):
210      if is_positive_definite is False:
211        raise ValueError(
212            "The direct sum of positive definite operators is always "
213            "positive definite.")
214      is_positive_definite = True
215
216    if not (is_square and all(operator.is_square for operator in operators)):
217      raise ValueError(
218          "Can only represent a block diagonal of square matrices.")
219
220    # Initialization.
221    graph_parents = []
222    for operator in operators:
223      graph_parents.extend(operator.graph_parents)
224
225    if name is None:
226      # Using ds to mean direct sum.
227      name = "_ds_".join(operator.name for operator in operators)
228    with ops.name_scope(name, values=graph_parents):
229      super(LinearOperatorBlockDiag, self).__init__(
230          dtype=dtype,
231          is_non_singular=is_non_singular,
232          is_self_adjoint=is_self_adjoint,
233          is_positive_definite=is_positive_definite,
234          is_square=True,
235          parameters=parameters,
236          name=name)
237
238    # TODO(b/143910018) Remove graph_parents in V3.
239    self._set_graph_parents(graph_parents)
240
241  @property
242  def operators(self):
243    return self._operators
244
245  def _block_range_dimensions(self):
246    return [op.range_dimension for op in self._diagonal_operators]
247
248  def _block_domain_dimensions(self):
249    return [op.domain_dimension for op in self._diagonal_operators]
250
251  def _block_range_dimension_tensors(self):
252    return [op.range_dimension_tensor() for op in self._diagonal_operators]
253
254  def _block_domain_dimension_tensors(self):
255    return [op.domain_dimension_tensor() for op in self._diagonal_operators]
256
257  def _shape(self):
258    # Get final matrix shape.
259    domain_dimension = sum(self._block_domain_dimensions())
260    range_dimension = sum(self._block_range_dimensions())
261    matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])
262
263    # Get broadcast batch shape.
264    # broadcast_shape checks for compatibility.
265    batch_shape = self.operators[0].batch_shape
266    for operator in self.operators[1:]:
267      batch_shape = common_shapes.broadcast_shape(
268          batch_shape, operator.batch_shape)
269
270    return batch_shape.concatenate(matrix_shape)
271
272  def _shape_tensor(self):
273    # Avoid messy broadcasting if possible.
274    if self.shape.is_fully_defined():
275      return ops.convert_to_tensor_v2_with_dispatch(
276          self.shape.as_list(), dtype=dtypes.int32, name="shape")
277
278    domain_dimension = sum(self._block_domain_dimension_tensors())
279    range_dimension = sum(self._block_range_dimension_tensors())
280    matrix_shape = array_ops.stack([domain_dimension, range_dimension])
281
282    # Dummy Tensor of zeros.  Will never be materialized.
283    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
284    for operator in self.operators[1:]:
285      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
286    batch_shape = array_ops.shape(zeros)
287
288    return array_ops.concat((batch_shape, matrix_shape), 0)
289
290  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
291    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
292
293    ```python
294    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
295    operator = LinearOperator(...)
296    operator.shape = [..., M, N]
297
298    X = ... # shape [..., N, R], batch matrix, R > 0.
299
300    Y = operator.matmul(X)
301    Y.shape
302    ==> [..., M, R]
303
304    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
305    ```
306
307    Args:
308      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
309        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
310        class docstring for definition of shape compatibility.
311      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
312      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
313        the hermitian transpose (transposition and complex conjugation).
314      name:  A name for this `Op`.
315
316    Returns:
317      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
318        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
319        concatenate to `[..., M, R]`.
320    """
321    if isinstance(x, linear_operator.LinearOperator):
322      left_operator = self.adjoint() if adjoint else self
323      right_operator = x.adjoint() if adjoint_arg else x
324
325      if (right_operator.range_dimension is not None and
326          left_operator.domain_dimension is not None and
327          right_operator.range_dimension != left_operator.domain_dimension):
328        raise ValueError(
329            "Operators are incompatible. Expected `x` to have dimension"
330            " {} but got {}.".format(
331                left_operator.domain_dimension, right_operator.range_dimension))
332      with self._name_scope(name):
333        return linear_operator_algebra.matmul(left_operator, right_operator)
334
335    with self._name_scope(name):
336      arg_dim = -1 if adjoint_arg else -2
337      block_dimensions = (self._block_range_dimensions() if adjoint
338                          else self._block_domain_dimensions())
339      if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
340        for i, block in enumerate(x):
341          if not isinstance(block, linear_operator.LinearOperator):
342            block = ops.convert_to_tensor_v2_with_dispatch(block)
343            self._check_input_dtype(block)
344            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
345            x[i] = block
346      else:
347        x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
348        self._check_input_dtype(x)
349        op_dimension = (self.range_dimension if adjoint
350                        else self.domain_dimension)
351        op_dimension.assert_is_compatible_with(x.shape[arg_dim])
352      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
353
354  def _matmul(self, x, adjoint=False, adjoint_arg=False):
355    arg_dim = -1 if adjoint_arg else -2
356    block_dimensions = (self._block_range_dimensions() if adjoint
357                        else self._block_domain_dimensions())
358    blockwise_arg = linear_operator_util.arg_is_blockwise(
359        block_dimensions, x, arg_dim)
360    if blockwise_arg:
361      split_x = x
362    else:
363      split_dim = -1 if adjoint_arg else -2
364      # Split input by rows normally, and otherwise columns.
365      split_x = linear_operator_util.split_arg_into_blocks(
366          self._block_domain_dimensions(),
367          self._block_domain_dimension_tensors,
368          x, axis=split_dim)
369
370    result_list = []
371    for index, operator in enumerate(self.operators):
372      result_list += [operator.matmul(
373          split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
374
375    if blockwise_arg:
376      return result_list
377
378    result_list = linear_operator_util.broadcast_matrix_batch_dims(
379        result_list)
380    return array_ops.concat(result_list, axis=-2)
381
382  def matvec(self, x, adjoint=False, name="matvec"):
383    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
384
385    ```python
386    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
387    operator = LinearOperator(...)
388
389    X = ... # shape [..., N], batch vector
390
391    Y = operator.matvec(X)
392    Y.shape
393    ==> [..., M]
394
395    Y[..., :] = sum_j A[..., :, j] X[..., j]
396    ```
397
398    Args:
399      x: `Tensor` with compatible shape and same `dtype` as `self`, or an
400        iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated
401        a [batch] vectors, meaning for every set of leading dimensions, the last
402        dimension defines a vector.
403        See class docstring for definition of compatibility.
404      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
405      name:  A name for this `Op`.
406
407    Returns:
408      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
409    """
410    with self._name_scope(name):
411      block_dimensions = (self._block_range_dimensions() if adjoint
412                          else self._block_domain_dimensions())
413      if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
414        for i, block in enumerate(x):
415          if not isinstance(block, linear_operator.LinearOperator):
416            block = ops.convert_to_tensor_v2_with_dispatch(block)
417            self._check_input_dtype(block)
418            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
419            x[i] = block
420        x_mat = [block[..., array_ops.newaxis] for block in x]
421        y_mat = self.matmul(x_mat, adjoint=adjoint)
422        return [array_ops.squeeze(y, axis=-1) for y in y_mat]
423
424      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
425      self._check_input_dtype(x)
426      op_dimension = (self.range_dimension if adjoint
427                      else self.domain_dimension)
428      op_dimension.assert_is_compatible_with(x.shape[-1])
429      x_mat = x[..., array_ops.newaxis]
430      y_mat = self.matmul(x_mat, adjoint=adjoint)
431      return array_ops.squeeze(y_mat, axis=-1)
432
433  def _determinant(self):
434    result = self.operators[0].determinant()
435    for operator in self.operators[1:]:
436      result *= operator.determinant()
437    return result
438
439  def _log_abs_determinant(self):
440    result = self.operators[0].log_abs_determinant()
441    for operator in self.operators[1:]:
442      result += operator.log_abs_determinant()
443    return result
444
445  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
446    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
447
448    The returned `Tensor` will be close to an exact solution if `A` is well
449    conditioned. Otherwise closeness will vary. See class docstring for details.
450
451    Examples:
452
453    ```python
454    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
455    operator = LinearOperator(...)
456    operator.shape = [..., M, N]
457
458    # Solve R > 0 linear systems for every member of the batch.
459    RHS = ... # shape [..., M, R]
460
461    X = operator.solve(RHS)
462    # X[..., :, r] is the solution to the r'th linear system
463    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
464
465    operator.matmul(X)
466    ==> RHS
467    ```
468
469    Args:
470      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
471        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
472        like a [batch] matrices meaning for every set of leading dimensions, the
473        last two dimensions defines a matrix.
474        See class docstring for definition of compatibility.
475      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
476        of this `LinearOperator`:  `A^H X = rhs`.
477      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
478        is the hermitian transpose (transposition and complex conjugation).
479      name:  A name scope to use for ops added by this method.
480
481    Returns:
482      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
483
484    Raises:
485      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
486    """
487    if self.is_non_singular is False:
488      raise NotImplementedError(
489          "Exact solve not implemented for an operator that is expected to "
490          "be singular.")
491    if self.is_square is False:
492      raise NotImplementedError(
493          "Exact solve not implemented for an operator that is expected to "
494          "not be square.")
495    if isinstance(rhs, linear_operator.LinearOperator):
496      left_operator = self.adjoint() if adjoint else self
497      right_operator = rhs.adjoint() if adjoint_arg else rhs
498
499      if (right_operator.range_dimension is not None and
500          left_operator.domain_dimension is not None and
501          right_operator.range_dimension != left_operator.domain_dimension):
502        raise ValueError(
503            "Operators are incompatible. Expected `rhs` to have dimension"
504            " {} but got {}.".format(
505                left_operator.domain_dimension, right_operator.range_dimension))
506      with self._name_scope(name):
507        return linear_operator_algebra.solve(left_operator, right_operator)
508
509    with self._name_scope(name):
510      block_dimensions = (self._block_domain_dimensions() if adjoint
511                          else self._block_range_dimensions())
512      arg_dim = -1 if adjoint_arg else -2
513      blockwise_arg = linear_operator_util.arg_is_blockwise(
514          block_dimensions, rhs, arg_dim)
515
516      if blockwise_arg:
517        split_rhs = rhs
518        for i, block in enumerate(split_rhs):
519          if not isinstance(block, linear_operator.LinearOperator):
520            block = ops.convert_to_tensor_v2_with_dispatch(block)
521            self._check_input_dtype(block)
522            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
523            split_rhs[i] = block
524      else:
525        rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
526        self._check_input_dtype(rhs)
527        op_dimension = (self.domain_dimension if adjoint
528                        else self.range_dimension)
529        op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
530        split_dim = -1 if adjoint_arg else -2
531        # Split input by rows normally, and otherwise columns.
532        split_rhs = linear_operator_util.split_arg_into_blocks(
533            self._block_domain_dimensions(),
534            self._block_domain_dimension_tensors,
535            rhs, axis=split_dim)
536
537      solution_list = []
538      for index, operator in enumerate(self.operators):
539        solution_list += [operator.solve(
540            split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
541
542      if blockwise_arg:
543        return solution_list
544
545      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
546          solution_list)
547      return array_ops.concat(solution_list, axis=-2)
548
549  def solvevec(self, rhs, adjoint=False, name="solve"):
550    """Solve single equation with best effort: `A X = rhs`.
551
552    The returned `Tensor` will be close to an exact solution if `A` is well
553    conditioned. Otherwise closeness will vary. See class docstring for details.
554
555    Examples:
556
557    ```python
558    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
559    operator = LinearOperator(...)
560    operator.shape = [..., M, N]
561
562    # Solve one linear system for every member of the batch.
563    RHS = ... # shape [..., M]
564
565    X = operator.solvevec(RHS)
566    # X is the solution to the linear system
567    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
568
569    operator.matvec(X)
570    ==> RHS
571    ```
572
573    Args:
574      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
575        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
576        meaning for every set of leading dimensions, the last dimension defines
577        a vector.  See class docstring for definition of compatibility regarding
578        batch dimensions.
579      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
580        of this `LinearOperator`:  `A^H X = rhs`.
581      name:  A name scope to use for ops added by this method.
582
583    Returns:
584      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
585
586    Raises:
587      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
588    """
589    with self._name_scope(name):
590      block_dimensions = (self._block_domain_dimensions() if adjoint
591                          else self._block_range_dimensions())
592      if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
593        for i, block in enumerate(rhs):
594          if not isinstance(block, linear_operator.LinearOperator):
595            block = ops.convert_to_tensor_v2_with_dispatch(block)
596            self._check_input_dtype(block)
597            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
598            rhs[i] = block
599        rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
600        solution_mat = self.solve(rhs_mat, adjoint=adjoint)
601        return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
602
603      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
604      self._check_input_dtype(rhs)
605      op_dimension = (self.domain_dimension if adjoint
606                      else self.range_dimension)
607      op_dimension.assert_is_compatible_with(rhs.shape[-1])
608      rhs_mat = array_ops.expand_dims(rhs, axis=-1)
609      solution_mat = self.solve(rhs_mat, adjoint=adjoint)
610      return array_ops.squeeze(solution_mat, axis=-1)
611
612  def _diag_part(self):
613    diag_list = []
614    for operator in self.operators:
615      # Extend the axis for broadcasting.
616      diag_list += [operator.diag_part()[..., array_ops.newaxis]]
617    diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
618    diagonal = array_ops.concat(diag_list, axis=-2)
619    return array_ops.squeeze(diagonal, axis=-1)
620
621  def _trace(self):
622    result = self.operators[0].trace()
623    for operator in self.operators[1:]:
624      result += operator.trace()
625    return result
626
627  def _to_dense(self):
628    num_cols = 0
629    rows = []
630    broadcasted_blocks = [operator.to_dense() for operator in self.operators]
631    broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
632        broadcasted_blocks)
633    for block in broadcasted_blocks:
634      batch_row_shape = array_ops.shape(block)[:-1]
635
636      zeros_to_pad_before_shape = array_ops.concat(
637          [batch_row_shape, [num_cols]], axis=-1)
638      zeros_to_pad_before = array_ops.zeros(
639          shape=zeros_to_pad_before_shape, dtype=block.dtype)
640      num_cols += array_ops.shape(block)[-1]
641      zeros_to_pad_after_shape = array_ops.concat(
642          [batch_row_shape,
643           [self.domain_dimension_tensor() - num_cols]], axis=-1)
644      zeros_to_pad_after = array_ops.zeros(
645          shape=zeros_to_pad_after_shape, dtype=block.dtype)
646
647      rows.append(array_ops.concat(
648          [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))
649
650    mat = array_ops.concat(rows, axis=-2)
651    mat.set_shape(self.shape)
652    return mat
653
654  def _assert_non_singular(self):
655    return control_flow_ops.group([
656        operator.assert_non_singular() for operator in self.operators])
657
658  def _assert_self_adjoint(self):
659    return control_flow_ops.group([
660        operator.assert_self_adjoint() for operator in self.operators])
661
662  def _assert_positive_definite(self):
663    return control_flow_ops.group([
664        operator.assert_positive_definite() for operator in self.operators])
665
666  def _eigvals(self):
667    eig_list = []
668    for operator in self.operators:
669      # Extend the axis for broadcasting.
670      eig_list += [operator.eigvals()[..., array_ops.newaxis]]
671    eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
672    eigs = array_ops.concat(eig_list, axis=-2)
673    return array_ops.squeeze(eigs, axis=-1)
674