1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""mel conversion ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.signal import shape_ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29# mel spectrum constants.
30_MEL_BREAK_FREQUENCY_HERTZ = 700.0
31_MEL_HIGH_FREQUENCY_Q = 1127.0
32
33
34def _mel_to_hertz(mel_values, name=None):
35  """Converts frequencies in `mel_values` from the mel scale to linear scale.
36
37  Args:
38    mel_values: A `Tensor` of frequencies in the mel scale.
39    name: An optional name for the operation.
40
41  Returns:
42    A `Tensor` of the same shape and type as `mel_values` containing linear
43    scale frequencies in Hertz.
44  """
45  with ops.name_scope(name, 'mel_to_hertz', [mel_values]):
46    mel_values = ops.convert_to_tensor(mel_values)
47    return _MEL_BREAK_FREQUENCY_HERTZ * (
48        math_ops.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
49    )
50
51
52def _hertz_to_mel(frequencies_hertz, name=None):
53  """Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.
54
55  Args:
56    frequencies_hertz: A `Tensor` of frequencies in Hertz.
57    name: An optional name for the operation.
58
59  Returns:
60    A `Tensor` of the same shape and type of `frequencies_hertz` containing
61    frequencies in the mel scale.
62  """
63  with ops.name_scope(name, 'hertz_to_mel', [frequencies_hertz]):
64    frequencies_hertz = ops.convert_to_tensor(frequencies_hertz)
65    return _MEL_HIGH_FREQUENCY_Q * math_ops.log(
66        1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
67
68
69def _validate_arguments(num_mel_bins, sample_rate,
70                        lower_edge_hertz, upper_edge_hertz, dtype):
71  """Checks the inputs to linear_to_mel_weight_matrix."""
72  if num_mel_bins <= 0:
73    raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
74  if sample_rate <= 0.0:
75    raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
76  if lower_edge_hertz < 0.0:
77    raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
78                     lower_edge_hertz)
79  if lower_edge_hertz >= upper_edge_hertz:
80    raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
81                     (lower_edge_hertz, upper_edge_hertz))
82  if upper_edge_hertz > sample_rate / 2:
83    raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
84                     'frequency (sample_rate / 2). Got: %s for sample_rate: %s'
85                     % (upper_edge_hertz, sample_rate))
86  if not dtype.is_floating:
87    raise ValueError('dtype must be a floating point type. Got: %s' % dtype)
88
89
90@tf_export('signal.linear_to_mel_weight_matrix')
91def linear_to_mel_weight_matrix(num_mel_bins=20,
92                                num_spectrogram_bins=129,
93                                sample_rate=8000,
94                                lower_edge_hertz=125.0,
95                                upper_edge_hertz=3800.0,
96                                dtype=dtypes.float32,
97                                name=None):
98  """Returns a matrix to warp linear scale spectrograms to the [mel scale][mel].
99
100  Returns a weight matrix that can be used to re-weight a `Tensor` containing
101  `num_spectrogram_bins` linearly sampled frequency information from
102  `[0, sample_rate / 2]` into `num_mel_bins` frequency information from
103  `[lower_edge_hertz, upper_edge_hertz]` on the [mel scale][mel].
104
105  For example, the returned matrix `A` can be used to right-multiply a
106  spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear
107  scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram"
108  `M` of shape `[frames, num_mel_bins]`.
109
110      # `S` has shape [frames, num_spectrogram_bins]
111      # `M` has shape [frames, num_mel_bins]
112      M = tf.matmul(S, A)
113
114  The matrix can be used with `tf.tensordot` to convert an arbitrary rank
115  `Tensor` of linear-scale spectral bins into the mel scale.
116
117      # S has shape [..., num_spectrogram_bins].
118      # M has shape [..., num_mel_bins].
119      M = tf.tensordot(S, A, 1)
120      # tf.tensordot does not support shape inference for this case yet.
121      M.set_shape(S.shape[:-1].concatenate(A.shape[-1:]))
122
123  Args:
124    num_mel_bins: Python int. How many bands in the resulting mel spectrum.
125    num_spectrogram_bins: An integer `Tensor`. How many bins there are in the
126      source spectrogram data, which is understood to be `fft_size // 2 + 1`,
127      i.e. the spectrogram only contains the nonredundant FFT bins.
128    sample_rate: Python float. Samples per second of the input signal used to
129      create the spectrogram. We need this to figure out the actual frequencies
130      for each spectrogram bin, which dictates how they are mapped into the mel
131      scale.
132    lower_edge_hertz: Python float. Lower bound on the frequencies to be
133      included in the mel spectrum. This corresponds to the lower edge of the
134      lowest triangular band.
135    upper_edge_hertz: Python float. The desired top edge of the highest
136      frequency band.
137    dtype: The `DType` of the result matrix. Must be a floating point type.
138    name: An optional name for the operation.
139
140  Returns:
141    A `Tensor` of shape `[num_spectrogram_bins, num_mel_bins]`.
142
143  Raises:
144    ValueError: If num_mel_bins/num_spectrogram_bins/sample_rate are not
145      positive, lower_edge_hertz is negative, frequency edges are incorrectly
146      ordered, or upper_edge_hertz is larger than the Nyquist frequency.
147
148  [mel]: https://en.wikipedia.org/wiki/Mel_scale
149  """
150  with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
151    # Note: As num_spectrogram_bins is passed to `math_ops.linspace`
152    # and the validation is already done in linspace (both in shape function
153    # and in kernel), there is no need to validate num_spectrogram_bins here.
154    _validate_arguments(num_mel_bins, sample_rate,
155                        lower_edge_hertz, upper_edge_hertz, dtype)
156
157    # This function can be constant folded by graph optimization since there are
158    # no Tensor inputs.
159    sample_rate = ops.convert_to_tensor(
160        sample_rate, dtype, name='sample_rate')
161    lower_edge_hertz = ops.convert_to_tensor(
162        lower_edge_hertz, dtype, name='lower_edge_hertz')
163    upper_edge_hertz = ops.convert_to_tensor(
164        upper_edge_hertz, dtype, name='upper_edge_hertz')
165    zero = ops.convert_to_tensor(0.0, dtype)
166
167    # HTK excludes the spectrogram DC bin.
168    bands_to_zero = 1
169    nyquist_hertz = sample_rate / 2.0
170    linear_frequencies = math_ops.linspace(
171        zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
172    spectrogram_bins_mel = array_ops.expand_dims(
173        _hertz_to_mel(linear_frequencies), 1)
174
175    # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
176    # center of each band is the lower and upper edge of the adjacent bands.
177    # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
178    # num_mel_bins + 2 pieces.
179    band_edges_mel = shape_ops.frame(
180        math_ops.linspace(_hertz_to_mel(lower_edge_hertz),
181                          _hertz_to_mel(upper_edge_hertz),
182                          num_mel_bins + 2), frame_length=3, frame_step=1)
183
184    # Split the triples up and reshape them into [1, num_mel_bins] tensors.
185    lower_edge_mel, center_mel, upper_edge_mel = tuple(array_ops.reshape(
186        t, [1, num_mel_bins]) for t in array_ops.split(
187            band_edges_mel, 3, axis=1))
188
189    # Calculate lower and upper slopes for every spectrogram bin.
190    # Line segments are linear in the mel domain, not Hertz.
191    lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
192        center_mel - lower_edge_mel)
193    upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
194        upper_edge_mel - center_mel)
195
196    # Intersect the line segments with each other and zero.
197    mel_weights_matrix = math_ops.maximum(
198        zero, math_ops.minimum(lower_slopes, upper_slopes))
199
200    # Re-add the zeroed lower bins we sliced out above.
201    return array_ops.pad(
202        mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name)
203