1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_audio_op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os.path
22
23import six
24
25from tensorflow.contrib import ffmpeg
26from tensorflow.python.framework import dtypes
27from tensorflow.python.ops import array_ops
28from tensorflow.python.platform import resource_loader
29from tensorflow.python.platform import test
30
31
32class DecodeAudioOpTest(test.TestCase):
33
34  def _loadFileAndTest(self, filename, file_format, duration_sec,
35                       samples_per_second, channel_count,
36                       samples_per_second_tensor=None, feed_dict=None,
37                       stream=None):
38    """Loads an audio file and validates the output tensor.
39
40    Args:
41      filename: The filename of the input file.
42      file_format: The format of the input file.
43      duration_sec: The duration of the audio contained in the file in seconds.
44      samples_per_second: The desired sample rate in the output tensor.
45      channel_count: The desired channel count in the output tensor.
46      samples_per_second_tensor: The value to pass to the corresponding
47        parameter in the instantiated `decode_audio` op. If not
48        provided, will default to a constant value of
49        `samples_per_second`. Useful for providing a placeholder.
50      feed_dict: Used when evaluating the `decode_audio` op. If not
51        provided, will be empty. Useful when providing a placeholder for
52        `samples_per_second_tensor`.
53      stream: A string specifying which stream from the content file
54        should be decoded. The default value is '' which leaves the
55        decision to ffmpeg.
56    """
57    if samples_per_second_tensor is None:
58      samples_per_second_tensor = samples_per_second
59    with self.cached_session():
60      path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
61                          filename)
62      with open(path, 'rb') as f:
63        contents = f.read()
64
65      audio_op = ffmpeg.decode_audio(
66          contents,
67          file_format=file_format,
68          samples_per_second=samples_per_second_tensor,
69          channel_count=channel_count, stream=stream)
70      audio = audio_op.eval(feed_dict=feed_dict or {})
71      self.assertEqual(len(audio.shape), 2)
72      self.assertNear(
73          duration_sec * samples_per_second,
74          audio.shape[0],
75          # Duration should be specified within 10%:
76          0.1 * audio.shape[0])
77      self.assertEqual(audio.shape[1], channel_count)
78
79  def testStreamIdentifier(self):
80    # mono_16khz_mp3_32khz_aac.mp4 was generated from:
81    # ffmpeg -i tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4 \
82    #        -i tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4 \
83    #        -strict -2 -map 0:a -map 1:a \
84    #        tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4
85    self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000,
86                          1, stream='0')
87    self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000,
88                          1, stream='1')
89
90  def testMonoMp3(self):
91    self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1)
92    self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2)
93
94  def testMonoMp4Mp3Codec(self):
95    # mp3 compressed audio streams in mp4 container.
96    self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 1)
97    self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 2)
98
99  def testMonoMp4AacCodec(self):
100    # aac compressed audio streams in mp4 container.
101    self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 1)
102    self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 2)
103
104  def testStereoMp3(self):
105    self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 50000, 1)
106    self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 20000, 2)
107
108  def testStereoMp4Mp3Codec(self):
109    # mp3 compressed audio streams in mp4 container.
110    self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 50000, 1)
111    self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 20000, 2)
112
113  def testStereoMp4AacCodec(self):
114    # aac compressed audio streams in mp4 container.
115    self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 50000, 1)
116    self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 20000, 2)
117
118  def testMonoWav(self):
119    self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 5000, 1)
120    self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 10000, 4)
121
122  def testOgg(self):
123    self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1)
124
125  def testInvalidFile(self):
126    with self.cached_session():
127      contents = 'invalid file'
128      audio_op = ffmpeg.decode_audio(
129          contents,
130          file_format='wav',
131          samples_per_second=10000,
132          channel_count=2)
133      audio = audio_op.eval()
134      self.assertEqual(audio.shape, (0, 0))
135
136  def testSampleRatePlaceholder(self):
137    placeholder = array_ops.placeholder(dtypes.int32)
138    self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1,
139                          samples_per_second_tensor=placeholder,
140                          feed_dict={placeholder: 20000})
141
142  def testSampleRateBadType(self):
143    placeholder = array_ops.placeholder(dtypes.float32)
144    with self.assertRaises(TypeError):
145      self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
146                            samples_per_second_tensor=placeholder,
147                            feed_dict={placeholder: 20000.0})
148
149  def testSampleRateBadValue_Zero(self):
150    placeholder = array_ops.placeholder(dtypes.int32)
151    with six.assertRaisesRegex(self, Exception,
152                               r'samples_per_second must be positive'):
153      self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
154                            samples_per_second_tensor=placeholder,
155                            feed_dict={placeholder: 0})
156
157  def testSampleRateBadValue_Negative(self):
158    placeholder = array_ops.placeholder(dtypes.int32)
159    with six.assertRaisesRegex(self, Exception,
160                               r'samples_per_second must be positive'):
161      self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1,
162                            samples_per_second_tensor=placeholder,
163                            feed_dict={placeholder: -2})
164
165  def testInvalidFileFormat(self):
166    with six.assertRaisesRegex(self, Exception,
167                               r'file_format must be one of'):
168      self._loadFileAndTest('mono_16khz.mp3', 'docx', 0.57, 20000, 1)
169
170  def testStaticShapeInference_ConstantChannelCount(self):
171    with self.cached_session():
172      audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
173                                     file_format='wav',
174                                     samples_per_second=44100,
175                                     channel_count=2)
176      self.assertEqual([None, 2], audio_op.shape.as_list())
177
178  def testStaticShapeInference_NonConstantChannelCount(self):
179    with self.cached_session():
180      channel_count = array_ops.placeholder(dtypes.int32)
181      audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
182                                     file_format='wav',
183                                     samples_per_second=44100,
184                                     channel_count=channel_count)
185      self.assertEqual([None, None], audio_op.shape.as_list())
186
187  def testStaticShapeInference_ZeroChannelCountInvalid(self):
188    with self.cached_session():
189      with six.assertRaisesRegex(self, Exception,
190                                 r'channel_count must be positive'):
191        ffmpeg.decode_audio(b'~~~ wave ~~~',
192                            file_format='wav',
193                            samples_per_second=44100,
194                            channel_count=0)
195
196  def testStaticShapeInference_NegativeChannelCountInvalid(self):
197    with self.cached_session():
198      with six.assertRaisesRegex(self, Exception,
199                                 r'channel_count must be positive'):
200        ffmpeg.decode_audio(b'~~~ wave ~~~',
201                            file_format='wav',
202                            samples_per_second=44100,
203                            channel_count=-2)
204
205
206if __name__ == '__main__':
207  test.main()
208