1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fast-Fourier Transform ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import dtypes as _dtypes 23from tensorflow.python.framework import ops as _ops 24from tensorflow.python.framework import tensor_util as _tensor_util 25from tensorflow.python.ops import array_ops as _array_ops 26from tensorflow.python.ops import gen_spectral_ops 27from tensorflow.python.ops import math_ops as _math_ops 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _infer_fft_length_for_rfft(input_tensor, fft_rank): 32 """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`.""" 33 # A TensorShape for the inner fft_rank dimensions. 34 fft_shape = input_tensor.get_shape()[-fft_rank:] 35 36 # If any dim is unknown, fall back to tensor-based math. 37 if not fft_shape.is_fully_defined(): 38 return _array_ops.shape(input_tensor)[-fft_rank:] 39 40 # Otherwise, return a constant. 41 return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32) 42 43 44def _infer_fft_length_for_irfft(input_tensor, fft_rank): 45 """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`.""" 46 # A TensorShape for the inner fft_rank dimensions. 47 fft_shape = input_tensor.get_shape()[-fft_rank:] 48 49 # If any dim is unknown, fall back to tensor-based math. 50 if not fft_shape.is_fully_defined(): 51 fft_length = _array_ops.unstack(_array_ops.shape(input_tensor)[-fft_rank:]) 52 fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1)) 53 return _array_ops.stack(fft_length) 54 55 # Otherwise, return a constant. 56 fft_length = fft_shape.as_list() 57 if fft_length: 58 fft_length[-1] = max(0, 2 * (fft_length[-1] - 1)) 59 return _ops.convert_to_tensor(fft_length, _dtypes.int32) 60 61 62def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): 63 """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" 64 fft_shape = _tensor_util.constant_value_as_shape(fft_length) 65 66 # Edge case: skip padding empty tensors. 67 if (input_tensor.shape.ndims is not None and 68 any(dim.value == 0 for dim in input_tensor.shape.dims)): 69 return input_tensor 70 71 # If we know the shapes ahead of time, we can either skip or pre-compute the 72 # appropriate paddings. Otherwise, fall back to computing paddings in 73 # TensorFlow. 74 if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: 75 # Slice the last FFT-rank dimensions from input_tensor's shape. 76 input_fft_shape = input_tensor.shape[-fft_shape.ndims:] 77 78 if input_fft_shape.is_fully_defined(): 79 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 80 if is_reverse: 81 fft_shape = fft_shape[:-1].concatenate( 82 fft_shape.dims[-1].value // 2 + 1) 83 84 paddings = [[0, max(fft_dim.value - input_dim.value, 0)] 85 for fft_dim, input_dim in zip( 86 fft_shape.dims, input_fft_shape.dims)] 87 if any(pad > 0 for _, pad in paddings): 88 outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - 89 fft_shape.ndims), 0) 90 return _array_ops.pad(input_tensor, outer_paddings + paddings) 91 return input_tensor 92 93 # If we can't determine the paddings ahead of time, then we have to pad. If 94 # the paddings end up as zero, tf.pad has a special-case that does no work. 95 input_rank = _array_ops.rank(input_tensor) 96 input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] 97 outer_dims = _math_ops.maximum(0, input_rank - fft_rank) 98 outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) 99 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 100 if is_reverse: 101 fft_length = _array_ops.concat([fft_length[:-1], 102 fft_length[-1:] // 2 + 1], 0) 103 fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) 104 paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) 105 paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings], 106 axis=1) 107 return _array_ops.pad(input_tensor, paddings) 108 109 110def _rfft_wrapper(fft_fn, fft_rank, default_name): 111 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 112 113 def _rfft(input_tensor, fft_length=None, name=None): 114 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 115 with _ops.name_scope(name, default_name, 116 [input_tensor, fft_length]) as name: 117 input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32) 118 input_tensor.shape.with_rank_at_least(fft_rank) 119 if fft_length is None: 120 fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) 121 else: 122 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 123 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) 124 return fft_fn(input_tensor, fft_length, name) 125 _rfft.__doc__ = fft_fn.__doc__ 126 return _rfft 127 128 129def _irfft_wrapper(ifft_fn, fft_rank, default_name): 130 """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument.""" 131 132 def _irfft(input_tensor, fft_length=None, name=None): 133 """Wrapper irfft* that infers fft_length argument.""" 134 with _ops.name_scope(name, default_name, 135 [input_tensor, fft_length]) as name: 136 input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64) 137 input_tensor.shape.with_rank_at_least(fft_rank) 138 if fft_length is None: 139 fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) 140 else: 141 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 142 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, 143 is_reverse=True) 144 return ifft_fn(input_tensor, fft_length, name) 145 _irfft.__doc__ = ifft_fn.__doc__ 146 return _irfft 147 148 149# FFT/IFFT 1/2/3D are exported via 150# third_party/tensorflow/core/api_def/python_api/ 151fft = gen_spectral_ops.fft 152ifft = gen_spectral_ops.ifft 153fft2d = gen_spectral_ops.fft2d 154ifft2d = gen_spectral_ops.ifft2d 155fft3d = gen_spectral_ops.fft3d 156ifft3d = gen_spectral_ops.ifft3d 157rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") 158tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(rfft) 159irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") 160tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(irfft) 161rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") 162tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(rfft2d) 163irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") 164tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(irfft2d) 165rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") 166tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(rfft3d) 167irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") 168tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(irfft3d) 169 170 171def _fft_size_for_grad(grad, rank): 172 return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:]) 173 174 175@_ops.RegisterGradient("FFT") 176def _fft_grad(_, grad): 177 size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype) 178 return ifft(grad) * size 179 180 181@_ops.RegisterGradient("IFFT") 182def _ifft_grad(_, grad): 183 rsize = _math_ops.cast( 184 1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype), 185 grad.dtype) 186 return fft(grad) * rsize 187 188 189@_ops.RegisterGradient("FFT2D") 190def _fft2d_grad(_, grad): 191 size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype) 192 return ifft2d(grad) * size 193 194 195@_ops.RegisterGradient("IFFT2D") 196def _ifft2d_grad(_, grad): 197 rsize = _math_ops.cast( 198 1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype), 199 grad.dtype) 200 return fft2d(grad) * rsize 201 202 203@_ops.RegisterGradient("FFT3D") 204def _fft3d_grad(_, grad): 205 size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype) 206 return ifft3d(grad) * size 207 208 209@_ops.RegisterGradient("IFFT3D") 210def _ifft3d_grad(_, grad): 211 rsize = _math_ops.cast( 212 1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype), 213 grad.dtype) 214 return fft3d(grad) * rsize 215 216 217def _rfft_grad_helper(rank, irfft_fn): 218 """Returns a gradient function for an RFFT of the provided rank.""" 219 # Can't happen because we don't register a gradient for RFFT3D. 220 assert rank in (1, 2), "Gradient for RFFT3D is not implemented." 221 222 def _grad(op, grad): 223 """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" 224 fft_length = op.inputs[1] 225 input_shape = _array_ops.shape(op.inputs[0]) 226 is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64) 227 228 def _tile_for_broadcasting(matrix, t): 229 expanded = _array_ops.reshape( 230 matrix, 231 _array_ops.concat([ 232 _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32), 233 _array_ops.shape(matrix) 234 ], 0)) 235 return _array_ops.tile( 236 expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0)) 237 238 def _mask_matrix(length): 239 """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len).""" 240 # TODO(rjryan): Speed up computation of twiddle factors using the 241 # following recurrence relation and cache them across invocations of RFFT. 242 # 243 # t_n = exp(sqrt(-1) * pi * n^2 / line_len) 244 # for n = 0, 1,..., line_len-1. 245 # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2 246 a = _array_ops.tile( 247 _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1)) 248 b = _array_ops.transpose(a, [1, 0]) 249 return _math_ops.exp( 250 -2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) / 251 _math_ops.cast(length, _dtypes.complex64)) 252 253 def _ymask(length): 254 """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" 255 return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2), 256 _dtypes.complex64) 257 258 y0 = grad[..., 0:1] 259 if rank == 1: 260 ym = grad[..., -1:] 261 extra_terms = y0 + is_even * ym * _ymask(input_shape[-1]) 262 elif rank == 2: 263 # Create a mask matrix for y0 and ym. 264 base_mask = _mask_matrix(input_shape[-2]) 265 266 # Tile base_mask to match y0 in shape so that we can batch-matmul the 267 # inner 2 dimensions. 268 tiled_mask = _tile_for_broadcasting(base_mask, y0) 269 270 y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0)) 271 extra_terms = y0_term 272 273 ym = grad[..., -1:] 274 ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym)) 275 276 inner_dim = input_shape[-1] 277 ym_term = _array_ops.tile( 278 ym_term, 279 _array_ops.concat([ 280 _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32), 281 [inner_dim] 282 ], 0)) * _ymask(inner_dim) 283 284 extra_terms += is_even * ym_term 285 286 # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling 287 # factor, plus some additional terms to make up for the components dropped 288 # due to Hermitian symmetry. 289 input_size = _math_ops.cast( 290 _fft_size_for_grad(op.inputs[0], rank), _dtypes.float32) 291 the_irfft = irfft_fn(grad, fft_length) 292 return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None 293 294 return _grad 295 296 297def _irfft_grad_helper(rank, rfft_fn): 298 """Returns a gradient function for an IRFFT of the provided rank.""" 299 # Can't happen because we don't register a gradient for IRFFT3D. 300 assert rank in (1, 2), "Gradient for IRFFT3D is not implemented." 301 302 def _grad(op, grad): 303 """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" 304 # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs 305 # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the 306 # graph we special-case the situation where the FFT length and last 307 # dimension of the input are known at graph construction time. 308 fft_length = op.inputs[1] 309 is_odd = _math_ops.mod(fft_length[-1], 2) 310 input_last_dimension = _array_ops.shape(op.inputs[0])[-1] 311 mask = _array_ops.concat( 312 [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]), 313 _array_ops.ones([1 - is_odd])], 0) 314 315 rsize = _math_ops.reciprocal(_math_ops.cast( 316 _fft_size_for_grad(grad, rank), _dtypes.float32)) 317 318 # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling 319 # factor and a mask. The mask scales the gradient for the Hermitian 320 # symmetric components of the RFFT by a factor of two, since these 321 # components are de-duplicated in the RFFT. 322 the_rfft = rfft_fn(grad, fft_length) 323 return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None 324 325 return _grad 326 327 328_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft)) 329_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft)) 330_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d)) 331_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d)) 332