1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for computing common window functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.util.tf_export import tf_export 31 32 33@tf_export('signal.hann_window') 34def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): 35 """Generate a [Hann window][hann]. 36 37 Args: 38 window_length: A scalar `Tensor` indicating the window length to generate. 39 periodic: A bool `Tensor` indicating whether to generate a periodic or 40 symmetric window. Periodic windows are typically used for spectral 41 analysis while symmetric windows are typically used for digital 42 filter design. 43 dtype: The data type to produce. Must be a floating point type. 44 name: An optional name for the operation. 45 46 Returns: 47 A `Tensor` of shape `[window_length]` of type `dtype`. 48 49 Raises: 50 ValueError: If `dtype` is not a floating point type. 51 52 [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows 53 """ 54 return _raised_cosine_window(name, 'hann_window', window_length, periodic, 55 dtype, 0.5, 0.5) 56 57 58@tf_export('signal.hamming_window') 59def hamming_window(window_length, periodic=True, dtype=dtypes.float32, 60 name=None): 61 """Generate a [Hamming][hamming] window. 62 63 Args: 64 window_length: A scalar `Tensor` indicating the window length to generate. 65 periodic: A bool `Tensor` indicating whether to generate a periodic or 66 symmetric window. Periodic windows are typically used for spectral 67 analysis while symmetric windows are typically used for digital 68 filter design. 69 dtype: The data type to produce. Must be a floating point type. 70 name: An optional name for the operation. 71 72 Returns: 73 A `Tensor` of shape `[window_length]` of type `dtype`. 74 75 Raises: 76 ValueError: If `dtype` is not a floating point type. 77 78 [hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows 79 """ 80 return _raised_cosine_window(name, 'hamming_window', window_length, periodic, 81 dtype, 0.54, 0.46) 82 83 84def _raised_cosine_window(name, default_name, window_length, periodic, 85 dtype, a, b): 86 """Helper function for computing a raised cosine window. 87 88 Args: 89 name: Name to use for the scope. 90 default_name: Default name to use for the scope. 91 window_length: A scalar `Tensor` or integer indicating the window length. 92 periodic: A bool `Tensor` indicating whether to generate a periodic or 93 symmetric window. 94 dtype: A floating point `DType`. 95 a: The alpha parameter to the raised cosine window. 96 b: The beta parameter to the raised cosine window. 97 98 Returns: 99 A `Tensor` of shape `[window_length]` of type `dtype`. 100 101 Raises: 102 ValueError: If `dtype` is not a floating point type or `window_length` is 103 not scalar or `periodic` is not scalar. 104 """ 105 if not dtype.is_floating: 106 raise ValueError('dtype must be a floating point type. Found %s' % dtype) 107 108 with ops.name_scope(name, default_name, [window_length, periodic]): 109 window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32, 110 name='window_length') 111 window_length.shape.assert_has_rank(0) 112 window_length_const = tensor_util.constant_value(window_length) 113 if window_length_const == 1: 114 return array_ops.ones([1], dtype=dtype) 115 periodic = math_ops.cast( 116 ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'), 117 dtypes.int32) 118 periodic.shape.assert_has_rank(0) 119 even = 1 - math_ops.mod(window_length, 2) 120 121 n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype) 122 count = math_ops.cast(math_ops.range(window_length), dtype) 123 cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n 124 125 if window_length_const is not None: 126 return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype) 127 return control_flow_ops.cond( 128 math_ops.equal(window_length, 1), 129 lambda: array_ops.ones([1], dtype=dtype), 130 lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)) 131