1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fast-Fourier Transform ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import dtypes as _dtypes
23from tensorflow.python.framework import ops as _ops
24from tensorflow.python.framework import tensor_util as _tensor_util
25from tensorflow.python.ops import array_ops as _array_ops
26from tensorflow.python.ops import gen_spectral_ops
27from tensorflow.python.ops import manip_ops
28from tensorflow.python.ops import math_ops as _math_ops
29from tensorflow.python.util import dispatch
30from tensorflow.python.util.tf_export import tf_export
31
32
33def _infer_fft_length_for_rfft(input_tensor, fft_rank):
34  """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`."""
35  # A TensorShape for the inner fft_rank dimensions.
36  fft_shape = input_tensor.get_shape()[-fft_rank:]
37
38  # If any dim is unknown, fall back to tensor-based math.
39  if not fft_shape.is_fully_defined():
40    return _array_ops.shape(input_tensor)[-fft_rank:]
41
42  # Otherwise, return a constant.
43  return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32)
44
45
46def _infer_fft_length_for_irfft(input_tensor, fft_rank):
47  """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`."""
48  # A TensorShape for the inner fft_rank dimensions.
49  fft_shape = input_tensor.get_shape()[-fft_rank:]
50
51  # If any dim is unknown, fall back to tensor-based math.
52  if not fft_shape.is_fully_defined():
53    fft_length = _array_ops.unstack(_array_ops.shape(input_tensor)[-fft_rank:])
54    fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1))
55    return _array_ops.stack(fft_length)
56
57  # Otherwise, return a constant.
58  fft_length = fft_shape.as_list()
59  if fft_length:
60    fft_length[-1] = max(0, 2 * (fft_length[-1] - 1))
61  return _ops.convert_to_tensor(fft_length, _dtypes.int32)
62
63
64def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False):
65  """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims."""
66  fft_shape = _tensor_util.constant_value_as_shape(fft_length)
67
68  # Edge case: skip padding empty tensors.
69  if (input_tensor.shape.ndims is not None and
70      any(dim.value == 0 for dim in input_tensor.shape.dims)):
71    return input_tensor
72
73  # If we know the shapes ahead of time, we can either skip or pre-compute the
74  # appropriate paddings. Otherwise, fall back to computing paddings in
75  # TensorFlow.
76  if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None:
77    # Slice the last FFT-rank dimensions from input_tensor's shape.
78    input_fft_shape = input_tensor.shape[-fft_shape.ndims:]
79
80    if input_fft_shape.is_fully_defined():
81      # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
82      if is_reverse:
83        fft_shape = fft_shape[:-1].concatenate(
84            fft_shape.dims[-1].value // 2 + 1)
85
86      paddings = [[0, max(fft_dim.value - input_dim.value, 0)]
87                  for fft_dim, input_dim in zip(
88                      fft_shape.dims, input_fft_shape.dims)]
89      if any(pad > 0 for _, pad in paddings):
90        outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims -
91                                         fft_shape.ndims), 0)
92        return _array_ops.pad(input_tensor, outer_paddings + paddings)
93      return input_tensor
94
95  # If we can't determine the paddings ahead of time, then we have to pad. If
96  # the paddings end up as zero, tf.pad has a special-case that does no work.
97  input_rank = _array_ops.rank(input_tensor)
98  input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:]
99  outer_dims = _math_ops.maximum(0, input_rank - fft_rank)
100  outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype)
101  # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
102  if is_reverse:
103    fft_length = _array_ops.concat([fft_length[:-1],
104                                    fft_length[-1:] // 2 + 1], 0)
105  fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape)
106  paddings = _array_ops.concat([outer_paddings, fft_paddings], 0)
107  paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings],
108                              axis=1)
109  return _array_ops.pad(input_tensor, paddings)
110
111
112def _rfft_wrapper(fft_fn, fft_rank, default_name):
113  """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
114
115  def _rfft(input_tensor, fft_length=None, name=None):
116    """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
117    with _ops.name_scope(name, default_name,
118                         [input_tensor, fft_length]) as name:
119      input_tensor = _ops.convert_to_tensor(input_tensor,
120                                            preferred_dtype=_dtypes.float32)
121      if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
122        raise ValueError(
123            "RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
124            input_tensor)
125      real_dtype = input_tensor.dtype
126      if real_dtype == _dtypes.float32:
127        complex_dtype = _dtypes.complex64
128      else:
129        assert real_dtype == _dtypes.float64
130        complex_dtype = _dtypes.complex128
131      input_tensor.shape.with_rank_at_least(fft_rank)
132      if fft_length is None:
133        fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
134      else:
135        fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
136      input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
137
138      fft_length_static = _tensor_util.constant_value(fft_length)
139      if fft_length_static is not None:
140        fft_length = fft_length_static
141      return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
142  _rfft.__doc__ = fft_fn.__doc__
143  return _rfft
144
145
146def _irfft_wrapper(ifft_fn, fft_rank, default_name):
147  """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument."""
148
149  def _irfft(input_tensor, fft_length=None, name=None):
150    """Wrapper irfft* that infers fft_length argument."""
151    with _ops.name_scope(name, default_name,
152                         [input_tensor, fft_length]) as name:
153      input_tensor = _ops.convert_to_tensor(input_tensor,
154                                            preferred_dtype=_dtypes.complex64)
155      input_tensor.shape.with_rank_at_least(fft_rank)
156      if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
157        raise ValueError(
158            "IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
159            input_tensor)
160      complex_dtype = input_tensor.dtype
161      real_dtype = complex_dtype.real_dtype
162      if fft_length is None:
163        fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
164      else:
165        fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
166      input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
167                                         is_reverse=True)
168      fft_length_static = _tensor_util.constant_value(fft_length)
169      if fft_length_static is not None:
170        fft_length = fft_length_static
171      return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
172  _irfft.__doc__ = ifft_fn.__doc__
173  return _irfft
174
175
176# FFT/IFFT 1/2/3D are exported via
177# third_party/tensorflow/core/api_def/python_api/
178fft = gen_spectral_ops.fft
179ifft = gen_spectral_ops.ifft
180fft2d = gen_spectral_ops.fft2d
181ifft2d = gen_spectral_ops.ifft2d
182fft3d = gen_spectral_ops.fft3d
183ifft3d = gen_spectral_ops.ifft3d
184rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
185tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(
186    dispatch.add_dispatch_support(rfft))
187irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
188tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(
189    dispatch.add_dispatch_support(irfft))
190rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
191tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(
192    dispatch.add_dispatch_support(rfft2d))
193irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
194tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(
195    dispatch.add_dispatch_support(irfft2d))
196rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
197tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(
198    dispatch.add_dispatch_support(rfft3d))
199irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
200tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(
201    dispatch.add_dispatch_support(irfft3d))
202
203
204def _fft_size_for_grad(grad, rank):
205  return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:])
206
207
208@_ops.RegisterGradient("FFT")
209def _fft_grad(_, grad):
210  size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype)
211  return ifft(grad) * size
212
213
214@_ops.RegisterGradient("IFFT")
215def _ifft_grad(_, grad):
216  rsize = _math_ops.cast(
217      1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype),
218      grad.dtype)
219  return fft(grad) * rsize
220
221
222@_ops.RegisterGradient("FFT2D")
223def _fft2d_grad(_, grad):
224  size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype)
225  return ifft2d(grad) * size
226
227
228@_ops.RegisterGradient("IFFT2D")
229def _ifft2d_grad(_, grad):
230  rsize = _math_ops.cast(
231      1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype),
232      grad.dtype)
233  return fft2d(grad) * rsize
234
235
236@_ops.RegisterGradient("FFT3D")
237def _fft3d_grad(_, grad):
238  size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype)
239  return ifft3d(grad) * size
240
241
242@_ops.RegisterGradient("IFFT3D")
243def _ifft3d_grad(_, grad):
244  rsize = _math_ops.cast(
245      1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype),
246      grad.dtype)
247  return fft3d(grad) * rsize
248
249
250def _rfft_grad_helper(rank, irfft_fn):
251  """Returns a gradient function for an RFFT of the provided rank."""
252  # Can't happen because we don't register a gradient for RFFT3D.
253  assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
254
255  def _grad(op, grad):
256    """A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
257    fft_length = op.inputs[1]
258    complex_dtype = grad.dtype
259    real_dtype = complex_dtype.real_dtype
260    input_shape = _array_ops.shape(op.inputs[0])
261    is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype)
262
263    def _tile_for_broadcasting(matrix, t):
264      expanded = _array_ops.reshape(
265          matrix,
266          _array_ops.concat([
267              _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32),
268              _array_ops.shape(matrix)
269          ], 0))
270      return _array_ops.tile(
271          expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0))
272
273    def _mask_matrix(length):
274      """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len)."""
275      # TODO(rjryan): Speed up computation of twiddle factors using the
276      # following recurrence relation and cache them across invocations of RFFT.
277      #
278      # t_n = exp(sqrt(-1) * pi * n^2 / line_len)
279      # for n = 0, 1,..., line_len-1.
280      # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
281      a = _array_ops.tile(
282          _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
283      b = _array_ops.transpose(a, [1, 0])
284      return _math_ops.exp(
285          -2j * np.pi * _math_ops.cast(a * b, complex_dtype) /
286          _math_ops.cast(length, complex_dtype))
287
288    def _ymask(length):
289      """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
290      return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
291                            complex_dtype)
292
293    y0 = grad[..., 0:1]
294    if rank == 1:
295      ym = grad[..., -1:]
296      extra_terms = y0 + is_even * ym * _ymask(input_shape[-1])
297    elif rank == 2:
298      # Create a mask matrix for y0 and ym.
299      base_mask = _mask_matrix(input_shape[-2])
300
301      # Tile base_mask to match y0 in shape so that we can batch-matmul the
302      # inner 2 dimensions.
303      tiled_mask = _tile_for_broadcasting(base_mask, y0)
304
305      y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0))
306      extra_terms = y0_term
307
308      ym = grad[..., -1:]
309      ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym))
310
311      inner_dim = input_shape[-1]
312      ym_term = _array_ops.tile(
313          ym_term,
314          _array_ops.concat([
315              _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32),
316              [inner_dim]
317          ], 0)) * _ymask(inner_dim)
318
319      extra_terms += is_even * ym_term
320
321    # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
322    # factor, plus some additional terms to make up for the components dropped
323    # due to Hermitian symmetry.
324    input_size = _math_ops.cast(
325        _fft_size_for_grad(op.inputs[0], rank), real_dtype)
326    the_irfft = irfft_fn(grad, fft_length)
327    return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
328
329  return _grad
330
331
332def _irfft_grad_helper(rank, rfft_fn):
333  """Returns a gradient function for an IRFFT of the provided rank."""
334  # Can't happen because we don't register a gradient for IRFFT3D.
335  assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
336
337  def _grad(op, grad):
338    """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
339    # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
340    # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
341    # graph we special-case the situation where the FFT length and last
342    # dimension of the input are known at graph construction time.
343    fft_length = op.inputs[1]
344    fft_length_static = _tensor_util.constant_value(fft_length)
345    if fft_length_static is not None:
346      fft_length = fft_length_static
347    real_dtype = grad.dtype
348    if real_dtype == _dtypes.float32:
349      complex_dtype = _dtypes.complex64
350    elif real_dtype == _dtypes.float64:
351      complex_dtype = _dtypes.complex128
352    is_odd = _math_ops.mod(fft_length[-1], 2)
353    input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
354    mask = _array_ops.concat(
355        [[1.0], 2.0 * _array_ops.ones(
356            [input_last_dimension - 2 + is_odd], real_dtype),
357         _array_ops.ones([1 - is_odd], real_dtype)], 0)
358
359    rsize = _math_ops.reciprocal(_math_ops.cast(
360        _fft_size_for_grad(grad, rank), real_dtype))
361
362    # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
363    # factor and a mask. The mask scales the gradient for the Hermitian
364    # symmetric components of the RFFT by a factor of two, since these
365    # components are de-duplicated in the RFFT.
366    the_rfft = rfft_fn(grad, fft_length)
367    return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
368
369  return _grad
370
371
372@tf_export("signal.fftshift")
373@dispatch.add_dispatch_support
374def fftshift(x, axes=None, name=None):
375  """Shift the zero-frequency component to the center of the spectrum.
376
377  This function swaps half-spaces for all axes listed (defaults to all).
378  Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
379
380  @compatibility(numpy)
381  Equivalent to numpy.fft.fftshift.
382  https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html
383  @end_compatibility
384
385  For example:
386
387  ```python
388  x = tf.signal.fftshift([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.])
389  x.numpy() # array([-5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.])
390  ```
391
392  Args:
393    x: `Tensor`, input tensor.
394    axes: `int` or shape `tuple`, optional Axes over which to shift.  Default is
395      None, which shifts all axes.
396    name: An optional name for the operation.
397
398  Returns:
399    A `Tensor`, The shifted tensor.
400  """
401  with _ops.name_scope(name, "fftshift") as name:
402    x = _ops.convert_to_tensor(x)
403    if axes is None:
404      axes = tuple(range(x.shape.ndims))
405      shift = _array_ops.shape(x) // 2
406    elif isinstance(axes, int):
407      shift = _array_ops.shape(x)[axes] // 2
408    else:
409      rank = _array_ops.rank(x)
410      # allows negative axis
411      axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
412      shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
413
414    return manip_ops.roll(x, shift, axes, name)
415
416
417@tf_export("signal.ifftshift")
418@dispatch.add_dispatch_support
419def ifftshift(x, axes=None, name=None):
420  """The inverse of fftshift.
421
422  Although identical for even-length x,
423  the functions differ by one sample for odd-length x.
424
425  @compatibility(numpy)
426  Equivalent to numpy.fft.ifftshift.
427  https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html
428  @end_compatibility
429
430  For example:
431
432  ```python
433  x = tf.signal.ifftshift([[ 0.,  1.,  2.],[ 3.,  4., -4.],[-3., -2., -1.]])
434  x.numpy() # array([[ 4., -4.,  3.],[-2., -1., -3.],[ 1.,  2.,  0.]])
435  ```
436
437  Args:
438    x: `Tensor`, input tensor.
439    axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None,
440      which shifts all axes.
441    name: An optional name for the operation.
442
443  Returns:
444    A `Tensor`, The shifted tensor.
445  """
446  with _ops.name_scope(name, "ifftshift") as name:
447    x = _ops.convert_to_tensor(x)
448    if axes is None:
449      axes = tuple(range(x.shape.ndims))
450      shift = -(_array_ops.shape(x) // 2)
451    elif isinstance(axes, int):
452      shift = -(_array_ops.shape(x)[axes] // 2)
453    else:
454      rank = _array_ops.rank(x)
455      # allows negative axis
456      axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
457      shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
458
459    return manip_ops.roll(x, shift, axes, name)
460
461
462_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft))
463_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft))
464_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d))
465_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d))
466