1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fast-Fourier Transform ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import dtypes as _dtypes
23from tensorflow.python.framework import ops as _ops
24from tensorflow.python.framework import tensor_util as _tensor_util
25from tensorflow.python.ops import array_ops as _array_ops
26from tensorflow.python.ops import gen_spectral_ops
27from tensorflow.python.ops import math_ops as _math_ops
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _infer_fft_length_for_rfft(input_tensor, fft_rank):
32  """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`."""
33  # A TensorShape for the inner fft_rank dimensions.
34  fft_shape = input_tensor.get_shape()[-fft_rank:]
35
36  # If any dim is unknown, fall back to tensor-based math.
37  if not fft_shape.is_fully_defined():
38    return _array_ops.shape(input_tensor)[-fft_rank:]
39
40  # Otherwise, return a constant.
41  return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32)
42
43
44def _infer_fft_length_for_irfft(input_tensor, fft_rank):
45  """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`."""
46  # A TensorShape for the inner fft_rank dimensions.
47  fft_shape = input_tensor.get_shape()[-fft_rank:]
48
49  # If any dim is unknown, fall back to tensor-based math.
50  if not fft_shape.is_fully_defined():
51    fft_length = _array_ops.unstack(_array_ops.shape(input_tensor)[-fft_rank:])
52    fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1))
53    return _array_ops.stack(fft_length)
54
55  # Otherwise, return a constant.
56  fft_length = fft_shape.as_list()
57  if fft_length:
58    fft_length[-1] = max(0, 2 * (fft_length[-1] - 1))
59  return _ops.convert_to_tensor(fft_length, _dtypes.int32)
60
61
62def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False):
63  """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims."""
64  fft_shape = _tensor_util.constant_value_as_shape(fft_length)
65
66  # Edge case: skip padding empty tensors.
67  if (input_tensor.shape.ndims is not None and
68      any(dim.value == 0 for dim in input_tensor.shape.dims)):
69    return input_tensor
70
71  # If we know the shapes ahead of time, we can either skip or pre-compute the
72  # appropriate paddings. Otherwise, fall back to computing paddings in
73  # TensorFlow.
74  if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None:
75    # Slice the last FFT-rank dimensions from input_tensor's shape.
76    input_fft_shape = input_tensor.shape[-fft_shape.ndims:]
77
78    if input_fft_shape.is_fully_defined():
79      # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
80      if is_reverse:
81        fft_shape = fft_shape[:-1].concatenate(
82            fft_shape.dims[-1].value // 2 + 1)
83
84      paddings = [[0, max(fft_dim.value - input_dim.value, 0)]
85                  for fft_dim, input_dim in zip(
86                      fft_shape.dims, input_fft_shape.dims)]
87      if any(pad > 0 for _, pad in paddings):
88        outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims -
89                                         fft_shape.ndims), 0)
90        return _array_ops.pad(input_tensor, outer_paddings + paddings)
91      return input_tensor
92
93  # If we can't determine the paddings ahead of time, then we have to pad. If
94  # the paddings end up as zero, tf.pad has a special-case that does no work.
95  input_rank = _array_ops.rank(input_tensor)
96  input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:]
97  outer_dims = _math_ops.maximum(0, input_rank - fft_rank)
98  outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype)
99  # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
100  if is_reverse:
101    fft_length = _array_ops.concat([fft_length[:-1],
102                                    fft_length[-1:] // 2 + 1], 0)
103  fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape)
104  paddings = _array_ops.concat([outer_paddings, fft_paddings], 0)
105  paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings],
106                              axis=1)
107  return _array_ops.pad(input_tensor, paddings)
108
109
110def _rfft_wrapper(fft_fn, fft_rank, default_name):
111  """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
112
113  def _rfft(input_tensor, fft_length=None, name=None):
114    """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
115    with _ops.name_scope(name, default_name,
116                         [input_tensor, fft_length]) as name:
117      input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
118      input_tensor.shape.with_rank_at_least(fft_rank)
119      if fft_length is None:
120        fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
121      else:
122        fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
123      input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
124      return fft_fn(input_tensor, fft_length, name)
125  _rfft.__doc__ = fft_fn.__doc__
126  return _rfft
127
128
129def _irfft_wrapper(ifft_fn, fft_rank, default_name):
130  """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument."""
131
132  def _irfft(input_tensor, fft_length=None, name=None):
133    """Wrapper irfft* that infers fft_length argument."""
134    with _ops.name_scope(name, default_name,
135                         [input_tensor, fft_length]) as name:
136      input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
137      input_tensor.shape.with_rank_at_least(fft_rank)
138      if fft_length is None:
139        fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
140      else:
141        fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
142      input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
143                                         is_reverse=True)
144      return ifft_fn(input_tensor, fft_length, name)
145  _irfft.__doc__ = ifft_fn.__doc__
146  return _irfft
147
148
149# FFT/IFFT 1/2/3D are exported via
150# third_party/tensorflow/core/api_def/python_api/
151fft = gen_spectral_ops.fft
152ifft = gen_spectral_ops.ifft
153fft2d = gen_spectral_ops.fft2d
154ifft2d = gen_spectral_ops.ifft2d
155fft3d = gen_spectral_ops.fft3d
156ifft3d = gen_spectral_ops.ifft3d
157rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
158tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(rfft)
159irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
160tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(irfft)
161rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
162tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(rfft2d)
163irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
164tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(irfft2d)
165rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
166tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(rfft3d)
167irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
168tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(irfft3d)
169
170
171def _fft_size_for_grad(grad, rank):
172  return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:])
173
174
175@_ops.RegisterGradient("FFT")
176def _fft_grad(_, grad):
177  size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype)
178  return ifft(grad) * size
179
180
181@_ops.RegisterGradient("IFFT")
182def _ifft_grad(_, grad):
183  rsize = _math_ops.cast(
184      1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype),
185      grad.dtype)
186  return fft(grad) * rsize
187
188
189@_ops.RegisterGradient("FFT2D")
190def _fft2d_grad(_, grad):
191  size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype)
192  return ifft2d(grad) * size
193
194
195@_ops.RegisterGradient("IFFT2D")
196def _ifft2d_grad(_, grad):
197  rsize = _math_ops.cast(
198      1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype),
199      grad.dtype)
200  return fft2d(grad) * rsize
201
202
203@_ops.RegisterGradient("FFT3D")
204def _fft3d_grad(_, grad):
205  size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype)
206  return ifft3d(grad) * size
207
208
209@_ops.RegisterGradient("IFFT3D")
210def _ifft3d_grad(_, grad):
211  rsize = _math_ops.cast(
212      1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype),
213      grad.dtype)
214  return fft3d(grad) * rsize
215
216
217def _rfft_grad_helper(rank, irfft_fn):
218  """Returns a gradient function for an RFFT of the provided rank."""
219  # Can't happen because we don't register a gradient for RFFT3D.
220  assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
221
222  def _grad(op, grad):
223    """A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
224    fft_length = op.inputs[1]
225    input_shape = _array_ops.shape(op.inputs[0])
226    is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64)
227
228    def _tile_for_broadcasting(matrix, t):
229      expanded = _array_ops.reshape(
230          matrix,
231          _array_ops.concat([
232              _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32),
233              _array_ops.shape(matrix)
234          ], 0))
235      return _array_ops.tile(
236          expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0))
237
238    def _mask_matrix(length):
239      """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len)."""
240      # TODO(rjryan): Speed up computation of twiddle factors using the
241      # following recurrence relation and cache them across invocations of RFFT.
242      #
243      # t_n = exp(sqrt(-1) * pi * n^2 / line_len)
244      # for n = 0, 1,..., line_len-1.
245      # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
246      a = _array_ops.tile(
247          _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
248      b = _array_ops.transpose(a, [1, 0])
249      return _math_ops.exp(
250          -2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) /
251          _math_ops.cast(length, _dtypes.complex64))
252
253    def _ymask(length):
254      """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
255      return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
256                            _dtypes.complex64)
257
258    y0 = grad[..., 0:1]
259    if rank == 1:
260      ym = grad[..., -1:]
261      extra_terms = y0 + is_even * ym * _ymask(input_shape[-1])
262    elif rank == 2:
263      # Create a mask matrix for y0 and ym.
264      base_mask = _mask_matrix(input_shape[-2])
265
266      # Tile base_mask to match y0 in shape so that we can batch-matmul the
267      # inner 2 dimensions.
268      tiled_mask = _tile_for_broadcasting(base_mask, y0)
269
270      y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0))
271      extra_terms = y0_term
272
273      ym = grad[..., -1:]
274      ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym))
275
276      inner_dim = input_shape[-1]
277      ym_term = _array_ops.tile(
278          ym_term,
279          _array_ops.concat([
280              _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32),
281              [inner_dim]
282          ], 0)) * _ymask(inner_dim)
283
284      extra_terms += is_even * ym_term
285
286    # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
287    # factor, plus some additional terms to make up for the components dropped
288    # due to Hermitian symmetry.
289    input_size = _math_ops.cast(
290        _fft_size_for_grad(op.inputs[0], rank), _dtypes.float32)
291    the_irfft = irfft_fn(grad, fft_length)
292    return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
293
294  return _grad
295
296
297def _irfft_grad_helper(rank, rfft_fn):
298  """Returns a gradient function for an IRFFT of the provided rank."""
299  # Can't happen because we don't register a gradient for IRFFT3D.
300  assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
301
302  def _grad(op, grad):
303    """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
304    # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
305    # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
306    # graph we special-case the situation where the FFT length and last
307    # dimension of the input are known at graph construction time.
308    fft_length = op.inputs[1]
309    is_odd = _math_ops.mod(fft_length[-1], 2)
310    input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
311    mask = _array_ops.concat(
312        [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]),
313         _array_ops.ones([1 - is_odd])], 0)
314
315    rsize = _math_ops.reciprocal(_math_ops.cast(
316        _fft_size_for_grad(grad, rank), _dtypes.float32))
317
318    # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
319    # factor and a mask. The mask scales the gradient for the Hermitian
320    # symmetric components of the RFFT by a factor of two, since these
321    # components are de-duplicated in the RFFT.
322    the_rfft = rfft_fn(grad, fft_length)
323    return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None
324
325  return _grad
326
327
328_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft))
329_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft))
330_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d))
331_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d))
332