1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Discrete Cosine Transform ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import math as _math
21
22from tensorflow.python.framework import dtypes as _dtypes
23from tensorflow.python.framework import ops as _ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops as _array_ops
26from tensorflow.python.ops import math_ops as _math_ops
27from tensorflow.python.ops.signal import fft_ops
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
32  """Checks that DCT/IDCT arguments are compatible and well formed."""
33  if n is not None:
34    raise NotImplementedError("The DCT length argument is not implemented.")
35  if axis != -1:
36    raise NotImplementedError("axis must be -1. Got: %s" % axis)
37  if dct_type not in (1, 2, 3):
38    raise ValueError("Only Types I, II and III (I)DCT are supported.")
39  if dct_type == 1:
40    if norm == "ortho":
41      raise ValueError("Normalization is not supported for the Type-I DCT.")
42    if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2:
43      raise ValueError(
44          "Type-I DCT requires the dimension to be greater than one.")
45
46  if norm not in (None, "ortho"):
47    raise ValueError(
48        "Unknown normalization. Expected None or 'ortho', got: %s" % norm)
49
50
51# TODO(rjryan): Implement `n` and `axis` parameters.
52@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"])
53def dct(input, type=2, n=None, axis=-1, norm=None, name=None):  # pylint: disable=redefined-builtin
54  """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
55
56  Currently only Types I, II and III are supported.
57  Type I is implemented using a length `2N` padded `tf.spectral.rfft`.
58  Type II is implemented using a length `2N` padded `tf.spectral.rfft`, as
59  described here:
60  https://dsp.stackexchange.com/a/10606.
61  Type III is a fairly straightforward inverse of Type II
62  (i.e. using a length `2N` padded `tf.spectral.irfft`).
63
64  @compatibility(scipy)
65  Equivalent to scipy.fftpack.dct for Type-I, Type-II and Type-III DCT.
66  https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
67  @end_compatibility
68
69  Args:
70    input: A `[..., samples]` `float32` `Tensor` containing the signals to
71      take the DCT of.
72    type: The DCT type to perform. Must be 1, 2 or 3.
73    n: For future expansion. The length of the transform. Must be `None`.
74    axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
75    norm: The normalization to apply. `None` for no normalization or `'ortho'`
76      for orthonormal normalization.
77    name: An optional name for the operation.
78
79  Returns:
80    A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
81
82  Raises:
83    ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
84      not `-1`, or `norm` is not `None` or `'ortho'`.
85    ValueError: If `type` is `1` and `norm` is `ortho`.
86
87  [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
88  """
89  _validate_dct_arguments(input, type, n, axis, norm)
90  with _ops.name_scope(name, "dct", [input]):
91    # We use the RFFT to compute the DCT and TensorFlow only supports float32
92    # for FFTs at the moment.
93    input = _ops.convert_to_tensor(input, dtype=_dtypes.float32)
94
95    axis_dim = (tensor_shape.dimension_value(input.shape[-1])
96                or _array_ops.shape(input)[-1])
97    axis_dim_float = _math_ops.cast(axis_dim, _dtypes.float32)
98
99    if type == 1:
100      dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1)
101      dct1 = _math_ops.real(fft_ops.rfft(dct1_input))
102      return dct1
103
104    if type == 2:
105      scale = 2.0 * _math_ops.exp(
106          _math_ops.complex(
107              0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
108              axis_dim_float))
109
110      # TODO(rjryan): Benchmark performance and memory usage of the various
111      # approaches to computing a DCT via the RFFT.
112      dct2 = _math_ops.real(
113          fft_ops.rfft(
114              input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
115
116      if norm == "ortho":
117        n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
118        n2 = n1 * _math_ops.sqrt(2.0)
119        # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
120        weights = _array_ops.pad(
121            _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
122            constant_values=n2)
123        dct2 *= weights
124
125      return dct2
126
127    elif type == 3:
128      if norm == "ortho":
129        n1 = _math_ops.sqrt(axis_dim_float)
130        n2 = n1 * _math_ops.sqrt(0.5)
131        # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
132        weights = _array_ops.pad(
133            _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
134            constant_values=n2)
135        input *= weights
136      else:
137        input *= axis_dim_float
138      scale = 2.0 * _math_ops.exp(
139          _math_ops.complex(
140              0.0,
141              _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
142              axis_dim_float))
143      dct3 = _math_ops.real(
144          fft_ops.irfft(
145              scale * _math_ops.complex(input, 0.0),
146              fft_length=[2 * axis_dim]))[..., :axis_dim]
147
148      return dct3
149
150
151# TODO(rjryan): Implement `n` and `axis` parameters.
152@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
153def idct(input, type=2, n=None, axis=-1, norm=None, name=None):  # pylint: disable=redefined-builtin
154  """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
155
156  Currently only Types I, II and III are supported. Type III is the inverse of
157  Type II, and vice versa.
158
159  Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
160  not `'ortho'`. That is:
161  `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
162  When `norm='ortho'`, we have:
163  `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
164
165  @compatibility(scipy)
166  Equivalent to scipy.fftpack.idct for Type-I, Type-II and Type-III DCT.
167  https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
168  @end_compatibility
169
170  Args:
171    input: A `[..., samples]` `float32` `Tensor` containing the signals to take
172      the DCT of.
173    type: The IDCT type to perform. Must be 1, 2 or 3.
174    n: For future expansion. The length of the transform. Must be `None`.
175    axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
176    norm: The normalization to apply. `None` for no normalization or `'ortho'`
177      for orthonormal normalization.
178    name: An optional name for the operation.
179
180  Returns:
181    A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
182
183  Raises:
184    ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
185      not `-1`, or `norm` is not `None` or `'ortho'`.
186
187  [idct]:
188  https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
189  """
190  _validate_dct_arguments(input, type, n, axis, norm)
191  inverse_type = {1: 1, 2: 3, 3: 2}[type]
192  return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
193