1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DCT operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22import itertools
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops.signal import dct_ops
29from tensorflow.python.platform import test
30from tensorflow.python.platform import tf_logging
31
32
33def try_import(name):  # pylint: disable=invalid-name
34  module = None
35  try:
36    module = importlib.import_module(name)
37  except ImportError as e:
38    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
39  return module
40
41
42fftpack = try_import("scipy.fftpack")
43
44
45def _modify_input_for_dct(signals, n=None):
46  """Pad or trim the provided NumPy array's innermost axis to length n."""
47  signal = np.array(signals)
48  if n is None or n == signal.shape[-1]:
49    signal_mod = signal
50  elif n >= 1:
51    signal_len = signal.shape[-1]
52    if n <= signal_len:
53      signal_mod = signal[..., 0:n]
54    else:
55      output_shape = list(signal.shape)
56      output_shape[-1] = n
57      signal_mod = np.zeros(output_shape)
58      signal_mod[..., 0:signal.shape[-1]] = signal
59  if n:
60    assert signal_mod.shape[-1] == n
61  return signal_mod
62
63
64def _np_dct1(signals, n=None, norm=None):
65  """Computes the DCT-I manually with NumPy."""
66  # X_k = (x_0 + (-1)**k * x_{N-1} +
67  #       2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k)  k=0,...,N-1
68  del norm
69  signals_mod = _modify_input_for_dct(signals, n=n)
70  dct_size = signals_mod.shape[-1]
71  dct = np.zeros_like(signals_mod)
72  for k in range(dct_size):
73    phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1))
74    dct[..., k] = 2 * np.sum(
75        signals_mod[..., 1:-1] * phi, axis=-1) + (
76            signals_mod[..., 0] + (-1)**k * signals_mod[..., -1])
77  return dct
78
79
80def _np_dct2(signals, n=None, norm=None):
81  """Computes the DCT-II manually with NumPy."""
82  # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k)  k=0,...,N-1
83  signals_mod = _modify_input_for_dct(signals, n=n)
84  dct_size = signals_mod.shape[-1]
85  dct = np.zeros_like(signals_mod)
86  for k in range(dct_size):
87    phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
88    dct[..., k] = np.sum(signals_mod * phi, axis=-1)
89  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
90  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
91  if norm == "ortho":
92    # The orthonormal scaling includes a factor of 0.5 which we combine with
93    # the overall scaling of 2.0 to cancel.
94    dct[..., 0] *= np.sqrt(1.0 / dct_size)
95    dct[..., 1:] *= np.sqrt(2.0 / dct_size)
96  else:
97    dct *= 2.0
98  return dct
99
100
101def _np_dct3(signals, n=None, norm=None):
102  """Computes the DCT-III manually with NumPy."""
103  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
104  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
105  signals_mod = _modify_input_for_dct(signals, n=n)
106  dct_size = signals_mod.shape[-1]
107  signals_mod = np.array(signals_mod)  # make a copy so we can modify
108  if norm == "ortho":
109    signals_mod[..., 0] *= np.sqrt(4.0 / dct_size)
110    signals_mod[..., 1:] *= np.sqrt(2.0 / dct_size)
111  else:
112    signals_mod *= 2.0
113  dct = np.zeros_like(signals_mod)
114  # X_k = 0.5 * x_0 +
115  #       sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5))  k=0,...,N-1
116  half_x0 = 0.5 * signals_mod[..., 0]
117  for k in range(dct_size):
118    phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size)
119    dct[..., k] = half_x0 + np.sum(signals_mod[..., 1:] * phi, axis=-1)
120  return dct
121
122
123def _np_dct4(signals, n=None, norm=None):
124  """Computes the DCT-IV manually with NumPy."""
125  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
126  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
127  signals_mod = _modify_input_for_dct(signals, n=n)
128  dct_size = signals_mod.shape[-1]
129  signals_mod = np.array(signals_mod)  # make a copy so we can modify
130  if norm == "ortho":
131    signals_mod *= np.sqrt(2.0 / dct_size)
132  else:
133    signals_mod *= 2.0
134  dct = np.zeros_like(signals_mod)
135  # X_k = sum_{n=0}^{N-1}
136  #            x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1))  k=0,...,N-1
137  for k in range(dct_size):
138    phi = np.cos(np.pi *
139                 (2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
140                 (4.0 * dct_size))
141    dct[..., k] = np.sum(signals_mod * phi, axis=-1)
142  return dct
143
144
145NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
146NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
147
148
149@test_util.run_all_in_graph_and_eager_modes
150class DCTOpsTest(parameterized.TestCase, test.TestCase):
151
152  def _compare(self, signals, n, norm, dct_type, atol, rtol):
153    """Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
154    np_dct = NP_DCT[dct_type](signals, n=n, norm=norm)
155    tf_dct = dct_ops.dct(signals, n=n, type=dct_type, norm=norm)
156    self.assertEqual(tf_dct.dtype.as_numpy_dtype, signals.dtype)
157    self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
158    np_idct = NP_IDCT[dct_type](signals, n=None, norm=norm)
159    tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
160    self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
161    self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
162    if fftpack and dct_type != 4:
163      scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
164      self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
165      scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
166      self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol)
167    # Verify inverse(forward(s)) == s, up to a normalization factor.
168    # Since `n` is not implemented for IDCT operation, re-calculating tf_dct
169    # without n.
170    tf_dct = dct_ops.dct(signals, type=dct_type, norm=norm)
171    tf_idct_dct = dct_ops.idct(tf_dct, type=dct_type, norm=norm)
172    tf_dct_idct = dct_ops.dct(tf_idct, type=dct_type, norm=norm)
173    if norm is None:
174      if dct_type == 1:
175        tf_idct_dct *= 0.5 / (signals.shape[-1] - 1)
176        tf_dct_idct *= 0.5 / (signals.shape[-1] - 1)
177      else:
178        tf_idct_dct *= 0.5 / signals.shape[-1]
179        tf_dct_idct *= 0.5 / signals.shape[-1]
180    self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
181    self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
182
183  @parameterized.parameters(itertools.product(
184      [1, 2, 3, 4],
185      [None, "ortho"],
186      [[2], [3], [10], [2, 20], [2, 3, 25]],
187      [np.float32, np.float64]))
188  def test_random(self, dct_type, norm, shape, dtype):
189    """Test randomly generated batches of data."""
190    # "ortho" normalization is not implemented for type I.
191    if dct_type == 1 and norm == "ortho":
192      return
193    with self.session():
194      tol = 5e-4 if dtype == np.float32 else 1e-7
195      signals = np.random.rand(*shape).astype(dtype)
196      n = np.random.randint(1, 2 * signals.shape[-1])
197      n = np.random.choice([None, n])
198      self._compare(signals, n, norm=norm, dct_type=dct_type,
199                    rtol=tol, atol=tol)
200
201  def test_error(self):
202    signals = np.random.rand(10)
203    # Unsupported type.
204    with self.assertRaises(ValueError):
205      dct_ops.dct(signals, type=5)
206    # Invalid n.
207    with self.assertRaises(ValueError):
208      dct_ops.dct(signals, n=-2)
209    # DCT-I normalization not implemented.
210    with self.assertRaises(ValueError):
211      dct_ops.dct(signals, type=1, norm="ortho")
212    # DCT-I requires at least two inputs.
213    with self.assertRaises(ValueError):
214      dct_ops.dct(np.random.rand(1), type=1)
215    # Unknown normalization.
216    with self.assertRaises(ValueError):
217      dct_ops.dct(signals, norm="bad")
218    with self.assertRaises(NotImplementedError):
219      dct_ops.dct(signals, axis=0)
220
221
222if __name__ == "__main__":
223  test.main()
224