1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Discrete Cosine Transform ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import math as _math 21 22from tensorflow.python.framework import dtypes as _dtypes 23from tensorflow.python.framework import ops as _ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops as _array_ops 26from tensorflow.python.ops import math_ops as _math_ops 27from tensorflow.python.ops.signal import fft_ops 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm): 32 """Checks that DCT/IDCT arguments are compatible and well formed.""" 33 if n is not None: 34 raise NotImplementedError("The DCT length argument is not implemented.") 35 if axis != -1: 36 raise NotImplementedError("axis must be -1. Got: %s" % axis) 37 if dct_type not in (1, 2, 3): 38 raise ValueError("Only Types I, II and III (I)DCT are supported.") 39 if dct_type == 1: 40 if norm == "ortho": 41 raise ValueError("Normalization is not supported for the Type-I DCT.") 42 if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2: 43 raise ValueError( 44 "Type-I DCT requires the dimension to be greater than one.") 45 46 if norm not in (None, "ortho"): 47 raise ValueError( 48 "Unknown normalization. Expected None or 'ortho', got: %s" % norm) 49 50 51# TODO(rjryan): Implement `n` and `axis` parameters. 52@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"]) 53def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 54 """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. 55 56 Currently only Types I, II and III are supported. 57 Type I is implemented using a length `2N` padded `tf.spectral.rfft`. 58 Type II is implemented using a length `2N` padded `tf.spectral.rfft`, as 59 described here: 60 https://dsp.stackexchange.com/a/10606. 61 Type III is a fairly straightforward inverse of Type II 62 (i.e. using a length `2N` padded `tf.spectral.irfft`). 63 64 @compatibility(scipy) 65 Equivalent to scipy.fftpack.dct for Type-I, Type-II and Type-III DCT. 66 https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 67 @end_compatibility 68 69 Args: 70 input: A `[..., samples]` `float32` `Tensor` containing the signals to 71 take the DCT of. 72 type: The DCT type to perform. Must be 1, 2 or 3. 73 n: For future expansion. The length of the transform. Must be `None`. 74 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 75 norm: The normalization to apply. `None` for no normalization or `'ortho'` 76 for orthonormal normalization. 77 name: An optional name for the operation. 78 79 Returns: 80 A `[..., samples]` `float32` `Tensor` containing the DCT of `input`. 81 82 Raises: 83 ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is 84 not `-1`, or `norm` is not `None` or `'ortho'`. 85 ValueError: If `type` is `1` and `norm` is `ortho`. 86 87 [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform 88 """ 89 _validate_dct_arguments(input, type, n, axis, norm) 90 with _ops.name_scope(name, "dct", [input]): 91 # We use the RFFT to compute the DCT and TensorFlow only supports float32 92 # for FFTs at the moment. 93 input = _ops.convert_to_tensor(input, dtype=_dtypes.float32) 94 95 axis_dim = (tensor_shape.dimension_value(input.shape[-1]) 96 or _array_ops.shape(input)[-1]) 97 axis_dim_float = _math_ops.cast(axis_dim, _dtypes.float32) 98 99 if type == 1: 100 dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1) 101 dct1 = _math_ops.real(fft_ops.rfft(dct1_input)) 102 return dct1 103 104 if type == 2: 105 scale = 2.0 * _math_ops.exp( 106 _math_ops.complex( 107 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / 108 axis_dim_float)) 109 110 # TODO(rjryan): Benchmark performance and memory usage of the various 111 # approaches to computing a DCT via the RFFT. 112 dct2 = _math_ops.real( 113 fft_ops.rfft( 114 input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) 115 116 if norm == "ortho": 117 n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) 118 n2 = n1 * _math_ops.sqrt(2.0) 119 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 120 weights = _array_ops.pad( 121 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 122 constant_values=n2) 123 dct2 *= weights 124 125 return dct2 126 127 elif type == 3: 128 if norm == "ortho": 129 n1 = _math_ops.sqrt(axis_dim_float) 130 n2 = n1 * _math_ops.sqrt(0.5) 131 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 132 weights = _array_ops.pad( 133 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 134 constant_values=n2) 135 input *= weights 136 else: 137 input *= axis_dim_float 138 scale = 2.0 * _math_ops.exp( 139 _math_ops.complex( 140 0.0, 141 _math_ops.range(axis_dim_float) * _math.pi * 0.5 / 142 axis_dim_float)) 143 dct3 = _math_ops.real( 144 fft_ops.irfft( 145 scale * _math_ops.complex(input, 0.0), 146 fft_length=[2 * axis_dim]))[..., :axis_dim] 147 148 return dct3 149 150 151# TODO(rjryan): Implement `n` and `axis` parameters. 152@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) 153def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 154 """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. 155 156 Currently only Types I, II and III are supported. Type III is the inverse of 157 Type II, and vice versa. 158 159 Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is 160 not `'ortho'`. That is: 161 `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. 162 When `norm='ortho'`, we have: 163 `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. 164 165 @compatibility(scipy) 166 Equivalent to scipy.fftpack.idct for Type-I, Type-II and Type-III DCT. 167 https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html 168 @end_compatibility 169 170 Args: 171 input: A `[..., samples]` `float32` `Tensor` containing the signals to take 172 the DCT of. 173 type: The IDCT type to perform. Must be 1, 2 or 3. 174 n: For future expansion. The length of the transform. Must be `None`. 175 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 176 norm: The normalization to apply. `None` for no normalization or `'ortho'` 177 for orthonormal normalization. 178 name: An optional name for the operation. 179 180 Returns: 181 A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`. 182 183 Raises: 184 ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is 185 not `-1`, or `norm` is not `None` or `'ortho'`. 186 187 [idct]: 188 https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms 189 """ 190 _validate_dct_arguments(input, type, n, axis, norm) 191 inverse_type = {1: 1, 2: 3, 3: 2}[type] 192 return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) 193