1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` coming from a [[nested] block] circulant matrix."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.distributions import util as distribution_util
30from tensorflow.python.ops.linalg import linalg_impl as linalg
31from tensorflow.python.ops.linalg import linear_operator
32from tensorflow.python.ops.linalg import linear_operator_util
33from tensorflow.python.ops.signal import fft_ops
34from tensorflow.python.util.tf_export import tf_export
35
36__all__ = [
37    "LinearOperatorCirculant",
38    "LinearOperatorCirculant2D",
39    "LinearOperatorCirculant3D",
40]
41
42# Different FFT Ops will be used for different block depths.
43_FFT_OP = {1: fft_ops.fft, 2: fft_ops.fft2d, 3: fft_ops.fft3d}
44_IFFT_OP = {1: fft_ops.ifft, 2: fft_ops.ifft2d, 3: fft_ops.ifft3d}
45
46
47# TODO(langmore) Add transformations that create common spectrums, e.g.
48#   starting with the convolution kernel
49#   start with half a spectrum, and create a Hermitian one.
50#   common filters.
51# TODO(langmore) Support rectangular Toeplitz matrices.
52class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
53  """Base class for circulant operators.  Not user facing.
54
55  `LinearOperator` acting like a [batch] [[nested] block] circulant matrix.
56  """
57
58  def __init__(self,
59               spectrum,
60               block_depth,
61               input_output_dtype=dtypes.complex64,
62               is_non_singular=None,
63               is_self_adjoint=None,
64               is_positive_definite=None,
65               is_square=True,
66               parameters=None,
67               name="LinearOperatorCirculant"):
68    r"""Initialize an `_BaseLinearOperatorCirculant`.
69
70    Args:
71      spectrum:  Shape `[B1,...,Bb, N]` `Tensor`.  Allowed dtypes: `float16`,
72        `float32`, `float64`, `complex64`, `complex128`.  Type can be different
73        than `input_output_dtype`
74      block_depth:  Python integer, either 1, 2, or 3.  Will be 1 for circulant,
75        2 for block circulant, and 3 for nested block circulant.
76      input_output_dtype: `dtype` for input/output.
77      is_non_singular:  Expect that this operator is non-singular.
78      is_self_adjoint:  Expect that this operator is equal to its hermitian
79        transpose.  If `spectrum` is real, this will always be true.
80      is_positive_definite:  Expect that this operator is positive definite,
81        meaning the quadratic form `x^H A x` has positive real part for all
82        nonzero `x`.  Note that we do not require the operator to be
83        self-adjoint to be positive-definite.  See:
84        https://en.wikipedia.org/wiki/Positive-definite_matrix\
85            #Extension_for_non_symmetric_matrices
86      is_square:  Expect that this operator acts like square [batch] matrices.
87      parameters: Python `dict` of parameters used to instantiate this
88        `LinearOperator`.
89      name:  A name to prepend to all ops created by this class.
90
91    Raises:
92      ValueError:  If `block_depth` is not an allowed value.
93      TypeError:  If `spectrum` is not an allowed type.
94    """
95
96    allowed_block_depths = [1, 2, 3]
97
98    self._name = name
99
100    if block_depth not in allowed_block_depths:
101      raise ValueError("Expected block_depth to be in %s.  Found: %s." %
102                       (allowed_block_depths, block_depth))
103    self._block_depth = block_depth
104
105    with ops.name_scope(name, values=[spectrum]):
106      self._spectrum = self._check_spectrum_and_return_tensor(spectrum)
107
108      # Check and auto-set hints.
109      if not self.spectrum.dtype.is_complex:
110        if is_self_adjoint is False:
111          raise ValueError(
112              "A real spectrum always corresponds to a self-adjoint operator.")
113        is_self_adjoint = True
114
115      if is_square is False:
116        raise ValueError(
117            "A [[nested] block] circulant operator is always square.")
118      is_square = True
119
120      super(_BaseLinearOperatorCirculant, self).__init__(
121          dtype=dtypes.as_dtype(input_output_dtype),
122          is_non_singular=is_non_singular,
123          is_self_adjoint=is_self_adjoint,
124          is_positive_definite=is_positive_definite,
125          is_square=is_square,
126          parameters=parameters,
127          name=name)
128      # TODO(b/143910018) Remove graph_parents in V3.
129      self._set_graph_parents([self.spectrum])
130
131  def _check_spectrum_and_return_tensor(self, spectrum):
132    """Static check of spectrum.  Then return `Tensor` version."""
133    spectrum = linear_operator_util.convert_nonref_to_tensor(spectrum,
134                                                             name="spectrum")
135
136    if spectrum.shape.ndims is not None:
137      if spectrum.shape.ndims < self.block_depth:
138        raise ValueError(
139            "Argument spectrum must have at least %d dimensions.  Found: %s" %
140            (self.block_depth, spectrum))
141    return spectrum
142
143  @property
144  def block_depth(self):
145    """Depth of recursively defined circulant blocks defining this `Operator`.
146
147    With `A` the dense representation of this `Operator`,
148
149    `block_depth = 1` means `A` is symmetric circulant.  For example,
150
151    ```
152    A = |w z y x|
153        |x w z y|
154        |y x w z|
155        |z y x w|
156    ```
157
158    `block_depth = 2` means `A` is block symmetric circulant with symmetric
159    circulant blocks.  For example, with `W`, `X`, `Y`, `Z` symmetric circulant,
160
161    ```
162    A = |W Z Y X|
163        |X W Z Y|
164        |Y X W Z|
165        |Z Y X W|
166    ```
167
168    `block_depth = 3` means `A` is block symmetric circulant with block
169    symmetric circulant blocks.
170
171    Returns:
172      Python `integer`.
173    """
174    return self._block_depth
175
176  def block_shape_tensor(self):
177    """Shape of the block dimensions of `self.spectrum`."""
178    # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
179    # block_shape = [s1, s2]
180    return self._block_shape_tensor()
181
182  def _block_shape_tensor(self, spectrum_shape=None):
183    if self.block_shape.is_fully_defined():
184      return linear_operator_util.shape_tensor(
185          self.block_shape.as_list(), name="block_shape")
186    spectrum_shape = (
187        array_ops.shape(self.spectrum)
188        if spectrum_shape is None else spectrum_shape)
189    return spectrum_shape[-self.block_depth:]
190
191  @property
192  def block_shape(self):
193    return self.spectrum.shape[-self.block_depth:]
194
195  @property
196  def spectrum(self):
197    return self._spectrum
198
199  def _vectorize_then_blockify(self, matrix):
200    """Shape batch matrix to batch vector, then blockify trailing dimensions."""
201    # Suppose
202    #   matrix.shape = [m0, m1, m2, m3],
203    # and matrix is a matrix because the final two dimensions are matrix dims.
204    #   self.block_depth = 2,
205    #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
206    # We will reshape matrix to
207    #   [m3, m0, m1, b0, b1].
208
209    # Vectorize: Reshape to batch vector.
210    #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
211    # This is called "vectorize" because we have taken the final two matrix dims
212    # and turned this into a size m3 batch of vectors.
213    vec = distribution_util.rotate_transpose(matrix, shift=1)
214
215    # Blockify: Blockfy trailing dimensions.
216    #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
217    if (vec.shape.is_fully_defined() and
218        self.block_shape.is_fully_defined()):
219      # vec_leading_shape = [m3, m0, m1],
220      # the parts of vec that will not be blockified.
221      vec_leading_shape = vec.shape[:-1]
222      final_shape = vec_leading_shape.concatenate(self.block_shape)
223    else:
224      vec_leading_shape = array_ops.shape(vec)[:-1]
225      final_shape = array_ops.concat(
226          (vec_leading_shape, self.block_shape_tensor()), 0)
227    return array_ops.reshape(vec, final_shape)
228
229  def _unblockify_then_matricize(self, vec):
230    """Flatten the block dimensions then reshape to a batch matrix."""
231    # Suppose
232    #   vec.shape = [v0, v1, v2, v3],
233    #   self.block_depth = 2.
234    # Then
235    #   leading shape = [v0, v1]
236    #   block shape = [v2, v3].
237    # We will reshape vec to
238    #   [v1, v2*v3, v0].
239
240    # Un-blockify: Flatten block dimensions.  Reshape
241    #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
242    if vec.shape.is_fully_defined():
243      # vec_shape = [v0, v1, v2, v3]
244      vec_shape = vec.shape.as_list()
245      # vec_leading_shape = [v0, v1]
246      vec_leading_shape = vec_shape[:-self.block_depth]
247      # vec_block_shape = [v2, v3]
248      vec_block_shape = vec_shape[-self.block_depth:]
249      # flat_shape = [v0, v1, v2*v3]
250      flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
251    else:
252      vec_shape = array_ops.shape(vec)
253      vec_leading_shape = vec_shape[:-self.block_depth]
254      vec_block_shape = vec_shape[-self.block_depth:]
255      flat_shape = array_ops.concat(
256          (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
257    vec_flat = array_ops.reshape(vec, flat_shape)
258
259    # Matricize:  Reshape to batch matrix.
260    #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
261    # representing a shape [v1] batch of [v2*v3, v0] matrices.
262    matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
263    return matrix
264
265  def _fft(self, x):
266    """FFT along the last self.block_depth dimensions of x.
267
268    Args:
269      x: `Tensor` with floating or complex `dtype`.
270        Should be in the form returned by self._vectorize_then_blockify.
271
272    Returns:
273      `Tensor` with `dtype` `complex64`.
274    """
275    x_complex = _to_complex(x)
276    return _FFT_OP[self.block_depth](x_complex)
277
278  def _ifft(self, x):
279    """IFFT along the last self.block_depth dimensions of x.
280
281    Args:
282      x: `Tensor` with floating or complex dtype.  Should be in the form
283        returned by self._vectorize_then_blockify.
284
285    Returns:
286      `Tensor` with `dtype` `complex64`.
287    """
288    x_complex = _to_complex(x)
289    return _IFFT_OP[self.block_depth](x_complex)
290
291  def convolution_kernel(self, name="convolution_kernel"):
292    """Convolution kernel corresponding to `self.spectrum`.
293
294    The `D` dimensional DFT of this kernel is the frequency domain spectrum of
295    this operator.
296
297    Args:
298      name:  A name to give this `Op`.
299
300    Returns:
301      `Tensor` with `dtype` `self.dtype`.
302    """
303    with self._name_scope(name):
304      h = self._ifft(_to_complex(self.spectrum))
305      return math_ops.cast(h, self.dtype)
306
307  def _shape(self):
308    s_shape = self._spectrum.shape
309    # Suppose spectrum.shape = [a, b, c, d]
310    # block_depth = 2
311    # Then:
312    #   batch_shape = [a, b]
313    #   N = c*d
314    # and we want to return
315    #   [a, b, c*d, c*d]
316    batch_shape = s_shape[:-self.block_depth]
317    # trailing_dims = [c, d]
318    trailing_dims = s_shape[-self.block_depth:]
319    if trailing_dims.is_fully_defined():
320      n = np.prod(trailing_dims.as_list())
321    else:
322      n = None
323    n_x_n = tensor_shape.TensorShape([n, n])
324    return batch_shape.concatenate(n_x_n)
325
326  def _shape_tensor(self, spectrum=None):
327    spectrum = self.spectrum if spectrum is None else spectrum
328    # See self.shape for explanation of steps
329    s_shape = array_ops.shape(spectrum)
330    batch_shape = s_shape[:-self.block_depth]
331    trailing_dims = s_shape[-self.block_depth:]
332    n = math_ops.reduce_prod(trailing_dims)
333    n_x_n = [n, n]
334    return array_ops.concat((batch_shape, n_x_n), 0)
335
336  def assert_hermitian_spectrum(self, name="assert_hermitian_spectrum"):
337    """Returns an `Op` that asserts this operator has Hermitian spectrum.
338
339    This operator corresponds to a real-valued matrix if and only if its
340    spectrum is Hermitian.
341
342    Args:
343      name:  A name to give this `Op`.
344
345    Returns:
346      An `Op` that asserts this operator has Hermitian spectrum.
347    """
348    eps = np.finfo(self.dtype.real_dtype.as_numpy_dtype).eps
349    with self._name_scope(name):
350      # Assume linear accumulation of error.
351      max_err = eps * self.domain_dimension_tensor()
352      imag_convolution_kernel = math_ops.imag(self.convolution_kernel())
353      return check_ops.assert_less(
354          math_ops.abs(imag_convolution_kernel),
355          max_err,
356          message="Spectrum was not Hermitian")
357
358  def _assert_non_singular(self):
359    return linear_operator_util.assert_no_entries_with_modulus_zero(
360        self.spectrum,
361        message="Singular operator:  Spectrum contained zero values.")
362
363  def _assert_positive_definite(self):
364    # This operator has the action  Ax = F^H D F x,
365    # where D is the diagonal matrix with self.spectrum on the diag.  Therefore,
366    # <x, Ax> = <Fx, DFx>,
367    # Since F is bijective, the condition for positive definite is the same as
368    # for a diagonal matrix, i.e. real part of spectrum is positive.
369    message = (
370        "Not positive definite:  Real part of spectrum was not all positive.")
371    return check_ops.assert_positive(
372        math_ops.real(self.spectrum), message=message)
373
374  def _assert_self_adjoint(self):
375    # Recall correspondence between symmetry and real transforms.  See docstring
376    return linear_operator_util.assert_zero_imag_part(
377        self.spectrum,
378        message=(
379            "Not self-adjoint:  The spectrum contained non-zero imaginary part."
380        ))
381
382  def _broadcast_batch_dims(self, x, spectrum):
383    """Broadcast batch dims of batch matrix `x` and spectrum."""
384    spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum, name="spectrum")
385    # spectrum.shape = batch_shape + block_shape
386    # First make spectrum a batch matrix with
387    #   spectrum.shape = batch_shape + [prod(block_shape), 1]
388    batch_shape = self._batch_shape_tensor(
389        shape=self._shape_tensor(spectrum=spectrum))
390    spec_mat = array_ops.reshape(
391        spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0))
392    # Second, broadcast, possibly requiring an addition of array of zeros.
393    x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
394                                                                    spec_mat))
395    # Third, put the block shape back into spectrum.
396    x_batch_shape = array_ops.shape(x)[:-2]
397    spectrum_shape = array_ops.shape(spectrum)
398    spectrum = array_ops.reshape(
399        spec_mat,
400        array_ops.concat(
401            (x_batch_shape,
402             self._block_shape_tensor(spectrum_shape=spectrum_shape)),
403            axis=0))
404
405    return x, spectrum
406
407  def _matmul(self, x, adjoint=False, adjoint_arg=False):
408    x = linalg.adjoint(x) if adjoint_arg else x
409    # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
410    # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix.  Therefore
411    # matmul(x) = F^{-1} diag(spectrum) F x,
412    #           = F^{H} diag(spectrum) F x,
413    # so that
414    # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x.
415    spectrum = _to_complex(self.spectrum)
416    if adjoint:
417      spectrum = math_ops.conj(spectrum)
418
419    x = math_ops.cast(x, spectrum.dtype)
420
421    x, spectrum = self._broadcast_batch_dims(x, spectrum)
422
423    x_vb = self._vectorize_then_blockify(x)
424    fft_x_vb = self._fft(x_vb)
425    block_vector_result = self._ifft(spectrum * fft_x_vb)
426    y = self._unblockify_then_matricize(block_vector_result)
427
428    return math_ops.cast(y, self.dtype)
429
430  def _determinant(self):
431    axis = [-(i + 1) for i in range(self.block_depth)]
432    det = math_ops.reduce_prod(self.spectrum, axis=axis)
433    return math_ops.cast(det, self.dtype)
434
435  def _log_abs_determinant(self):
436    axis = [-(i + 1) for i in range(self.block_depth)]
437    lad = math_ops.reduce_sum(
438        math_ops.log(math_ops.abs(self.spectrum)), axis=axis)
439    return math_ops.cast(lad, self.dtype)
440
441  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
442    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
443    spectrum = _to_complex(self.spectrum)
444    if adjoint:
445      spectrum = math_ops.conj(spectrum)
446
447    rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum)
448
449    rhs_vb = self._vectorize_then_blockify(rhs)
450    fft_rhs_vb = self._fft(rhs_vb)
451    solution_vb = self._ifft(fft_rhs_vb / spectrum)
452    x = self._unblockify_then_matricize(solution_vb)
453    return math_ops.cast(x, self.dtype)
454
455  def _diag_part(self):
456    # Get ones in shape of diag, which is [B1,...,Bb, N]
457    # Also get the size of the diag, "N".
458    if self.shape.is_fully_defined():
459      diag_shape = self.shape[:-1]
460      diag_size = self.domain_dimension.value
461    else:
462      diag_shape = self.shape_tensor()[:-1]
463      diag_size = self.domain_dimension_tensor()
464    ones_diag = array_ops.ones(diag_shape, dtype=self.dtype)
465
466    # As proved in comments in self._trace, the value on the diag is constant,
467    # repeated N times.  This value is the trace divided by N.
468
469    # The handling of self.shape = (0, 0) is tricky, and is the reason we choose
470    # to compute trace and use that to compute diag_part, rather than computing
471    # the value on the diagonal ("diag_value") directly.  Both result in a 0/0,
472    # but in different places, and the current method gives the right result in
473    # the end.
474
475    # Here, if self.shape = (0, 0), then self.trace() = 0., and then
476    # diag_value = 0. / 0. = NaN.
477    diag_value = self.trace() / math_ops.cast(diag_size, self.dtype)
478
479    # If self.shape = (0, 0), then ones_diag = [] (empty tensor), and then
480    # the following line is NaN * [] = [], as needed.
481    return diag_value[..., array_ops.newaxis] * ones_diag
482
483  def _trace(self):
484    # The diagonal of the [[nested] block] circulant operator is the mean of
485    # the spectrum.
486    # Proof:  For the [0,...,0] element, this follows from the IDFT formula.
487    # Then the result follows since all diagonal elements are the same.
488
489    # Therefore, the trace is the sum of the spectrum.
490
491    # Get shape of diag along with the axis over which to reduce the spectrum.
492    # We will reduce the spectrum over all block indices.
493    if self.spectrum.shape.is_fully_defined():
494      spec_rank = self.spectrum.shape.ndims
495      axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32)
496    else:
497      spec_rank = array_ops.rank(self.spectrum)
498      axis = math_ops.range(spec_rank - self.block_depth, spec_rank)
499
500    # Real diag part "re_d".
501    # Suppose spectrum.shape = [B1,...,Bb, N1, N2]
502    # self.shape = [B1,...,Bb, N, N], with N1 * N2 = N.
503    # re_d_value.shape = [B1,...,Bb]
504    re_d_value = math_ops.reduce_sum(math_ops.real(self.spectrum), axis=axis)
505
506    if not self.dtype.is_complex:
507      return math_ops.cast(re_d_value, self.dtype)
508
509    # Imaginary part, "im_d".
510    if self.is_self_adjoint:
511      im_d_value = array_ops.zeros_like(re_d_value)
512    else:
513      im_d_value = math_ops.reduce_sum(math_ops.imag(self.spectrum), axis=axis)
514
515    return math_ops.cast(math_ops.complex(re_d_value, im_d_value), self.dtype)
516
517
518@tf_export("linalg.LinearOperatorCirculant")
519class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
520  """`LinearOperator` acting like a circulant matrix.
521
522  This operator acts like a circulant matrix `A` with
523  shape `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
524  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
525  an `N x N` matrix.  This matrix `A` is not materialized, but for
526  purposes of broadcasting this shape will be relevant.
527
528  #### Description in terms of circulant matrices
529
530  Circulant means the entries of `A` are generated by a single vector, the
531  convolution kernel `h`: `A_{mn} := h_{m-n mod N}`.  With `h = [w, x, y, z]`,
532
533  ```
534  A = |w z y x|
535      |x w z y|
536      |y x w z|
537      |z y x w|
538  ```
539
540  This means that the result of matrix multiplication `v = Au` has `Lth` column
541  given circular convolution between `h` with the `Lth` column of `u`.
542
543  #### Description in terms of the frequency spectrum
544
545  There is an equivalent description in terms of the [batch] spectrum `H` and
546  Fourier transforms.  Here we consider `A.shape = [N, N]` and ignore batch
547  dimensions.  Define the discrete Fourier transform (DFT) and its inverse by
548
549  ```
550  DFT[ h[n] ] = H[k] := sum_{n = 0}^{N - 1} h_n e^{-i 2pi k n / N}
551  IDFT[ H[k] ] = h[n] = N^{-1} sum_{k = 0}^{N - 1} H_k e^{i 2pi k n / N}
552  ```
553
554  From these definitions, we see that
555
556  ```
557  H[0] = sum_{n = 0}^{N - 1} h_n
558  H[1] = "the first positive frequency"
559  H[N - 1] = "the first negative frequency"
560  ```
561
562  Loosely speaking, with `*` element-wise multiplication, matrix multiplication
563  is equal to the action of a Fourier multiplier: `A u = IDFT[ H * DFT[u] ]`.
564  Precisely speaking, given `[N, R]` matrix `u`, let `DFT[u]` be the `[N, R]`
565  matrix with `rth` column equal to the DFT of the `rth` column of `u`.
566  Define the `IDFT` similarly.
567  Matrix multiplication may be expressed columnwise:
568
569  ```(A u)_r = IDFT[ H * (DFT[u])_r ]```
570
571  #### Operator properties deduced from the spectrum.
572
573  Letting `U` be the `kth` Euclidean basis vector, and `U = IDFT[u]`.
574  The above formulas show that`A U = H_k * U`.  We conclude that the elements
575  of `H` are the eigenvalues of this operator.   Therefore
576
577  * This operator is positive definite if and only if `Real{H} > 0`.
578
579  A general property of Fourier transforms is the correspondence between
580  Hermitian functions and real valued transforms.
581
582  Suppose `H.shape = [B1,...,Bb, N]`.  We say that `H` is a Hermitian spectrum
583  if, with `%` meaning modulus division,
584
585  ```H[..., n % N] = ComplexConjugate[ H[..., (-n) % N] ]```
586
587  * This operator corresponds to a real matrix if and only if `H` is Hermitian.
588  * This operator is self-adjoint if and only if `H` is real.
589
590  See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
591
592  #### Example of a self-adjoint positive definite operator
593
594  ```python
595  # spectrum is real ==> operator is self-adjoint
596  # spectrum is positive ==> operator is positive definite
597  spectrum = [6., 4, 2]
598
599  operator = LinearOperatorCirculant(spectrum)
600
601  # IFFT[spectrum]
602  operator.convolution_kernel()
603  ==> [4 + 0j, 1 + 0.58j, 1 - 0.58j]
604
605  operator.to_dense()
606  ==> [[4 + 0.0j, 1 - 0.6j, 1 + 0.6j],
607       [1 + 0.6j, 4 + 0.0j, 1 - 0.6j],
608       [1 - 0.6j, 1 + 0.6j, 4 + 0.0j]]
609  ```
610
611  #### Example of defining in terms of a real convolution kernel
612
613  ```python
614  # convolution_kernel is real ==> spectrum is Hermitian.
615  convolution_kernel = [1., 2., 1.]]
616  spectrum = tf.signal.fft(tf.cast(convolution_kernel, tf.complex64))
617
618  # spectrum is Hermitian ==> operator is real.
619  # spectrum is shape [3] ==> operator is shape [3, 3]
620  # We force the input/output type to be real, which allows this to operate
621  # like a real matrix.
622  operator = LinearOperatorCirculant(spectrum, input_output_dtype=tf.float32)
623
624  operator.to_dense()
625  ==> [[ 1, 1, 2],
626       [ 2, 1, 1],
627       [ 1, 2, 1]]
628  ```
629
630  #### Example of Hermitian spectrum
631
632  ```python
633  # spectrum is shape [3] ==> operator is shape [3, 3]
634  # spectrum is Hermitian ==> operator is real.
635  spectrum = [1, 1j, -1j]
636
637  operator = LinearOperatorCirculant(spectrum)
638
639  operator.to_dense()
640  ==> [[ 0.33 + 0j,  0.91 + 0j, -0.24 + 0j],
641       [-0.24 + 0j,  0.33 + 0j,  0.91 + 0j],
642       [ 0.91 + 0j, -0.24 + 0j,  0.33 + 0j]
643  ```
644
645  #### Example of forcing real `dtype` when spectrum is Hermitian
646
647  ```python
648  # spectrum is shape [4] ==> operator is shape [4, 4]
649  # spectrum is real ==> operator is self-adjoint
650  # spectrum is Hermitian ==> operator is real
651  # spectrum has positive real part ==> operator is positive-definite.
652  spectrum = [6., 4, 2, 4]
653
654  # Force the input dtype to be float32.
655  # Cast the output to float32.  This is fine because the operator will be
656  # real due to Hermitian spectrum.
657  operator = LinearOperatorCirculant(spectrum, input_output_dtype=tf.float32)
658
659  operator.shape
660  ==> [4, 4]
661
662  operator.to_dense()
663  ==> [[4, 1, 0, 1],
664       [1, 4, 1, 0],
665       [0, 1, 4, 1],
666       [1, 0, 1, 4]]
667
668  # convolution_kernel = tf.signal.ifft(spectrum)
669  operator.convolution_kernel()
670  ==> [4, 1, 0, 1]
671  ```
672
673  #### Performance
674
675  Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
676  and `x.shape = [N, R]`.  Then
677
678  * `operator.matmul(x)` is `O(R*N*Log[N])`
679  * `operator.solve(x)` is `O(R*N*Log[N])`
680  * `operator.determinant()` involves a size `N` `reduce_prod`.
681
682  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
683  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
684
685  #### Matrix property hints
686
687  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
688  for `X = non_singular, self_adjoint, positive_definite, square`.
689  These have the following meaning:
690
691  * If `is_X == True`, callers should expect the operator to have the
692    property `X`.  This is a promise that should be fulfilled, but is *not* a
693    runtime assert.  For example, finite floating point precision may result
694    in these promises being violated.
695  * If `is_X == False`, callers should expect the operator to not have `X`.
696  * If `is_X == None` (the default), callers should have no expectation either
697    way.
698
699  References:
700    Toeplitz and Circulant Matrices - A Review:
701      [Gray, 2006](https://www.nowpublishers.com/article/Details/CIT-006)
702      ([pdf](https://ee.stanford.edu/~gray/toeplitz.pdf))
703  """
704
705  def __init__(self,
706               spectrum,
707               input_output_dtype=dtypes.complex64,
708               is_non_singular=None,
709               is_self_adjoint=None,
710               is_positive_definite=None,
711               is_square=True,
712               name="LinearOperatorCirculant"):
713    r"""Initialize an `LinearOperatorCirculant`.
714
715    This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
716    by providing `spectrum`, a `[B1,...,Bb, N]` `Tensor`.
717
718    If `input_output_dtype = DTYPE`:
719
720    * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
721    * Values returned by all methods, such as `matmul` or `determinant` will be
722      cast to `DTYPE`.
723
724    Note that if the spectrum is not Hermitian, then this operator corresponds
725    to a complex matrix with non-zero imaginary part.  In this case, setting
726    `input_output_dtype` to a real type will forcibly cast the output to be
727    real, resulting in incorrect results!
728
729    If on the other hand the spectrum is Hermitian, then this operator
730    corresponds to a real-valued matrix, and setting `input_output_dtype` to
731    a real type is fine.
732
733    Args:
734      spectrum:  Shape `[B1,...,Bb, N]` `Tensor`.  Allowed dtypes: `float16`,
735        `float32`, `float64`, `complex64`, `complex128`.  Type can be different
736        than `input_output_dtype`
737      input_output_dtype: `dtype` for input/output.
738      is_non_singular:  Expect that this operator is non-singular.
739      is_self_adjoint:  Expect that this operator is equal to its hermitian
740        transpose.  If `spectrum` is real, this will always be true.
741      is_positive_definite:  Expect that this operator is positive definite,
742        meaning the quadratic form `x^H A x` has positive real part for all
743        nonzero `x`.  Note that we do not require the operator to be
744        self-adjoint to be positive-definite.  See:
745        https://en.wikipedia.org/wiki/Positive-definite_matrix\
746            #Extension_for_non_symmetric_matrices
747      is_square:  Expect that this operator acts like square [batch] matrices.
748      name:  A name to prepend to all ops created by this class.
749    """
750    parameters = dict(
751        spectrum=spectrum,
752        input_output_dtype=input_output_dtype,
753        is_non_singular=is_non_singular,
754        is_self_adjoint=is_self_adjoint,
755        is_positive_definite=is_positive_definite,
756        is_square=is_square,
757        name=name
758    )
759    super(LinearOperatorCirculant, self).__init__(
760        spectrum,
761        block_depth=1,
762        input_output_dtype=input_output_dtype,
763        is_non_singular=is_non_singular,
764        is_self_adjoint=is_self_adjoint,
765        is_positive_definite=is_positive_definite,
766        is_square=is_square,
767        parameters=parameters,
768        name=name)
769
770  def _eigvals(self):
771    return ops.convert_to_tensor_v2_with_dispatch(self.spectrum)
772
773
774@tf_export("linalg.LinearOperatorCirculant2D")
775class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
776  """`LinearOperator` acting like a block circulant matrix.
777
778  This operator acts like a block circulant matrix `A` with
779  shape `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
780  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
781  an `N x N` matrix.  This matrix `A` is not materialized, but for
782  purposes of broadcasting this shape will be relevant.
783
784  #### Description in terms of block circulant matrices
785
786  If `A` is block circulant, with block sizes `N0, N1` (`N0 * N1 = N`):
787  `A` has a block circulant structure, composed of `N0 x N0` blocks, with each
788  block an `N1 x N1` circulant matrix.
789
790  For example, with `W`, `X`, `Y`, `Z` each circulant,
791
792  ```
793  A = |W Z Y X|
794      |X W Z Y|
795      |Y X W Z|
796      |Z Y X W|
797  ```
798
799  Note that `A` itself will not in general be circulant.
800
801  #### Description in terms of the frequency spectrum
802
803  There is an equivalent description in terms of the [batch] spectrum `H` and
804  Fourier transforms.  Here we consider `A.shape = [N, N]` and ignore batch
805  dimensions.
806
807  If `H.shape = [N0, N1]`, (`N0 * N1 = N`):
808  Loosely speaking, matrix multiplication is equal to the action of a
809  Fourier multiplier:  `A u = IDFT2[ H DFT2[u] ]`.
810  Precisely speaking, given `[N, R]` matrix `u`, let `DFT2[u]` be the
811  `[N0, N1, R]` `Tensor` defined by re-shaping `u` to `[N0, N1, R]` and taking
812  a two dimensional DFT across the first two dimensions.  Let `IDFT2` be the
813  inverse of `DFT2`.  Matrix multiplication may be expressed columnwise:
814
815  ```(A u)_r = IDFT2[ H * (DFT2[u])_r ]```
816
817  #### Operator properties deduced from the spectrum.
818
819  * This operator is positive definite if and only if `Real{H} > 0`.
820
821  A general property of Fourier transforms is the correspondence between
822  Hermitian functions and real valued transforms.
823
824  Suppose `H.shape = [B1,...,Bb, N0, N1]`, we say that `H` is a Hermitian
825  spectrum if, with `%` indicating modulus division,
826
827  ```
828  H[..., n0 % N0, n1 % N1] = ComplexConjugate[ H[..., (-n0) % N0, (-n1) % N1 ].
829  ```
830
831  * This operator corresponds to a real matrix if and only if `H` is Hermitian.
832  * This operator is self-adjoint if and only if `H` is real.
833
834  See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
835
836  ### Example of a self-adjoint positive definite operator
837
838  ```python
839  # spectrum is real ==> operator is self-adjoint
840  # spectrum is positive ==> operator is positive definite
841  spectrum = [[1., 2., 3.],
842              [4., 5., 6.],
843              [7., 8., 9.]]
844
845  operator = LinearOperatorCirculant2D(spectrum)
846
847  # IFFT[spectrum]
848  operator.convolution_kernel()
849  ==> [[5.0+0.0j, -0.5-.3j, -0.5+.3j],
850       [-1.5-.9j,        0,        0],
851       [-1.5+.9j,        0,        0]]
852
853  operator.to_dense()
854  ==> Complex self adjoint 9 x 9 matrix.
855  ```
856
857  #### Example of defining in terms of a real convolution kernel,
858
859  ```python
860  # convolution_kernel is real ==> spectrum is Hermitian.
861  convolution_kernel = [[1., 2., 1.], [5., -1., 1.]]
862  spectrum = tf.signal.fft2d(tf.cast(convolution_kernel, tf.complex64))
863
864  # spectrum is shape [2, 3] ==> operator is shape [6, 6]
865  # spectrum is Hermitian ==> operator is real.
866  operator = LinearOperatorCirculant2D(spectrum, input_output_dtype=tf.float32)
867  ```
868
869  #### Performance
870
871  Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
872  and `x.shape = [N, R]`.  Then
873
874  * `operator.matmul(x)` is `O(R*N*Log[N])`
875  * `operator.solve(x)` is `O(R*N*Log[N])`
876  * `operator.determinant()` involves a size `N` `reduce_prod`.
877
878  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
879  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
880
881  #### Matrix property hints
882
883  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
884  for `X = non_singular, self_adjoint, positive_definite, square`.
885  These have the following meaning
886  * If `is_X == True`, callers should expect the operator to have the
887    property `X`.  This is a promise that should be fulfilled, but is *not* a
888    runtime assert.  For example, finite floating point precision may result
889    in these promises being violated.
890  * If `is_X == False`, callers should expect the operator to not have `X`.
891  * If `is_X == None` (the default), callers should have no expectation either
892    way.
893  """
894
895  def __init__(self,
896               spectrum,
897               input_output_dtype=dtypes.complex64,
898               is_non_singular=None,
899               is_self_adjoint=None,
900               is_positive_definite=None,
901               is_square=True,
902               name="LinearOperatorCirculant2D"):
903    r"""Initialize an `LinearOperatorCirculant2D`.
904
905    This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
906    by providing `spectrum`, a `[B1,...,Bb, N0, N1]` `Tensor` with `N0*N1 = N`.
907
908    If `input_output_dtype = DTYPE`:
909
910    * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
911    * Values returned by all methods, such as `matmul` or `determinant` will be
912      cast to `DTYPE`.
913
914    Note that if the spectrum is not Hermitian, then this operator corresponds
915    to a complex matrix with non-zero imaginary part.  In this case, setting
916    `input_output_dtype` to a real type will forcibly cast the output to be
917    real, resulting in incorrect results!
918
919    If on the other hand the spectrum is Hermitian, then this operator
920    corresponds to a real-valued matrix, and setting `input_output_dtype` to
921    a real type is fine.
922
923    Args:
924      spectrum:  Shape `[B1,...,Bb, N]` `Tensor`.  Allowed dtypes: `float16`,
925        `float32`, `float64`, `complex64`, `complex128`.  Type can be different
926        than `input_output_dtype`
927      input_output_dtype: `dtype` for input/output.
928      is_non_singular:  Expect that this operator is non-singular.
929      is_self_adjoint:  Expect that this operator is equal to its hermitian
930        transpose.  If `spectrum` is real, this will always be true.
931      is_positive_definite:  Expect that this operator is positive definite,
932        meaning the quadratic form `x^H A x` has positive real part for all
933        nonzero `x`.  Note that we do not require the operator to be
934        self-adjoint to be positive-definite.  See:
935        https://en.wikipedia.org/wiki/Positive-definite_matrix\
936            #Extension_for_non_symmetric_matrices
937      is_square:  Expect that this operator acts like square [batch] matrices.
938      name:  A name to prepend to all ops created by this class.
939    """
940    parameters = dict(
941        spectrum=spectrum,
942        input_output_dtype=input_output_dtype,
943        is_non_singular=is_non_singular,
944        is_self_adjoint=is_self_adjoint,
945        is_positive_definite=is_positive_definite,
946        is_square=is_square,
947        name=name
948    )
949    super(LinearOperatorCirculant2D, self).__init__(
950        spectrum,
951        block_depth=2,
952        input_output_dtype=input_output_dtype,
953        is_non_singular=is_non_singular,
954        is_self_adjoint=is_self_adjoint,
955        is_positive_definite=is_positive_definite,
956        is_square=is_square,
957        parameters=parameters,
958        name=name)
959
960
961@tf_export("linalg.LinearOperatorCirculant3D")
962class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
963  """`LinearOperator` acting like a nested block circulant matrix.
964
965  This operator acts like a block circulant matrix `A` with
966  shape `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
967  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
968  an `N x N` matrix.  This matrix `A` is not materialized, but for
969  purposes of broadcasting this shape will be relevant.
970
971  #### Description in terms of block circulant matrices
972
973  If `A` is nested block circulant, with block sizes `N0, N1, N2`
974  (`N0 * N1 * N2 = N`):
975  `A` has a block structure, composed of `N0 x N0` blocks, with each
976  block an `N1 x N1` block circulant matrix.
977
978  For example, with `W`, `X`, `Y`, `Z` each block circulant,
979
980  ```
981  A = |W Z Y X|
982      |X W Z Y|
983      |Y X W Z|
984      |Z Y X W|
985  ```
986
987  Note that `A` itself will not in general be circulant.
988
989  #### Description in terms of the frequency spectrum
990
991  There is an equivalent description in terms of the [batch] spectrum `H` and
992  Fourier transforms.  Here we consider `A.shape = [N, N]` and ignore batch
993  dimensions.
994
995  If `H.shape = [N0, N1, N2]`, (`N0 * N1 * N2 = N`):
996  Loosely speaking, matrix multiplication is equal to the action of a
997  Fourier multiplier:  `A u = IDFT3[ H DFT3[u] ]`.
998  Precisely speaking, given `[N, R]` matrix `u`, let `DFT3[u]` be the
999  `[N0, N1, N2, R]` `Tensor` defined by re-shaping `u` to `[N0, N1, N2, R]` and
1000  taking a three dimensional DFT across the first three dimensions.  Let `IDFT3`
1001  be the inverse of `DFT3`.  Matrix multiplication may be expressed columnwise:
1002
1003  ```(A u)_r = IDFT3[ H * (DFT3[u])_r ]```
1004
1005  #### Operator properties deduced from the spectrum.
1006
1007  * This operator is positive definite if and only if `Real{H} > 0`.
1008
1009  A general property of Fourier transforms is the correspondence between
1010  Hermitian functions and real valued transforms.
1011
1012  Suppose `H.shape = [B1,...,Bb, N0, N1, N2]`, we say that `H` is a Hermitian
1013  spectrum if, with `%` meaning modulus division,
1014
1015  ```
1016  H[..., n0 % N0, n1 % N1, n2 % N2]
1017    = ComplexConjugate[ H[..., (-n0) % N0, (-n1) % N1, (-n2) % N2] ].
1018  ```
1019
1020  * This operator corresponds to a real matrix if and only if `H` is Hermitian.
1021  * This operator is self-adjoint if and only if `H` is real.
1022
1023  See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
1024
1025  ### Examples
1026
1027  See `LinearOperatorCirculant` and `LinearOperatorCirculant2D` for examples.
1028
1029  #### Performance
1030
1031  Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
1032  and `x.shape = [N, R]`.  Then
1033
1034  * `operator.matmul(x)` is `O(R*N*Log[N])`
1035  * `operator.solve(x)` is `O(R*N*Log[N])`
1036  * `operator.determinant()` involves a size `N` `reduce_prod`.
1037
1038  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
1039  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
1040
1041  #### Matrix property hints
1042
1043  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
1044  for `X = non_singular, self_adjoint, positive_definite, square`.
1045  These have the following meaning
1046  * If `is_X == True`, callers should expect the operator to have the
1047    property `X`.  This is a promise that should be fulfilled, but is *not* a
1048    runtime assert.  For example, finite floating point precision may result
1049    in these promises being violated.
1050  * If `is_X == False`, callers should expect the operator to not have `X`.
1051  * If `is_X == None` (the default), callers should have no expectation either
1052    way.
1053  """
1054
1055  def __init__(self,
1056               spectrum,
1057               input_output_dtype=dtypes.complex64,
1058               is_non_singular=None,
1059               is_self_adjoint=None,
1060               is_positive_definite=None,
1061               is_square=True,
1062               name="LinearOperatorCirculant3D"):
1063    """Initialize an `LinearOperatorCirculant`.
1064
1065    This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
1066    by providing `spectrum`, a `[B1,...,Bb, N0, N1, N2]` `Tensor`
1067    with `N0*N1*N2 = N`.
1068
1069    If `input_output_dtype = DTYPE`:
1070
1071    * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
1072    * Values returned by all methods, such as `matmul` or `determinant` will be
1073      cast to `DTYPE`.
1074
1075    Note that if the spectrum is not Hermitian, then this operator corresponds
1076    to a complex matrix with non-zero imaginary part.  In this case, setting
1077    `input_output_dtype` to a real type will forcibly cast the output to be
1078    real, resulting in incorrect results!
1079
1080    If on the other hand the spectrum is Hermitian, then this operator
1081    corresponds to a real-valued matrix, and setting `input_output_dtype` to
1082    a real type is fine.
1083
1084    Args:
1085      spectrum:  Shape `[B1,...,Bb, N]` `Tensor`.  Allowed dtypes: `float16`,
1086        `float32`, `float64`, `complex64`, `complex128`.  Type can be different
1087        than `input_output_dtype`
1088      input_output_dtype: `dtype` for input/output.
1089      is_non_singular:  Expect that this operator is non-singular.
1090      is_self_adjoint:  Expect that this operator is equal to its hermitian
1091        transpose.  If `spectrum` is real, this will always be true.
1092      is_positive_definite:  Expect that this operator is positive definite,
1093        meaning the real part of all eigenvalues is positive.  We do not require
1094        the operator to be self-adjoint to be positive-definite.  See:
1095        https://en.wikipedia.org/wiki/Positive-definite_matrix
1096            #Extension_for_non_symmetric_matrices
1097      is_square:  Expect that this operator acts like square [batch] matrices.
1098      name:  A name to prepend to all ops created by this class.
1099    """
1100    parameters = dict(
1101        spectrum=spectrum,
1102        input_output_dtype=input_output_dtype,
1103        is_non_singular=is_non_singular,
1104        is_self_adjoint=is_self_adjoint,
1105        is_positive_definite=is_positive_definite,
1106        is_square=is_square,
1107        name=name
1108    )
1109    super(LinearOperatorCirculant3D, self).__init__(
1110        spectrum,
1111        block_depth=3,
1112        input_output_dtype=input_output_dtype,
1113        is_non_singular=is_non_singular,
1114        is_self_adjoint=is_self_adjoint,
1115        is_positive_definite=is_positive_definite,
1116        is_square=is_square,
1117        parameters=parameters,
1118        name=name)
1119
1120
1121def _to_complex(x):
1122  if x.dtype.is_complex:
1123    return x
1124  dtype = dtypes.complex64
1125
1126  if x.dtype == dtypes.float64:
1127    dtype = dtypes.complex128
1128  return math_ops.cast(x, dtype)
1129