1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Internal utilities for `LinearOperator` classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import linalg_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.linalg import linalg_impl as linalg
31
32
33def assert_no_entries_with_modulus_zero(
34    x, message=None, name="assert_no_entries_with_modulus_zero"):
35  """Returns `Op` that asserts Tensor `x` has no entries with modulus zero.
36
37  Args:
38    x:  Numeric `Tensor`, real, integer, or complex.
39    message:  A string message to prepend to failure message.
40    name:  A name to give this `Op`.
41
42  Returns:
43    An `Op` that asserts `x` has no entries with modulus zero.
44  """
45  with ops.name_scope(name, values=[x]):
46    x = ops.convert_to_tensor(x, name="x")
47    dtype = x.dtype.base_dtype
48    should_be_nonzero = math_ops.abs(x)
49    zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
50    return check_ops.assert_less(zero, should_be_nonzero, message=message)
51
52
53def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
54  """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.
55
56  Args:
57    x:  Numeric `Tensor`, real, integer, or complex.
58    message:  A string message to prepend to failure message.
59    name:  A name to give this `Op`.
60
61  Returns:
62    An `Op` that asserts `x` has no entries with modulus zero.
63  """
64  with ops.name_scope(name, values=[x]):
65    x = ops.convert_to_tensor(x, name="x")
66    dtype = x.dtype.base_dtype
67
68    if dtype.is_floating:
69      return control_flow_ops.no_op()
70
71    zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
72    return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
73
74
75def assert_compatible_matrix_dimensions(operator, x):
76  """Assert that an argument to solve/matmul has proper domain dimension.
77
78  If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then
79  `operator.matmul(x)` is defined only if `N = Q`.  This `Op` returns an
80  `Assert` that "fires" if this is not the case.  Static checks are already
81  done by the base class `LinearOperator`.
82
83  Args:
84    operator:  `LinearOperator`.
85    x:  `Tensor`.
86
87  Returns:
88    `Assert` `Op`.
89  """
90  # Static checks are done in the base class.  Only tensor asserts here.
91  assert_same_dd = check_ops.assert_equal(
92      array_ops.shape(x)[-2],
93      operator.domain_dimension_tensor(),
94      message=("Incompatible matrix dimensions.  "
95               "shape[-2] of argument to be the same as this operator"))
96
97  return assert_same_dd
98
99
100def assert_is_batch_matrix(tensor):
101  """Static assert that `tensor` has rank `2` or higher."""
102  sh = tensor.get_shape()
103  if sh.ndims is not None and sh.ndims < 2:
104    raise ValueError(
105        "Expected [batch] matrix to have at least two dimensions.  Found: "
106        "%s" % tensor)
107
108
109def shape_tensor(shape, name=None):
110  """Convert Tensor using default type, unless empty list or tuple."""
111  # Works just like random_ops._ShapeTensor.
112  if isinstance(shape, (tuple, list)) and not shape:
113    dtype = dtypes.int32
114  else:
115    dtype = None
116  return ops.convert_to_tensor(shape, dtype=dtype, name=name)
117
118
119################################################################################
120# Broadcasting versions of common linear algebra functions.
121# TODO(b/77519145) Do this more efficiently in some special cases.
122################################################################################
123
124
125def broadcast_matrix_batch_dims(batch_matrices, name=None):
126  """Broadcast leading dimensions of zero or more [batch] matrices.
127
128  Example broadcasting one batch dim of two simple matrices.
129
130  ```python
131  x = [[1, 2],
132       [3, 4]]  # Shape [2, 2], no batch dims
133
134  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]
135
136  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
137
138  x_bc
139  ==> [[[1, 2],
140        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].
141
142  y_bc
143  ==> same as y
144  ```
145
146  Example broadcasting many batch dims
147
148  ```python
149  x = tf.random_normal(shape=(2, 3, 1, 4, 4))
150  y = tf.random_normal(shape=(1, 3, 2, 5, 5))
151  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
152
153  x_bc.shape
154  ==> (2, 3, 2, 4, 4)
155
156  y_bc.shape
157  ==> (2, 3, 2, 5, 5)
158  ```
159
160  Args:
161    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
162    name:  A string name to prepend to created ops.
163
164  Returns:
165    bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
166      the values from `batch_matrices[i]`, with possibly broadcast batch dims.
167
168  Raises:
169    ValueError:  If any input `Tensor` is statically determined to have less
170      than two dimensions.
171  """
172  with ops.name_scope(
173      name or "broadcast_matrix_batch_dims", values=batch_matrices):
174    check_ops.assert_proper_iterable(batch_matrices)
175    batch_matrices = list(batch_matrices)
176
177    for i, mat in enumerate(batch_matrices):
178      batch_matrices[i] = ops.convert_to_tensor(mat)
179      assert_is_batch_matrix(batch_matrices[i])
180
181    if len(batch_matrices) < 2:
182      return batch_matrices
183
184    # Try static broadcasting.
185    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
186    # E.g. if batch_matrices = [x, y], with
187    # x.shape =    [2, j, k]  (batch shape =    [2])
188    # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
189    # ==> bcast_batch_shape = [3, 2]
190    bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
191    for mat in batch_matrices[1:]:
192      bcast_batch_shape = array_ops.broadcast_static_shape(
193          bcast_batch_shape,
194          mat.get_shape()[:-2])
195    if bcast_batch_shape.is_fully_defined():
196      # The [1, 1] at the end will broadcast with anything.
197      bcast_shape = bcast_batch_shape.concatenate([1, 1])
198      for i, mat in enumerate(batch_matrices):
199        if mat.get_shape()[:-2] != bcast_batch_shape:
200          batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
201      return batch_matrices
202
203    # Since static didn't work, do dynamic, which always copies data.
204    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
205    for mat in batch_matrices[1:]:
206      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
207          bcast_batch_shape,
208          array_ops.shape(mat)[:-2])
209    bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
210    for i, mat in enumerate(batch_matrices):
211      batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
212
213    return batch_matrices
214
215
216def _broadcast_to_shape(x, shape):
217  return x + array_ops.zeros(shape=shape, dtype=x.dtype)
218
219
220def cholesky_solve_with_broadcast(chol, rhs, name=None):
221  """Solve systems of linear equations."""
222  with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]):
223    chol, rhs = broadcast_matrix_batch_dims([chol, rhs])
224    return linalg_ops.cholesky_solve(chol, rhs)
225
226
227def matmul_with_broadcast(a,
228                          b,
229                          transpose_a=False,
230                          transpose_b=False,
231                          adjoint_a=False,
232                          adjoint_b=False,
233                          a_is_sparse=False,
234                          b_is_sparse=False,
235                          name=None):
236  """Multiplies matrix `a` by matrix `b`, producing `a @ b`.
237
238  Works identically to `tf.matmul`, but broadcasts batch dims
239  of `a` and `b` if they are determined statically to be different, or if static
240  shapes are not fully defined. Attempts are made to avoid unnecessary
241  replication of data, but this is not always possible.
242
243  The inputs must be matrices (or tensors of rank > 2, representing batches of
244  matrices).
245
246  Both matrices must be of the same type. The supported types are:
247  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
248
249  Either matrix can be transposed or adjointed (conjugated and transposed) on
250  the fly by setting one of the corresponding flag to `True`. These are `False`
251  by default.
252
253  If one or both of the matrices contain a lot of zeros, a more efficient
254  multiplication algorithm can be used by setting the corresponding
255  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
256  This optimization is only available for plain matrices (rank-2 tensors) with
257  datatypes `bfloat16` or `float32`.
258
259  For example:
260
261  ```python
262  # A 2-batch of 3x4 matrices
263  a = tf.random_normal(shape=(2, 3, 4))
264
265  # A single 4x5 matrix
266  b = tf.random_normal(shape=(4, 5))
267
268  result = matmul_with_broadcast(a, b)
269
270  result.shape
271  ==> (2, 3, 5)
272
273  result[0,...]
274  ==> tf.matmul(a[0,...], b)
275
276  result[1,...]
277  ==> tf.matmul(a[1,...], b)
278  ```
279
280  Args:
281    a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
282      `complex128` and `rank > 1`.
283    b: `Tensor` with same type as `a` having compatible matrix dimensions and
284      broadcastable batch dimensions.
285    transpose_a: If `True`, `a` is transposed before multiplication.
286    transpose_b: If `True`, `b` is transposed before multiplication.
287    adjoint_a: If `True`, `a` is conjugated and transposed before
288      multiplication.
289    adjoint_b: If `True`, `b` is conjugated and transposed before
290      multiplication.
291    a_is_sparse: If `True`, `a` is treated as a sparse matrix.
292    b_is_sparse: If `True`, `b` is treated as a sparse matrix.
293    name: Name for the operation (optional).
294
295  Returns:
296    A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
297    the product of the corresponding matrices in `a` and `b`, e.g. if all
298    transpose or adjoint attributes are `False`:
299
300    The leading shape of `output` is the result of broadcasting the leading
301    dimensions of `a` and `b`.
302
303    `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
304    for all indices i, j.
305
306    Note: This is matrix product, not element-wise product.
307
308
309  Raises:
310    ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
311      are both set to True.
312  """
313  with ops.name_scope(name, "MatMulWithBroadcast", [a, b]):
314    a = ops.convert_to_tensor(a, name="a")
315    b = ops.convert_to_tensor(b, name="b", dtype=a.dtype)
316
317    # If either a or b has extra dims, we can reshape to get rid of them.
318    a, b, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
319        a,
320        b,
321        transpose_a=transpose_a,
322        transpose_b=transpose_b,
323        adjoint_a=adjoint_a,
324        adjoint_b=adjoint_b)
325
326    # This will broadcast by brute force if we still need to.
327    a, b = broadcast_matrix_batch_dims([a, b])
328
329    a_times_b = math_ops.matmul(
330        a,
331        b,
332        transpose_a=transpose_a and still_need_to_transpose,
333        transpose_b=transpose_b and still_need_to_transpose,
334        adjoint_a=adjoint_a and still_need_to_transpose,
335        adjoint_b=adjoint_b and still_need_to_transpose,
336        a_is_sparse=a_is_sparse,
337        b_is_sparse=b_is_sparse)
338
339    return reshape_inv(a_times_b)
340
341
342def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
343  """Solve systems of linear equations."""
344  with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
345    matrix = ops.convert_to_tensor(matrix, name="matrix")
346    rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
347
348    # If either matrix/rhs has extra dims, we can reshape to get rid of them.
349    matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
350        matrix, rhs, adjoint_a=adjoint)
351
352    # This will broadcast by brute force if we still need to.
353    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
354
355    solution = linalg_ops.matrix_solve(
356        matrix, rhs, adjoint=adjoint and still_need_to_transpose)
357
358    return reshape_inv(solution)
359
360
361def matrix_triangular_solve_with_broadcast(matrix,
362                                           rhs,
363                                           lower=True,
364                                           adjoint=False,
365                                           name=None):
366  """Solves triangular systems of linear equations with by backsubstitution.
367
368  Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
369  of `matrix` and `rhs` (by replicating) if they are determined statically to be
370  different, or if static shapes are not fully defined.  Thus, this may result
371  in an inefficient replication of data.
372
373  Args:
374    matrix: A Tensor. Must be one of the following types:
375      `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
376    rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
377      Shape is `[..., M, K]`.
378    lower: An optional `bool`. Defaults to `True`. Indicates whether the
379      innermost matrices in `matrix` are lower or upper triangular.
380    adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
381      with matrix or its (block-wise) adjoint.
382    name: A name for the operation (optional).
383
384  Returns:
385    `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
386  """
387  with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
388    matrix = ops.convert_to_tensor(matrix, name="matrix")
389    rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
390
391    # If either matrix/rhs has extra dims, we can reshape to get rid of them.
392    matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
393        matrix, rhs, adjoint_a=adjoint)
394
395    # lower indicates whether the matrix is lower triangular. If we have
396    # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri
397    if not still_need_to_transpose and adjoint:
398      lower = not lower
399
400    # This will broadcast by brute force if we still need to.
401    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
402
403    solution = linalg_ops.matrix_triangular_solve(
404        matrix,
405        rhs,
406        lower=lower,
407        adjoint=adjoint and still_need_to_transpose)
408
409    return reshape_inv(solution)
410
411
412def _reshape_for_efficiency(a,
413                            b,
414                            transpose_a=False,
415                            transpose_b=False,
416                            adjoint_a=False,
417                            adjoint_b=False):
418  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
419  def identity(x):
420    return x
421
422  # At this point, we have not taken transpose/adjoint of a/b.
423  still_need_to_transpose = True
424
425  if a.shape.ndims is None or b.shape.ndims is None:
426    return a, b, identity, still_need_to_transpose
427
428  # This could be handled in the future, but seems less common.
429  if a.shape.ndims >= b.shape.ndims:
430    return a, b, identity, still_need_to_transpose
431
432  # From now on, we might modify b, but will not modify a.
433
434  # Suppose:
435  #   a.shape =     C + [m, n], b.shape =
436  #   b.shape = S + C + [n, r]
437  b_extra_ndims = b.shape.ndims - a.shape.ndims
438
439  # b_extra_sh = S, b_main_sh = C + [n, r]
440  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
441  b_main_sh = array_ops.shape(b)[b_extra_ndims:]
442
443  # No reason to flip unless the extra dims of b are big enough.  Why?
444  # Assume adjoint/transpose = False.  Then...
445  # By not flipping, we have to replicate a to shape
446  #   b_extra_sh + a.shape,
447  # which could use extra memory.  But in all cases, the final output has shape
448  #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
449  # So we only end up creating a larger object if the end dim of b is smaller
450  # than the end dim of a.  This often happens, e.g. if b was a vector that was
451  # expanded to a matrix (by appending a singleton).
452
453  # Since adjoint/transpose may not be False, we must make adjustments here.
454  # The dim of b that holds the multiple equations.
455  a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
456  b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
457  b_extra_sz_ = (
458      np.prod(b.shape[:b_extra_ndims].as_list())
459      if b.shape[:b_extra_ndims].is_fully_defined() else None)
460  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
461      b_extra_sz_ is not None):
462    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
463      return a, b, identity, still_need_to_transpose
464
465  # At this point, we're flipping for sure!
466  # Any transposes/adjoints will happen here explicitly, rather than in calling
467  # code.  Why?  To avoid having to write separate complex code for each case.
468  if adjoint_a:
469    a = linalg.adjoint(a)
470  elif transpose_a:
471    a = linalg.transpose(a)
472  if adjoint_b:
473    b = linalg.adjoint(b)
474  elif transpose_b:
475    b = linalg.transpose(b)
476  still_need_to_transpose = False
477
478  # Recompute shapes, since the transpose/adjoint may have changed them.
479  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
480  b_main_sh = array_ops.shape(b)[b_extra_ndims:]
481
482  # Permutation to put the extra dims at the end.
483  perm = (
484      np.concatenate(
485          (np.arange(b_extra_ndims, b.shape.ndims),
486           np.arange(0, b_extra_ndims)), 0))
487  b_extra_on_end = array_ops.transpose(b, perm=perm)
488
489  # Now squash this end into one long dim.
490  b_squashed_end = array_ops.reshape(
491      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))
492
493  def reshape_inv(y):
494    # Expand the extra dims hanging off the end, "b_extra_sh".
495    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
496    # Could have different batch dims than a and b, because of broadcasting.
497    y_extra_shape = array_ops.concat(
498        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
499    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
500    inverse_perm = np.argsort(perm)
501    return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
502
503  return a, b_squashed_end, reshape_inv, still_need_to_transpose
504