1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Spectral operations (e.g. Short-time Fourier Transform).""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.signal import fft_ops 30from tensorflow.python.ops.signal import reconstruction_ops 31from tensorflow.python.ops.signal import shape_ops 32from tensorflow.python.ops.signal import window_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36@tf_export('signal.stft') 37def stft(signals, frame_length, frame_step, fft_length=None, 38 window_fn=window_ops.hann_window, 39 pad_end=False, name=None): 40 """Computes the [Short-time Fourier Transform][stft] of `signals`. 41 42 Implemented with GPU-compatible ops and supports gradients. 43 44 Args: 45 signals: A `[..., samples]` `float32` `Tensor` of real-valued signals. 46 frame_length: An integer scalar `Tensor`. The window length in samples. 47 frame_step: An integer scalar `Tensor`. The number of samples to step. 48 fft_length: An integer scalar `Tensor`. The size of the FFT to apply. 49 If not provided, uses the smallest power of 2 enclosing `frame_length`. 50 window_fn: A callable that takes a window length and a `dtype` keyword 51 argument and returns a `[window_length]` `Tensor` of samples in the 52 provided datatype. If set to `None`, no windowing is used. 53 pad_end: Whether to pad the end of `signals` with zeros when the provided 54 frame length and step produces a frame that lies partially past its end. 55 name: An optional name for the operation. 56 57 Returns: 58 A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where 59 `fft_unique_bins` is `fft_length // 2 + 1` (the unique components of the 60 FFT). 61 62 Raises: 63 ValueError: If `signals` is not at least rank 1, `frame_length` is 64 not scalar, or `frame_step` is not scalar. 65 66 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform 67 """ 68 with ops.name_scope(name, 'stft', [signals, frame_length, 69 frame_step]): 70 signals = ops.convert_to_tensor(signals, name='signals') 71 signals.shape.with_rank_at_least(1) 72 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 73 frame_length.shape.assert_has_rank(0) 74 frame_step = ops.convert_to_tensor(frame_step, name='frame_step') 75 frame_step.shape.assert_has_rank(0) 76 77 if fft_length is None: 78 fft_length = _enclosing_power_of_two(frame_length) 79 else: 80 fft_length = ops.convert_to_tensor(fft_length, name='fft_length') 81 82 framed_signals = shape_ops.frame( 83 signals, frame_length, frame_step, pad_end=pad_end) 84 85 # Optionally window the framed signals. 86 if window_fn is not None: 87 window = window_fn(frame_length, dtype=framed_signals.dtype) 88 framed_signals *= window 89 90 # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the 91 # FFT of the real windowed signals in framed_signals. 92 return fft_ops.rfft(framed_signals, [fft_length]) 93 94 95@tf_export('signal.inverse_stft_window_fn') 96def inverse_stft_window_fn(frame_step, 97 forward_window_fn=window_ops.hann_window, 98 name=None): 99 """Generates a window function that can be used in `inverse_stft`. 100 101 Constructs a window that is equal to the forward window with a further 102 pointwise amplitude correction. `inverse_stft_window_fn` is equivalent to 103 `forward_window_fn` in the case where it would produce an exact inverse. 104 105 See examples in `inverse_stft` documentation for usage. 106 107 Args: 108 frame_step: An integer scalar `Tensor`. The number of samples to step. 109 forward_window_fn: window_fn used in the forward transform, `stft`. 110 name: An optional name for the operation. 111 112 Returns: 113 A callable that takes a window length and a `dtype` keyword argument and 114 returns a `[window_length]` `Tensor` of samples in the provided datatype. 115 The returned window is suitable for reconstructing original waveform in 116 inverse_stft. 117 """ 118 with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]): 119 frame_step = ops.convert_to_tensor(frame_step, name='frame_step') 120 frame_step.shape.assert_has_rank(0) 121 122 def inverse_stft_window_fn_inner(frame_length, dtype): 123 """Computes a window that can be used in `inverse_stft`. 124 125 Args: 126 frame_length: An integer scalar `Tensor`. The window length in samples. 127 dtype: Data type of waveform passed to `stft`. 128 129 Returns: 130 A window suitable for reconstructing original waveform in `inverse_stft`. 131 132 Raises: 133 ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a 134 callable that takes a window length and a `dtype` keyword argument and 135 returns a `[window_length]` `Tensor` of samples in the provided datatype 136 `frame_step` is not scalar, or `frame_step` is not scalar. 137 """ 138 with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]): 139 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 140 frame_length.shape.assert_has_rank(0) 141 142 # Use equation 7 from Griffin + Lim. 143 forward_window = forward_window_fn(frame_length, dtype=dtype) 144 denom = math_ops.square(forward_window) 145 overlaps = -(-frame_length // frame_step) # Ceiling division. 146 denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)]) 147 denom = array_ops.reshape(denom, [overlaps, frame_step]) 148 denom = math_ops.reduce_sum(denom, 0, keepdims=True) 149 denom = array_ops.tile(denom, [overlaps, 1]) 150 denom = array_ops.reshape(denom, [overlaps * frame_step]) 151 152 return forward_window / denom[:frame_length] 153 return inverse_stft_window_fn_inner 154 155 156@tf_export('signal.inverse_stft') 157def inverse_stft(stfts, 158 frame_length, 159 frame_step, 160 fft_length=None, 161 window_fn=window_ops.hann_window, 162 name=None): 163 """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`. 164 165 To reconstruct an original waveform, a complimentary window function should 166 be used in inverse_stft. Such a window function can be constructed with 167 tf.signal.inverse_stft_window_fn. 168 169 Example: 170 171 ```python 172 frame_length = 400 173 frame_step = 160 174 waveform = tf.placeholder(dtype=tf.float32, shape=[1000]) 175 stft = tf.signal.stft(waveform, frame_length, frame_step) 176 inverse_stft = tf.signal.inverse_stft( 177 stft, frame_length, frame_step, 178 window_fn=tf.signal.inverse_stft_window_fn(frame_step)) 179 ``` 180 181 if a custom window_fn is used in stft, it must be passed to 182 inverse_stft_window_fn: 183 184 ```python 185 frame_length = 400 186 frame_step = 160 187 window_fn = functools.partial(window_ops.hamming_window, periodic=True), 188 waveform = tf.placeholder(dtype=tf.float32, shape=[1000]) 189 stft = tf.signal.stft( 190 waveform, frame_length, frame_step, window_fn=window_fn) 191 inverse_stft = tf.signal.inverse_stft( 192 stft, frame_length, frame_step, 193 window_fn=tf.signal.inverse_stft_window_fn( 194 frame_step, forward_window_fn=window_fn)) 195 ``` 196 197 Implemented with GPU-compatible ops and supports gradients. 198 199 Args: 200 stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins 201 representing a batch of `fft_length`-point STFTs where `fft_unique_bins` 202 is `fft_length // 2 + 1` 203 frame_length: An integer scalar `Tensor`. The window length in samples. 204 frame_step: An integer scalar `Tensor`. The number of samples to step. 205 fft_length: An integer scalar `Tensor`. The size of the FFT that produced 206 `stfts`. If not provided, uses the smallest power of 2 enclosing 207 `frame_length`. 208 window_fn: A callable that takes a window length and a `dtype` keyword 209 argument and returns a `[window_length]` `Tensor` of samples in the 210 provided datatype. If set to `None`, no windowing is used. 211 name: An optional name for the operation. 212 213 Returns: 214 A `[..., samples]` `Tensor` of `float32` signals representing the inverse 215 STFT for each input STFT in `stfts`. 216 217 Raises: 218 ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, 219 `frame_step` is not scalar, or `fft_length` is not scalar. 220 221 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform 222 """ 223 with ops.name_scope(name, 'inverse_stft', [stfts]): 224 stfts = ops.convert_to_tensor(stfts, name='stfts') 225 stfts.shape.with_rank_at_least(2) 226 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 227 frame_length.shape.assert_has_rank(0) 228 frame_step = ops.convert_to_tensor(frame_step, name='frame_step') 229 frame_step.shape.assert_has_rank(0) 230 if fft_length is None: 231 fft_length = _enclosing_power_of_two(frame_length) 232 else: 233 fft_length = ops.convert_to_tensor(fft_length, name='fft_length') 234 fft_length.shape.assert_has_rank(0) 235 236 real_frames = fft_ops.irfft(stfts, [fft_length]) 237 238 # frame_length may be larger or smaller than fft_length, so we pad or 239 # truncate real_frames to frame_length. 240 frame_length_static = tensor_util.constant_value(frame_length) 241 # If we don't know the shape of real_frames's inner dimension, pad and 242 # truncate to frame_length. 243 if (frame_length_static is None or 244 real_frames.shape.ndims is None or 245 real_frames.shape[-1].value is None): 246 real_frames = real_frames[..., :frame_length] 247 real_frames_rank = array_ops.rank(real_frames) 248 real_frames_shape = array_ops.shape(real_frames) 249 paddings = array_ops.concat( 250 [array_ops.zeros([real_frames_rank - 1, 2], 251 dtype=frame_length.dtype), 252 [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0) 253 real_frames = array_ops.pad(real_frames, paddings) 254 # We know real_frames's last dimension and frame_length statically. If they 255 # are different, then pad or truncate real_frames to frame_length. 256 elif real_frames.shape[-1].value > frame_length_static: 257 real_frames = real_frames[..., :frame_length_static] 258 elif real_frames.shape[-1].value < frame_length_static: 259 pad_amount = frame_length_static - real_frames.shape[-1].value 260 real_frames = array_ops.pad(real_frames, 261 [[0, 0]] * (real_frames.shape.ndims - 1) + 262 [[0, pad_amount]]) 263 264 # The above code pads the inner dimension of real_frames to frame_length, 265 # but it does so in a way that may not be shape-inference friendly. 266 # Restore shape information if we are able to. 267 if frame_length_static is not None and real_frames.shape.ndims is not None: 268 real_frames.set_shape([None] * (real_frames.shape.ndims - 1) + 269 [frame_length_static]) 270 271 # Optionally window and overlap-add the inner 2 dimensions of real_frames 272 # into a single [samples] dimension. 273 if window_fn is not None: 274 window = window_fn(frame_length, dtype=stfts.dtype.real_dtype) 275 real_frames *= window 276 return reconstruction_ops.overlap_and_add(real_frames, frame_step) 277 278 279def _enclosing_power_of_two(value): 280 """Return 2**N for integer N such that 2**N >= value.""" 281 value_static = tensor_util.constant_value(value) 282 if value_static is not None: 283 return constant_op.constant( 284 int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype) 285 return math_ops.cast( 286 math_ops.pow( 287 2.0, 288 math_ops.ceil( 289 math_ops.log(math_ops.cast(value, dtypes.float32)) / 290 math_ops.log(2.0))), value.dtype) 291