1#!/usr/bin/env python3
2#
3#   Copyright 2017 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16"""Audio Analysis tool to analyze wave file and detect artifacts."""
17
18import argparse
19import collections
20import json
21import logging
22import math
23import numpy
24import os
25import pprint
26import subprocess
27import tempfile
28import wave
29
30import acts.test_utils.audio_analysis_lib.audio_analysis as audio_analysis
31import acts.test_utils.audio_analysis_lib.audio_data as audio_data
32import acts.test_utils.audio_analysis_lib.audio_quality_measurement as \
33 audio_quality_measurement
34
35# Holder for quality parameters used in audio_quality_measurement module.
36QualityParams = collections.namedtuple('QualityParams', [
37    'block_size_secs', 'frequency_error_threshold',
38    'delay_amplitude_threshold', 'noise_amplitude_threshold',
39    'burst_amplitude_threshold'
40])
41
42DEFAULT_QUALITY_BLOCK_SIZE_SECS = 0.0015
43DEFAULT_BURST_AMPLITUDE_THRESHOLD = 1.4
44DEFAULT_DELAY_AMPLITUDE_THRESHOLD = 0.6
45DEFAULT_FREQUENCY_ERROR_THRESHOLD = 0.5
46DEFAULT_NOISE_AMPLITUDE_THRESHOLD = 0.5
47
48
49class WaveFileException(Exception):
50    """Error in WaveFile."""
51    pass
52
53
54class WaveFormatExtensibleException(Exception):
55    """Wave file is in WAVE_FORMAT_EXTENSIBLE format which is not supported."""
56    pass
57
58
59class WaveFile(object):
60    """Class which handles wave file reading.
61
62    Properties:
63        raw_data: audio_data.AudioRawData object for data in wave file.
64        rate: sampling rate.
65
66    """
67
68    def __init__(self, filename):
69        """Inits a wave file.
70
71        Args:
72            filename: file name of the wave file.
73
74        """
75        self.raw_data = None
76        self.rate = None
77
78        self._wave_reader = None
79        self._n_channels = None
80        self._sample_width_bits = None
81        self._n_frames = None
82        self._binary = None
83
84        try:
85            self._read_wave_file(filename)
86        except WaveFormatExtensibleException:
87            logging.warning(
88                'WAVE_FORMAT_EXTENSIBLE is not supproted. '
89                'Try command "sox in.wav -t wavpcm out.wav" to convert '
90                'the file to WAVE_FORMAT_PCM format.')
91            self._convert_and_read_wav_file(filename)
92
93    def _convert_and_read_wav_file(self, filename):
94        """Converts the wav file and read it.
95
96        Converts the file into WAVE_FORMAT_PCM format using sox command and
97        reads its content.
98
99        Args:
100            filename: The wave file to be read.
101
102        Raises:
103            RuntimeError: sox is not installed.
104
105        """
106        # Checks if sox is installed.
107        try:
108            subprocess.check_output(['sox', '--version'])
109        except:
110            raise RuntimeError('sox command is not installed. '
111                               'Try sudo apt-get install sox')
112
113        with tempfile.NamedTemporaryFile(suffix='.wav') as converted_file:
114            command = ['sox', filename, '-t', 'wavpcm', converted_file.name]
115            logging.debug('Convert the file using sox: %s', command)
116            subprocess.check_call(command)
117            self._read_wave_file(converted_file.name)
118
119    def _read_wave_file(self, filename):
120        """Reads wave file header and samples.
121
122        Args:
123            filename: The wave file to be read.
124
125        @raises WaveFormatExtensibleException: Wave file is in
126                                               WAVE_FORMAT_EXTENSIBLE format.
127        @raises WaveFileException: Wave file format is not supported.
128
129        """
130        try:
131            self._wave_reader = wave.open(filename, 'r')
132            self._read_wave_header()
133            self._read_wave_binary()
134        except wave.Error as e:
135            if 'unknown format: 65534' in str(e):
136                raise WaveFormatExtensibleException()
137            else:
138                logging.exception('Unsupported wave format')
139                raise WaveFileException()
140        finally:
141            if self._wave_reader:
142                self._wave_reader.close()
143
144    def _read_wave_header(self):
145        """Reads wave file header.
146
147        @raises WaveFileException: wave file is compressed.
148
149        """
150        # Header is a tuple of
151        # (nchannels, sampwidth, framerate, nframes, comptype, compname).
152        header = self._wave_reader.getparams()
153        logging.debug('Wave header: %s', header)
154
155        self._n_channels = header[0]
156        self._sample_width_bits = header[1] * 8
157        self.rate = header[2]
158        self._n_frames = header[3]
159        comptype = header[4]
160        compname = header[5]
161
162        if comptype != 'NONE' or compname != 'not compressed':
163            raise WaveFileException('Can not support compressed wav file.')
164
165    def _read_wave_binary(self):
166        """Reads in samples in wave file."""
167        self._binary = self._wave_reader.readframes(self._n_frames)
168        format_str = 'S%d_LE' % self._sample_width_bits
169        self.raw_data = audio_data.AudioRawData(
170            binary=self._binary,
171            channel=self._n_channels,
172            sample_format=format_str)
173
174
175class QualityCheckerError(Exception):
176    """Error in QualityChecker."""
177    pass
178
179
180class CompareFailure(QualityCheckerError):
181    """Exception when frequency comparison fails."""
182    pass
183
184
185class QualityFailure(QualityCheckerError):
186    """Exception when quality check fails."""
187    pass
188
189
190class QualityChecker(object):
191    """Quality checker controls the flow of checking quality of raw data."""
192
193    def __init__(self, raw_data, rate):
194        """Inits a quality checker.
195
196        Args:
197            raw_data: An audio_data.AudioRawData object.
198            rate: Sampling rate in samples per second. Example inputs: 44100,
199            48000
200
201        """
202        self._raw_data = raw_data
203        self._rate = rate
204        self._spectrals = []
205        self._quality_result = []
206
207    def do_spectral_analysis(self, ignore_high_freq, check_quality,
208                             quality_params):
209        """Gets the spectral_analysis result.
210
211        Args:
212            ignore_high_freq: Ignore high frequencies above this threshold.
213            check_quality: Check quality of each channel.
214            quality_params: A QualityParams object for quality measurement.
215
216        """
217        self.has_data()
218        for channel_idx in range(self._raw_data.channel):
219            signal = self._raw_data.channel_data[channel_idx]
220            max_abs = max(numpy.abs(signal))
221            logging.debug('Channel %d max abs signal: %f', channel_idx,
222                          max_abs)
223            if max_abs == 0:
224                logging.info('No data on channel %d, skip this channel',
225                             channel_idx)
226                continue
227
228            saturate_value = audio_data.get_maximum_value_from_sample_format(
229                self._raw_data.sample_format)
230            normalized_signal = audio_analysis.normalize_signal(
231                signal, saturate_value)
232            logging.debug('saturate_value: %f', saturate_value)
233            logging.debug('max signal after normalized: %f',
234                          max(normalized_signal))
235            spectral = audio_analysis.spectral_analysis(
236                normalized_signal, self._rate)
237
238            logging.debug('Channel %d spectral:\n%s', channel_idx,
239                          pprint.pformat(spectral))
240
241            # Ignore high frequencies above the threshold.
242            spectral = [(f, c) for (f, c) in spectral if f < ignore_high_freq]
243
244            logging.info('Channel %d spectral after ignoring high frequencies '
245                         'above %f:\n%s', channel_idx, ignore_high_freq,
246                         pprint.pformat(spectral))
247
248            try:
249                if check_quality:
250                    quality = audio_quality_measurement.quality_measurement(
251                        signal=normalized_signal,
252                        rate=self._rate,
253                        dominant_frequency=spectral[0][0],
254                        block_size_secs=quality_params.block_size_secs,
255                        frequency_error_threshold=quality_params.
256                        frequency_error_threshold,
257                        delay_amplitude_threshold=quality_params.
258                        delay_amplitude_threshold,
259                        noise_amplitude_threshold=quality_params.
260                        noise_amplitude_threshold,
261                        burst_amplitude_threshold=quality_params.
262                        burst_amplitude_threshold)
263
264                    logging.debug('Channel %d quality:\n%s', channel_idx,
265                                  pprint.pformat(quality))
266                    self._quality_result.append(quality)
267                self._spectrals.append(spectral)
268            except Exception as error:
269                logging.warning(
270                    "Failed to analyze channel {} with error: {}".format(
271                        channel_idx, error))
272
273    def has_data(self):
274        """Checks if data has been set.
275
276        Raises:
277            QualityCheckerError: if data or rate is not set yet.
278
279        """
280        if not self._raw_data or not self._rate:
281            raise QualityCheckerError('Data and rate is not set yet')
282
283    def check_freqs(self, expected_freqs, freq_threshold):
284        """Checks the dominant frequencies in the channels.
285
286        Args:
287            expected_freq: A list of frequencies. If frequency is 0, it
288                              means this channel should be ignored.
289            freq_threshold: The difference threshold to compare two
290                               frequencies.
291
292        """
293        logging.debug('expected_freqs: %s', expected_freqs)
294        for idx, expected_freq in enumerate(expected_freqs):
295            if expected_freq == 0:
296                continue
297            if not self._spectrals[idx]:
298                raise CompareFailure(
299                    'Failed at channel %d: no dominant frequency' % idx)
300            dominant_freq = self._spectrals[idx][0][0]
301            if abs(dominant_freq - expected_freq) > freq_threshold:
302                raise CompareFailure(
303                    'Failed at channel %d: %f is too far away from %f' %
304                    (idx, dominant_freq, expected_freq))
305
306    def check_quality(self):
307        """Checks the quality measurement results on each channel.
308
309        Raises:
310            QualityFailure when there is artifact.
311
312        """
313        error_msgs = []
314
315        for idx, quality_res in enumerate(self._quality_result):
316            artifacts = quality_res['artifacts']
317            if artifacts['noise_before_playback']:
318                error_msgs.append('Found noise before playback: %s' %
319                                  (artifacts['noise_before_playback']))
320            if artifacts['noise_after_playback']:
321                error_msgs.append('Found noise after playback: %s' %
322                                  (artifacts['noise_after_playback']))
323            if artifacts['delay_during_playback']:
324                error_msgs.append('Found delay during playback: %s' %
325                                  (artifacts['delay_during_playback']))
326            if artifacts['burst_during_playback']:
327                error_msgs.append('Found burst during playback: %s' %
328                                  (artifacts['burst_during_playback']))
329        if error_msgs:
330            raise QualityFailure('Found bad quality: %s',
331                                 '\n'.join(error_msgs))
332
333    def dump(self, output_file):
334        """Dumps the result into a file in json format.
335
336        Args:
337            output_file: A file path to dump spectral and quality
338                            measurement result of each channel.
339
340        """
341        dump_dict = {
342            'spectrals': self._spectrals,
343            'quality_result': self._quality_result
344        }
345        with open(output_file, 'w') as f:
346            json.dump(dump_dict, f)
347
348    def has_data(self):
349        """Checks if data has been set.
350
351        Raises:
352            QualityCheckerError: if data or rate is not set yet.
353
354        """
355        if not self._raw_data or not self._rate:
356            raise QualityCheckerError('Data and rate is not set yet')
357
358    def check_freqs(self, expected_freqs, freq_threshold):
359        """Checks the dominant frequencies in the channels.
360
361        Args:
362            expected_freq: A list of frequencies. If frequency is 0, it
363                              means this channel should be ignored.
364            freq_threshold: The difference threshold to compare two
365                               frequencies.
366
367        """
368        logging.debug('expected_freqs: %s', expected_freqs)
369        for idx, expected_freq in enumerate(expected_freqs):
370            if expected_freq == 0:
371                continue
372            if not self._spectrals[idx]:
373                raise CompareFailure(
374                    'Failed at channel %d: no dominant frequency' % idx)
375            dominant_freq = self._spectrals[idx][0][0]
376            if abs(dominant_freq - expected_freq) > freq_threshold:
377                raise CompareFailure(
378                    'Failed at channel %d: %f is too far away from %f' %
379                    (idx, dominant_freq, expected_freq))
380
381    def check_quality(self):
382        """Checks the quality measurement results on each channel.
383
384        Raises:
385            QualityFailure when there is artifact.
386
387        """
388        error_msgs = []
389
390        for idx, quality_res in enumerate(self._quality_result):
391            artifacts = quality_res['artifacts']
392            if artifacts['noise_before_playback']:
393                error_msgs.append('Found noise before playback: %s' %
394                                  (artifacts['noise_before_playback']))
395            if artifacts['noise_after_playback']:
396                error_msgs.append('Found noise after playback: %s' %
397                                  (artifacts['noise_after_playback']))
398            if artifacts['delay_during_playback']:
399                error_msgs.append('Found delay during playback: %s' %
400                                  (artifacts['delay_during_playback']))
401            if artifacts['burst_during_playback']:
402                error_msgs.append('Found burst during playback: %s' %
403                                  (artifacts['burst_during_playback']))
404        if error_msgs:
405            raise QualityFailure('Found bad quality: %s',
406                                 '\n'.join(error_msgs))
407
408    def dump(self, output_file):
409        """Dumps the result into a file in json format.
410
411        Args:
412            output_file: A file path to dump spectral and quality
413                            measurement result of each channel.
414
415        """
416        dump_dict = {
417            'spectrals': self._spectrals,
418            'quality_result': self._quality_result
419        }
420        with open(output_file, 'w') as f:
421            json.dump(dump_dict, f)
422
423
424class CheckQualityError(Exception):
425    """Error in check_quality main function."""
426    pass
427
428
429def read_audio_file(filename, channel, bit_width, rate):
430    """Reads audio file.
431
432    Args:
433        filename: The wav or raw file to check.
434        channel: For raw file. Number of channels.
435        bit_width: For raw file. Bit width of a sample.
436        rate: Sampling rate in samples per second. Example inputs: 44100,
437        48000
438
439
440    Returns:
441        A tuple (raw_data, rate) where raw_data is audio_data.AudioRawData, rate
442            is sampling rate.
443
444    """
445    if filename.endswith('.wav'):
446        wavefile = WaveFile(filename)
447        raw_data = wavefile.raw_data
448        rate = wavefile.rate
449    elif filename.endswith('.raw'):
450        binary = None
451        with open(filename, 'rb') as f:
452            binary = f.read()
453        raw_data = audio_data.AudioRawData(
454            binary=binary, channel=channel, sample_format='S%d_LE' % bit_width)
455    else:
456        raise CheckQualityError(
457            'File format for %s is not supported' % filename)
458
459    return raw_data, rate
460
461
462def get_quality_params(
463        quality_block_size_secs, quality_frequency_error_threshold,
464        quality_delay_amplitude_threshold, quality_noise_amplitude_threshold,
465        quality_burst_amplitude_threshold):
466    """Gets quality parameters in arguments.
467
468    Args:
469        quality_block_size_secs: Input block size in seconds.
470        quality_frequency_error_threshold: Input the frequency error
471        threshold.
472        quality_delay_amplitude_threshold: Input the delay aplitutde
473        threshold.
474        quality_noise_amplitude_threshold: Input the noise aplitutde
475        threshold.
476        quality_burst_amplitude_threshold: Input the burst aplitutde
477        threshold.
478
479    Returns:
480        A QualityParams object.
481
482    """
483    quality_params = QualityParams(
484        block_size_secs=quality_block_size_secs,
485        frequency_error_threshold=quality_frequency_error_threshold,
486        delay_amplitude_threshold=quality_delay_amplitude_threshold,
487        noise_amplitude_threshold=quality_noise_amplitude_threshold,
488        burst_amplitude_threshold=quality_burst_amplitude_threshold)
489
490    return quality_params
491
492
493def quality_analysis(
494        filename,
495        output_file,
496        bit_width,
497        rate,
498        channel,
499        freqs=None,
500        freq_threshold=5,
501        ignore_high_freq=5000,
502        spectral_only=False,
503        quality_block_size_secs=DEFAULT_QUALITY_BLOCK_SIZE_SECS,
504        quality_burst_amplitude_threshold=DEFAULT_BURST_AMPLITUDE_THRESHOLD,
505        quality_delay_amplitude_threshold=DEFAULT_DELAY_AMPLITUDE_THRESHOLD,
506        quality_frequency_error_threshold=DEFAULT_FREQUENCY_ERROR_THRESHOLD,
507        quality_noise_amplitude_threshold=DEFAULT_NOISE_AMPLITUDE_THRESHOLD,
508):
509    """ Runs various functions to measure audio quality base on user input.
510
511    Args:
512        filename: The wav or raw file to check.
513        output_file: Output file to dump analysis result in JSON format.
514        bit_width: For raw file. Bit width of a sample.
515        rate: Sampling rate in samples per second. Example inputs: 44100,
516        48000
517        channel: For raw file. Number of channels.
518        freqs: Expected frequencies in the channels.
519        freq_threshold: Frequency difference threshold in Hz.
520        ignore_high_freq: Frequency threshold in Hz to be ignored for high
521        frequency. Default is 5KHz
522        spectral_only: Only do spectral analysis on each channel.
523        quality_block_size_secs: Input block size in seconds.
524        quality_frequency_error_threshold: Input the frequency error
525        threshold.
526        quality_delay_amplitude_threshold: Input the delay aplitutde
527        threshold.
528        quality_noise_amplitude_threshold: Input the noise aplitutde
529        threshold.
530        quality_burst_amplitude_threshold: Input the burst aplitutde
531        threshold.
532    """
533
534    raw_data, rate = read_audio_file(filename, channel, bit_width, rate)
535
536    checker = QualityChecker(raw_data, rate)
537
538    quality_params = get_quality_params(
539        quality_block_size_secs, quality_frequency_error_threshold,
540        quality_delay_amplitude_threshold, quality_noise_amplitude_threshold,
541        quality_burst_amplitude_threshold)
542
543    checker.do_spectral_analysis(
544        ignore_high_freq=ignore_high_freq,
545        check_quality=(not spectral_only),
546        quality_params=quality_params)
547
548    checker.dump(output_file)
549
550    if freqs:
551        checker.check_freqs(freqs, freq_threshold)
552
553    if not spectral_only:
554        checker.check_quality()
555    logging.debug("Audio analysis completed.")
556