1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for computing common window functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.util.tf_export import tf_export
31
32
33@tf_export('signal.hann_window')
34def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
35  """Generate a [Hann window][hann].
36
37  Args:
38    window_length: A scalar `Tensor` indicating the window length to generate.
39    periodic: A bool `Tensor` indicating whether to generate a periodic or
40      symmetric window. Periodic windows are typically used for spectral
41      analysis while symmetric windows are typically used for digital
42      filter design.
43    dtype: The data type to produce. Must be a floating point type.
44    name: An optional name for the operation.
45
46  Returns:
47    A `Tensor` of shape `[window_length]` of type `dtype`.
48
49  Raises:
50    ValueError: If `dtype` is not a floating point type.
51
52  [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
53  """
54  return _raised_cosine_window(name, 'hann_window', window_length, periodic,
55                               dtype, 0.5, 0.5)
56
57
58@tf_export('signal.hamming_window')
59def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
60                   name=None):
61  """Generate a [Hamming][hamming] window.
62
63  Args:
64    window_length: A scalar `Tensor` indicating the window length to generate.
65    periodic: A bool `Tensor` indicating whether to generate a periodic or
66      symmetric window. Periodic windows are typically used for spectral
67      analysis while symmetric windows are typically used for digital
68      filter design.
69    dtype: The data type to produce. Must be a floating point type.
70    name: An optional name for the operation.
71
72  Returns:
73    A `Tensor` of shape `[window_length]` of type `dtype`.
74
75  Raises:
76    ValueError: If `dtype` is not a floating point type.
77
78  [hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
79  """
80  return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
81                               dtype, 0.54, 0.46)
82
83
84def _raised_cosine_window(name, default_name, window_length, periodic,
85                          dtype, a, b):
86  """Helper function for computing a raised cosine window.
87
88  Args:
89    name: Name to use for the scope.
90    default_name: Default name to use for the scope.
91    window_length: A scalar `Tensor` or integer indicating the window length.
92    periodic: A bool `Tensor` indicating whether to generate a periodic or
93      symmetric window.
94    dtype: A floating point `DType`.
95    a: The alpha parameter to the raised cosine window.
96    b: The beta parameter to the raised cosine window.
97
98  Returns:
99    A `Tensor` of shape `[window_length]` of type `dtype`.
100
101  Raises:
102    ValueError: If `dtype` is not a floating point type or `window_length` is
103      not scalar or `periodic` is not scalar.
104  """
105  if not dtype.is_floating:
106    raise ValueError('dtype must be a floating point type. Found %s' % dtype)
107
108  with ops.name_scope(name, default_name, [window_length, periodic]):
109    window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32,
110                                          name='window_length')
111    window_length.shape.assert_has_rank(0)
112    window_length_const = tensor_util.constant_value(window_length)
113    if window_length_const == 1:
114      return array_ops.ones([1], dtype=dtype)
115    periodic = math_ops.cast(
116        ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'),
117        dtypes.int32)
118    periodic.shape.assert_has_rank(0)
119    even = 1 - math_ops.mod(window_length, 2)
120
121    n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype)
122    count = math_ops.cast(math_ops.range(window_length), dtype)
123    cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n
124
125    if window_length_const is not None:
126      return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)
127    return control_flow_ops.cond(
128        math_ops.equal(window_length, 1),
129        lambda: array_ops.ones([1], dtype=dtype),
130        lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype))
131