1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Spectral operations (e.g. Short-time Fourier Transform)."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.signal import fft_ops
30from tensorflow.python.ops.signal import reconstruction_ops
31from tensorflow.python.ops.signal import shape_ops
32from tensorflow.python.ops.signal import window_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36@tf_export('signal.stft')
37def stft(signals, frame_length, frame_step, fft_length=None,
38         window_fn=window_ops.hann_window,
39         pad_end=False, name=None):
40  """Computes the [Short-time Fourier Transform][stft] of `signals`.
41
42  Implemented with GPU-compatible ops and supports gradients.
43
44  Args:
45    signals: A `[..., samples]` `float32` `Tensor` of real-valued signals.
46    frame_length: An integer scalar `Tensor`. The window length in samples.
47    frame_step: An integer scalar `Tensor`. The number of samples to step.
48    fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
49      If not provided, uses the smallest power of 2 enclosing `frame_length`.
50    window_fn: A callable that takes a window length and a `dtype` keyword
51      argument and returns a `[window_length]` `Tensor` of samples in the
52      provided datatype. If set to `None`, no windowing is used.
53    pad_end: Whether to pad the end of `signals` with zeros when the provided
54      frame length and step produces a frame that lies partially past its end.
55    name: An optional name for the operation.
56
57  Returns:
58    A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where
59    `fft_unique_bins` is `fft_length // 2 + 1` (the unique components of the
60    FFT).
61
62  Raises:
63    ValueError: If `signals` is not at least rank 1, `frame_length` is
64      not scalar, or `frame_step` is not scalar.
65
66  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
67  """
68  with ops.name_scope(name, 'stft', [signals, frame_length,
69                                     frame_step]):
70    signals = ops.convert_to_tensor(signals, name='signals')
71    signals.shape.with_rank_at_least(1)
72    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
73    frame_length.shape.assert_has_rank(0)
74    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
75    frame_step.shape.assert_has_rank(0)
76
77    if fft_length is None:
78      fft_length = _enclosing_power_of_two(frame_length)
79    else:
80      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
81
82    framed_signals = shape_ops.frame(
83        signals, frame_length, frame_step, pad_end=pad_end)
84
85    # Optionally window the framed signals.
86    if window_fn is not None:
87      window = window_fn(frame_length, dtype=framed_signals.dtype)
88      framed_signals *= window
89
90    # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the
91    # FFT of the real windowed signals in framed_signals.
92    return fft_ops.rfft(framed_signals, [fft_length])
93
94
95@tf_export('signal.inverse_stft_window_fn')
96def inverse_stft_window_fn(frame_step,
97                           forward_window_fn=window_ops.hann_window,
98                           name=None):
99  """Generates a window function that can be used in `inverse_stft`.
100
101  Constructs a window that is equal to the forward window with a further
102  pointwise amplitude correction.  `inverse_stft_window_fn` is equivalent to
103  `forward_window_fn` in the case where it would produce an exact inverse.
104
105  See examples in `inverse_stft` documentation for usage.
106
107  Args:
108    frame_step: An integer scalar `Tensor`. The number of samples to step.
109    forward_window_fn: window_fn used in the forward transform, `stft`.
110    name: An optional name for the operation.
111
112  Returns:
113    A callable that takes a window length and a `dtype` keyword argument and
114      returns a `[window_length]` `Tensor` of samples in the provided datatype.
115      The returned window is suitable for reconstructing original waveform in
116      inverse_stft.
117  """
118  with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
119    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
120    frame_step.shape.assert_has_rank(0)
121
122  def inverse_stft_window_fn_inner(frame_length, dtype):
123    """Computes a window that can be used in `inverse_stft`.
124
125    Args:
126      frame_length: An integer scalar `Tensor`. The window length in samples.
127      dtype: Data type of waveform passed to `stft`.
128
129    Returns:
130      A window suitable for reconstructing original waveform in `inverse_stft`.
131
132    Raises:
133      ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a
134      callable that takes a window length and a `dtype` keyword argument and
135      returns a `[window_length]` `Tensor` of samples in the provided datatype
136      `frame_step` is not scalar, or `frame_step` is not scalar.
137    """
138    with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
139      frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
140      frame_length.shape.assert_has_rank(0)
141
142      # Use equation 7 from Griffin + Lim.
143      forward_window = forward_window_fn(frame_length, dtype=dtype)
144      denom = math_ops.square(forward_window)
145      overlaps = -(-frame_length // frame_step)  # Ceiling division.
146      denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
147      denom = array_ops.reshape(denom, [overlaps, frame_step])
148      denom = math_ops.reduce_sum(denom, 0, keepdims=True)
149      denom = array_ops.tile(denom, [overlaps, 1])
150      denom = array_ops.reshape(denom, [overlaps * frame_step])
151
152      return forward_window / denom[:frame_length]
153  return inverse_stft_window_fn_inner
154
155
156@tf_export('signal.inverse_stft')
157def inverse_stft(stfts,
158                 frame_length,
159                 frame_step,
160                 fft_length=None,
161                 window_fn=window_ops.hann_window,
162                 name=None):
163  """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.
164
165  To reconstruct an original waveform, a complimentary window function should
166  be used in inverse_stft. Such a window function can be constructed with
167  tf.signal.inverse_stft_window_fn.
168
169  Example:
170
171  ```python
172  frame_length = 400
173  frame_step = 160
174  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
175  stft = tf.signal.stft(waveform, frame_length, frame_step)
176  inverse_stft = tf.signal.inverse_stft(
177      stft, frame_length, frame_step,
178      window_fn=tf.signal.inverse_stft_window_fn(frame_step))
179  ```
180
181  if a custom window_fn is used in stft, it must be passed to
182  inverse_stft_window_fn:
183
184  ```python
185  frame_length = 400
186  frame_step = 160
187  window_fn = functools.partial(window_ops.hamming_window, periodic=True),
188  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
189  stft = tf.signal.stft(
190      waveform, frame_length, frame_step, window_fn=window_fn)
191  inverse_stft = tf.signal.inverse_stft(
192      stft, frame_length, frame_step,
193      window_fn=tf.signal.inverse_stft_window_fn(
194         frame_step, forward_window_fn=window_fn))
195  ```
196
197  Implemented with GPU-compatible ops and supports gradients.
198
199  Args:
200    stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
201      representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
202      is `fft_length // 2 + 1`
203    frame_length: An integer scalar `Tensor`. The window length in samples.
204    frame_step: An integer scalar `Tensor`. The number of samples to step.
205    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
206      `stfts`. If not provided, uses the smallest power of 2 enclosing
207      `frame_length`.
208    window_fn: A callable that takes a window length and a `dtype` keyword
209      argument and returns a `[window_length]` `Tensor` of samples in the
210      provided datatype. If set to `None`, no windowing is used.
211    name: An optional name for the operation.
212
213  Returns:
214    A `[..., samples]` `Tensor` of `float32` signals representing the inverse
215    STFT for each input STFT in `stfts`.
216
217  Raises:
218    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
219      `frame_step` is not scalar, or `fft_length` is not scalar.
220
221  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
222  """
223  with ops.name_scope(name, 'inverse_stft', [stfts]):
224    stfts = ops.convert_to_tensor(stfts, name='stfts')
225    stfts.shape.with_rank_at_least(2)
226    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
227    frame_length.shape.assert_has_rank(0)
228    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
229    frame_step.shape.assert_has_rank(0)
230    if fft_length is None:
231      fft_length = _enclosing_power_of_two(frame_length)
232    else:
233      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
234      fft_length.shape.assert_has_rank(0)
235
236    real_frames = fft_ops.irfft(stfts, [fft_length])
237
238    # frame_length may be larger or smaller than fft_length, so we pad or
239    # truncate real_frames to frame_length.
240    frame_length_static = tensor_util.constant_value(frame_length)
241    # If we don't know the shape of real_frames's inner dimension, pad and
242    # truncate to frame_length.
243    if (frame_length_static is None or
244        real_frames.shape.ndims is None or
245        real_frames.shape[-1].value is None):
246      real_frames = real_frames[..., :frame_length]
247      real_frames_rank = array_ops.rank(real_frames)
248      real_frames_shape = array_ops.shape(real_frames)
249      paddings = array_ops.concat(
250          [array_ops.zeros([real_frames_rank - 1, 2],
251                           dtype=frame_length.dtype),
252           [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0)
253      real_frames = array_ops.pad(real_frames, paddings)
254    # We know real_frames's last dimension and frame_length statically. If they
255    # are different, then pad or truncate real_frames to frame_length.
256    elif real_frames.shape[-1].value > frame_length_static:
257      real_frames = real_frames[..., :frame_length_static]
258    elif real_frames.shape[-1].value < frame_length_static:
259      pad_amount = frame_length_static - real_frames.shape[-1].value
260      real_frames = array_ops.pad(real_frames,
261                                  [[0, 0]] * (real_frames.shape.ndims - 1) +
262                                  [[0, pad_amount]])
263
264    # The above code pads the inner dimension of real_frames to frame_length,
265    # but it does so in a way that may not be shape-inference friendly.
266    # Restore shape information if we are able to.
267    if frame_length_static is not None and real_frames.shape.ndims is not None:
268      real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
269                            [frame_length_static])
270
271    # Optionally window and overlap-add the inner 2 dimensions of real_frames
272    # into a single [samples] dimension.
273    if window_fn is not None:
274      window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
275      real_frames *= window
276    return reconstruction_ops.overlap_and_add(real_frames, frame_step)
277
278
279def _enclosing_power_of_two(value):
280  """Return 2**N for integer N such that 2**N >= value."""
281  value_static = tensor_util.constant_value(value)
282  if value_static is not None:
283    return constant_op.constant(
284        int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype)
285  return math_ops.cast(
286      math_ops.pow(
287          2.0,
288          math_ops.ceil(
289              math_ops.log(math_ops.cast(value, dtypes.float32)) /
290              math_ops.log(2.0))), value.dtype)
291