1#!/usr/bin/python
2# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6
7import logging
8import numpy
9import os
10import re
11import subprocess
12import tempfile
13import threading
14import time
15
16from glob import glob
17from autotest_lib.client.bin import test, utils
18from autotest_lib.client.bin.input.input_device import *
19from autotest_lib.client.common_lib import error
20from autotest_lib.client.cros.audio import audio_data
21from autotest_lib.client.cros.audio import cmd_utils
22from autotest_lib.client.cros.audio import cras_utils
23from autotest_lib.client.cros.audio import sox_utils
24
25LD_LIBRARY_PATH = 'LD_LIBRARY_PATH'
26
27_AUDIO_DIAGNOSTICS_PATH = '/usr/bin/audio_diagnostics'
28
29_DEFAULT_NUM_CHANNELS = 2
30_DEFAULT_REC_COMMAND = 'arecord -D hw:0,0 -d 10 -f dat'
31_DEFAULT_SOX_FORMAT = '-t raw -b 16 -e signed -r 48000 -L'
32_DEFAULT_PLAYBACK_VOLUME = 100
33_DEFAULT_CAPTURE_GAIN = 2500
34_DEFAULT_ALSA_MAX_VOLUME = '100%'
35_DEFAULT_ALSA_CAPTURE_GAIN = '25dB'
36
37# Minimum RMS value to pass when checking recorded file.
38_DEFAULT_SOX_RMS_THRESHOLD = 0.08
39
40_JACK_VALUE_ON_RE = re.compile(r'.*values=on')
41_HP_JACK_CONTROL_RE = re.compile(r'numid=(\d+).*Headphone\sJack')
42_MIC_JACK_CONTROL_RE = re.compile(r'numid=(\d+).*Mic\sJack')
43
44_SOX_RMS_AMPLITUDE_RE = re.compile(r'RMS\s+amplitude:\s+(.+)')
45_SOX_ROUGH_FREQ_RE = re.compile(r'Rough\s+frequency:\s+(.+)')
46
47_AUDIO_NOT_FOUND_RE = r'Audio\snot\sdetected'
48_MEASURED_LATENCY_RE = r'Measured\sLatency:\s(\d+)\suS'
49_REPORTED_LATENCY_RE = r'Reported\sLatency:\s(\d+)\suS'
50
51# Tools from platform/audiotest
52AUDIOFUNTEST_PATH = 'audiofuntest'
53AUDIOLOOP_PATH = 'looptest'
54LOOPBACK_LATENCY_PATH = 'loopback_latency'
55SOX_PATH = 'sox'
56TEST_TONES_PATH = 'test_tones'
57
58_MINIMUM_NORM = 0.001
59_CORRELATION_INDEX_THRESHOLD = 0.999
60# The minimum difference of estimated frequencies between two sine waves.
61_FREQUENCY_DIFF_THRESHOLD = 20
62# The minimum RMS value of meaningful audio data.
63_MEANINGFUL_RMS_THRESHOLD = 0.001
64
65def set_mixer_controls(mixer_settings={}, card='0'):
66    """Sets all mixer controls listed in the mixer settings on card.
67
68    @param mixer_settings: Mixer settings to set.
69    @param card: Index of audio card to set mixer settings for.
70    """
71    logging.info('Setting mixer control values on %s', card)
72    for item in mixer_settings:
73        logging.info('Setting %s to %s on card %s',
74                     item['name'], item['value'], card)
75        cmd = 'amixer -c %s cset name=%s %s'
76        cmd = cmd % (card, item['name'], item['value'])
77        try:
78            utils.system(cmd)
79        except error.CmdError:
80            # A card is allowed not to support all the controls, so don't
81            # fail the test here if we get an error.
82            logging.info('amixer command failed: %s', cmd)
83
84def set_volume_levels(volume, capture):
85    """Sets the volume and capture gain through cras_test_client.
86
87    @param volume: The playback volume to set.
88    @param capture: The capture gain to set.
89    """
90    logging.info('Setting volume level to %d', volume)
91    try:
92        utils.system('/usr/bin/cras_test_client --volume %d' % volume)
93        logging.info('Setting capture gain to %d', capture)
94        utils.system('/usr/bin/cras_test_client --capture_gain %d' % capture)
95        utils.system('/usr/bin/cras_test_client --dump_server_info')
96        utils.system('/usr/bin/cras_test_client --mute 0')
97    except error.CmdError as e:
98        raise error.TestError(
99                '*** Can not tune volume through CRAS. *** (' + str(e) + ')')
100
101    try:
102        utils.system('amixer -c 0 contents')
103    except error.CmdError as e:
104        logging.info('amixer command failed: %s', str(e))
105
106def loopback_latency_check(**args):
107    """Checks loopback latency.
108
109    @param args: additional arguments for loopback_latency.
110
111    @return A tuple containing measured and reported latency in uS.
112        Return None if no audio detected.
113    """
114    noise_threshold = str(args['n']) if args.has_key('n') else '400'
115
116    cmd = '%s -n %s -c' % (LOOPBACK_LATENCY_PATH, noise_threshold)
117
118    output = utils.system_output(cmd, retain_output=True)
119
120    # Sleep for a short while to make sure device is not busy anymore
121    # after called loopback_latency.
122    time.sleep(.1)
123
124    measured_latency = None
125    reported_latency = None
126    for line in output.split('\n'):
127        match = re.search(_MEASURED_LATENCY_RE, line, re.I)
128        if match:
129            measured_latency = int(match.group(1))
130            continue
131        match = re.search(_REPORTED_LATENCY_RE, line, re.I)
132        if match:
133            reported_latency = int(match.group(1))
134            continue
135        if re.search(_AUDIO_NOT_FOUND_RE, line, re.I):
136            return None
137    if measured_latency and reported_latency:
138        return (measured_latency, reported_latency)
139    else:
140        # Should not reach here, just in case.
141        return None
142
143def get_mixer_jack_status(jack_reg_exp):
144    """Gets the mixer jack status.
145
146    @param jack_reg_exp: The regular expression to match jack control name.
147
148    @return None if the control does not exist, return True if jack control
149        is detected plugged, return False otherwise.
150    """
151    output = utils.system_output('amixer -c0 controls', retain_output=True)
152    numid = None
153    for line in output.split('\n'):
154        m = jack_reg_exp.match(line)
155        if m:
156            numid = m.group(1)
157            break
158
159    # Proceed only when matched numid is not empty.
160    if numid:
161        output = utils.system_output('amixer -c0 cget numid=%s' % numid)
162        for line in output.split('\n'):
163            if _JACK_VALUE_ON_RE.match(line):
164                return True
165        return False
166    else:
167        return None
168
169def get_hp_jack_status():
170    """Gets the status of headphone jack."""
171    status = get_mixer_jack_status(_HP_JACK_CONTROL_RE)
172    if status is not None:
173        return status
174
175    # When headphone jack is not found in amixer, lookup input devices
176    # instead.
177    #
178    # TODO(hychao): Check hp/mic jack status dynamically from evdev. And
179    # possibly replace the existing check using amixer.
180    for evdev in glob('/dev/input/event*'):
181        device = InputDevice(evdev)
182        if device.is_hp_jack():
183            return device.get_headphone_insert()
184    else:
185        return None
186
187def get_mic_jack_status():
188    """Gets the status of mic jack."""
189    status = get_mixer_jack_status(_MIC_JACK_CONTROL_RE)
190    if status is not None:
191        return status
192
193    # When mic jack is not found in amixer, lookup input devices instead.
194    for evdev in glob('/dev/input/event*'):
195        device = InputDevice(evdev)
196        if device.is_mic_jack():
197            return device.get_microphone_insert()
198    else:
199        return None
200
201def log_loopback_dongle_status():
202    """Log the status of the loopback dongle to make sure it is equipped."""
203    dongle_status_ok = True
204
205    # Check Mic Jack
206    mic_jack_status = get_mic_jack_status()
207    logging.info('Mic jack status: %s', mic_jack_status)
208    dongle_status_ok &= bool(mic_jack_status)
209
210    # Check Headphone Jack
211    hp_jack_status = get_hp_jack_status()
212    logging.info('Headphone jack status: %s', hp_jack_status)
213    dongle_status_ok &= bool(hp_jack_status)
214
215    # Use latency check to test if audio can be captured through dongle.
216    # We only want to know the basic function of dongle, so no need to
217    # assert the latency accuracy here.
218    latency = loopback_latency_check(n=4000)
219    if latency:
220        logging.info('Got latency measured %d, reported %d',
221                latency[0], latency[1])
222    else:
223        logging.info('Latency check fail.')
224        dongle_status_ok = False
225
226    logging.info('audio loopback dongle test: %s',
227            'PASS' if dongle_status_ok else 'FAIL')
228
229# Functions to test audio palyback.
230def play_sound(duration_seconds=None, audio_file_path=None):
231    """Plays a sound file found at |audio_file_path| for |duration_seconds|.
232
233    If |audio_file_path|=None, plays a default audio file.
234    If |duration_seconds|=None, plays audio file in its entirety.
235
236    @param duration_seconds: Duration to play sound.
237    @param audio_file_path: Path to the audio file.
238    """
239    if not audio_file_path:
240        audio_file_path = '/usr/local/autotest/cros/audio/sine440.wav'
241    duration_arg = ('-d %d' % duration_seconds) if duration_seconds else ''
242    utils.system('aplay %s %s' % (duration_arg, audio_file_path))
243
244def get_play_sine_args(channel, odev='default', freq=1000, duration=10,
245                       sample_size=16):
246    """Gets the command args to generate a sine wav to play to odev.
247
248    @param channel: 0 for left, 1 for right; otherwize, mono.
249    @param odev: alsa output device.
250    @param freq: frequency of the generated sine tone.
251    @param duration: duration of the generated sine tone.
252    @param sample_size: output audio sample size. Default to 16.
253    """
254    cmdargs = [SOX_PATH, '-b', str(sample_size), '-n', '-t', 'alsa',
255               odev, 'synth', str(duration)]
256    if channel == 0:
257        cmdargs += ['sine', str(freq), 'sine', '0']
258    elif channel == 1:
259        cmdargs += ['sine', '0', 'sine', str(freq)]
260    else:
261        cmdargs += ['sine', str(freq)]
262
263    return cmdargs
264
265def play_sine(channel, odev='default', freq=1000, duration=10,
266              sample_size=16):
267    """Generates a sine wave and plays to odev.
268
269    @param channel: 0 for left, 1 for right; otherwize, mono.
270    @param odev: alsa output device.
271    @param freq: frequency of the generated sine tone.
272    @param duration: duration of the generated sine tone.
273    @param sample_size: output audio sample size. Default to 16.
274    """
275    cmdargs = get_play_sine_args(channel, odev, freq, duration, sample_size)
276    utils.system(' '.join(cmdargs))
277
278# Functions to compose customized sox command, execute it and process the
279# output of sox command.
280def get_sox_mixer_cmd(infile, channel,
281                      num_channels=_DEFAULT_NUM_CHANNELS,
282                      sox_format=_DEFAULT_SOX_FORMAT):
283    """Gets sox mixer command to reduce channel.
284
285    @param infile: Input file name.
286    @param channel: The selected channel to take effect.
287    @param num_channels: The number of total channels to test.
288    @param sox_format: Format to generate sox command.
289    """
290    # Build up a pan value string for the sox command.
291    if channel == 0:
292        pan_values = '1'
293    else:
294        pan_values = '0'
295    for pan_index in range(1, num_channels):
296        if channel == pan_index:
297            pan_values = '%s%s' % (pan_values, ',1')
298        else:
299            pan_values = '%s%s' % (pan_values, ',0')
300
301    return '%s -c 2 %s %s -c 1 %s - mixer %s' % (SOX_PATH,
302            sox_format, infile, sox_format, pan_values)
303
304def sox_stat_output(infile, channel,
305                    num_channels=_DEFAULT_NUM_CHANNELS,
306                    sox_format=_DEFAULT_SOX_FORMAT):
307    """Executes sox stat command.
308
309    @param infile: Input file name.
310    @param channel: The selected channel.
311    @param num_channels: The number of total channels to test.
312    @param sox_format: Format to generate sox command.
313
314    @return The output of sox stat command
315    """
316    sox_mixer_cmd = get_sox_mixer_cmd(infile, channel,
317                                      num_channels, sox_format)
318    stat_cmd = '%s -c 1 %s - -n stat 2>&1' % (SOX_PATH, sox_format)
319    sox_cmd = '%s | %s' % (sox_mixer_cmd, stat_cmd)
320    return utils.system_output(sox_cmd, retain_output=True)
321
322def get_audio_rms(sox_output):
323    """Gets the audio RMS value from sox stat output
324
325    @param sox_output: Output of sox stat command.
326
327    @return The RMS value parsed from sox stat output.
328    """
329    for rms_line in sox_output.split('\n'):
330        m = _SOX_RMS_AMPLITUDE_RE.match(rms_line)
331        if m is not None:
332            return float(m.group(1))
333
334def get_rough_freq(sox_output):
335    """Gets the rough audio frequency from sox stat output
336
337    @param sox_output: Output of sox stat command.
338
339    @return The rough frequency value parsed from sox stat output.
340    """
341    for rms_line in sox_output.split('\n'):
342        m = _SOX_ROUGH_FREQ_RE.match(rms_line)
343        if m is not None:
344            return int(m.group(1))
345
346def check_audio_rms(sox_output, sox_threshold=_DEFAULT_SOX_RMS_THRESHOLD):
347    """Checks if the calculated RMS value is expected.
348
349    @param sox_output: The output from sox stat command.
350    @param sox_threshold: The threshold to test RMS value against.
351
352    @raises error.TestError if RMS amplitude can't be parsed.
353    @raises error.TestFail if the RMS amplitude of the recording isn't above
354            the threshold.
355    """
356    rms_val = get_audio_rms(sox_output)
357
358    # In case we don't get a valid RMS value.
359    if rms_val is None:
360        raise error.TestError(
361            'Failed to generate an audio RMS value from playback.')
362
363    logging.info('Got audio RMS value of %f. Minimum pass is %f.',
364                 rms_val, sox_threshold)
365    if rms_val < sox_threshold:
366        raise error.TestFail(
367            'Audio RMS value %f too low. Minimum pass is %f.' %
368            (rms_val, sox_threshold))
369
370def noise_reduce_file(in_file, noise_file, out_file,
371                      sox_format=_DEFAULT_SOX_FORMAT):
372    """Runs the sox command to reduce noise.
373
374    Runs the sox command to noise-reduce in_file using the noise
375    profile from noise_file.
376
377    @param in_file: The file to noise reduce.
378    @param noise_file: The file containing the noise profile.
379        This can be created by recording silence.
380    @param out_file: The file contains the noise reduced sound.
381    @param sox_format: The  sox format to generate sox command.
382    """
383    prof_cmd = '%s -c 2 %s %s -n noiseprof' % (SOX_PATH,
384               sox_format, noise_file)
385    reduce_cmd = ('%s -c 2 %s %s -c 2 %s %s noisered' %
386            (SOX_PATH, sox_format, in_file, sox_format, out_file))
387    utils.system('%s | %s' % (prof_cmd, reduce_cmd))
388
389def record_sample(tmpfile, record_command=_DEFAULT_REC_COMMAND):
390    """Records a sample from the default input device.
391
392    @param tmpfile: The file to record to.
393    @param record_command: The command to record audio.
394    """
395    utils.system('%s %s' % (record_command, tmpfile))
396
397def create_wav_file(wav_dir, prefix=""):
398    """Creates a unique name for wav file.
399
400    The created file name will be preserved in autotest result directory
401    for future analysis.
402
403    @param wav_dir: The directory of created wav file.
404    @param prefix: specified file name prefix.
405    """
406    filename = "%s-%s.wav" % (prefix, time.time())
407    return os.path.join(wav_dir, filename)
408
409def run_in_parallel(*funs):
410    """Runs methods in parallel.
411
412    @param funs: methods to run.
413    """
414    threads = []
415    for f in funs:
416        t = threading.Thread(target=f)
417        t.start()
418        threads.append(t)
419
420    for t in threads:
421        t.join()
422
423def loopback_test_channels(noise_file_name, wav_dir,
424                           playback_callback=None,
425                           check_recorded_callback=check_audio_rms,
426                           preserve_test_file=True,
427                           num_channels = _DEFAULT_NUM_CHANNELS,
428                           record_callback=record_sample,
429                           mix_callback=None):
430    """Tests loopback on all channels.
431
432    @param noise_file_name: Name of the file contains pre-recorded noise.
433    @param wav_dir: The directory of created wav file.
434    @param playback_callback: The callback to do the playback for
435        one channel.
436    @param record_callback: The callback to do the recording.
437    @param check_recorded_callback: The callback to check recorded file.
438    @param preserve_test_file: Retain the recorded files for future debugging.
439    @param num_channels: The number of total channels to test.
440    @param mix_callback: The callback to do on the one-channel file.
441    """
442    for channel in xrange(num_channels):
443        record_file_name = create_wav_file(wav_dir,
444                                           "record-%d" % channel)
445        functions = [lambda: record_callback(record_file_name)]
446
447        if playback_callback:
448            functions.append(lambda: playback_callback(channel))
449
450        if mix_callback:
451            mix_file_name = create_wav_file(wav_dir, "mix-%d" % channel)
452            functions.append(lambda: mix_callback(mix_file_name))
453
454        run_in_parallel(*functions)
455
456        if mix_callback:
457            sox_output_mix = sox_stat_output(mix_file_name, channel)
458            rms_val_mix = get_audio_rms(sox_output_mix)
459            logging.info('Got mixed audio RMS value of %f.', rms_val_mix)
460
461        sox_output_record = sox_stat_output(record_file_name, channel)
462        rms_val_record = get_audio_rms(sox_output_record)
463        logging.info('Got recorded audio RMS value of %f.', rms_val_record)
464
465        reduced_file_name = create_wav_file(wav_dir,
466                                            "reduced-%d" % channel)
467        noise_reduce_file(record_file_name, noise_file_name,
468                          reduced_file_name)
469
470        sox_output_reduced = sox_stat_output(reduced_file_name, channel)
471
472        if not preserve_test_file:
473            os.unlink(reduced_file_name)
474            os.unlink(record_file_name)
475            if mix_callback:
476                os.unlink(mix_file_name)
477
478        check_recorded_callback(sox_output_reduced)
479
480
481def get_channel_sox_stat(
482        input_audio, channel_index, channels=2, bits=16, rate=48000):
483    """Gets the sox stat info of the selected channel in the input audio file.
484
485    @param input_audio: The input audio file to be analyzed.
486    @param channel_index: The index of the channel to be analyzed.
487                          (1 for the first channel).
488    @param channels: The number of channels in the input audio.
489    @param bits: The number of bits of each audio sample.
490    @param rate: The sampling rate.
491    """
492    if channel_index <= 0 or channel_index > channels:
493        raise ValueError('incorrect channel_indexi: %d' % channel_index)
494
495    if channels == 1:
496        return sox_utils.get_stat(
497                input_audio, channels=channels, bits=bits, rate=rate)
498
499    p1 = cmd_utils.popen(
500            sox_utils.extract_channel_cmd(
501                    input_audio, '-', channel_index,
502                    channels=channels, bits=bits, rate=rate),
503            stdout=subprocess.PIPE)
504    p2 = cmd_utils.popen(
505            sox_utils.stat_cmd('-', channels=1, bits=bits, rate=rate),
506            stdin=p1.stdout, stderr=subprocess.PIPE)
507    stat_output = p2.stderr.read()
508    cmd_utils.wait_and_check_returncode(p1, p2)
509    return sox_utils.parse_stat_output(stat_output)
510
511
512def get_rms(input_audio, channels=1, bits=16, rate=48000):
513    """Gets the RMS values of all channels of the input audio.
514
515    @param input_audio: The input audio file to be checked.
516    @param channels: The number of channels in the input audio.
517    @param bits: The number of bits of each audio sample.
518    @param rate: The sampling rate.
519    """
520    stats = [get_channel_sox_stat(
521            input_audio, i + 1, channels=channels, bits=bits,
522            rate=rate) for i in xrange(channels)]
523
524    logging.info('sox stat: %s', [str(s) for s in stats])
525    return [s.rms for s in stats]
526
527
528def reduce_noise_and_get_rms(
529        input_audio, noise_file, channels=1, bits=16, rate=48000):
530    """Reduces noise in the input audio by the given noise file and then gets
531    the RMS values of all channels of the input audio.
532
533    @param input_audio: The input audio file to be analyzed.
534    @param noise_file: The noise file used to reduce noise in the input audio.
535    @param channels: The number of channels in the input audio.
536    @param bits: The number of bits of each audio sample.
537    @param rate: The sampling rate.
538    """
539    with tempfile.NamedTemporaryFile() as reduced_file:
540        p1 = cmd_utils.popen(
541                sox_utils.noise_profile_cmd(
542                        noise_file, '-', channels=channels, bits=bits,
543                        rate=rate),
544                stdout=subprocess.PIPE)
545        p2 = cmd_utils.popen(
546                sox_utils.noise_reduce_cmd(
547                        input_audio, reduced_file.name, '-',
548                        channels=channels, bits=bits, rate=rate),
549                stdin=p1.stdout)
550        cmd_utils.wait_and_check_returncode(p1, p2)
551        return get_rms(reduced_file.name, channels, bits, rate)
552
553
554def cras_rms_test_setup():
555    """Setups for the cras_rms_tests.
556
557    To make sure the line_out-to-mic_in path is all green.
558    """
559    # TODO(owenlin): Now, the nodes are choosed by chrome.
560    #                We should do it here.
561    cras_utils.set_system_volume(_DEFAULT_PLAYBACK_VOLUME)
562    cras_utils.set_selected_output_node_volume(_DEFAULT_PLAYBACK_VOLUME)
563
564    cras_utils.set_capture_gain(_DEFAULT_CAPTURE_GAIN)
565
566    cras_utils.set_system_mute(False)
567    cras_utils.set_capture_mute(False)
568
569
570def generate_rms_postmortem():
571    """Generates postmortem for rms tests."""
572    try:
573        logging.info('audio postmortem report')
574        log_loopback_dongle_status()
575        logging.info(get_audio_diagnostics())
576    except Exception:
577        logging.exception('Error while generating postmortem report')
578
579
580def get_audio_diagnostics():
581    """Gets audio diagnostic results.
582
583    @returns: a string containing diagnostic results.
584
585    """
586    return cmd_utils.execute([_AUDIO_DIAGNOSTICS_PATH], stdout=subprocess.PIPE)
587
588
589def get_max_cross_correlation(signal_a, signal_b):
590    """Gets max cross-correlation and best time delay of two signals.
591
592    Computes cross-correlation function between two
593    signals and gets the maximum value and time delay.
594    The steps includes:
595      1. Compute cross-correlation function of X and Y and get Cxy.
596         The correlation function Cxy is an array where Cxy[k] is the
597         cross product of X and Y when Y is delayed by k.
598         Refer to manual of numpy.correlate for detail of correlation.
599      2. Find the maximum value C_max and index C_index in Cxy.
600      3. Compute L2 norm of X and Y to get norm(X) and norm(Y).
601      4. Divide C_max by norm(X)*norm(Y) to get max cross-correlation.
602
603    Max cross-correlation indicates the similarity of X and Y. The value
604    is 1 if X equals Y multiplied by a positive scalar.
605    The value is -1 if X equals Y multiplied by a negative scaler.
606    Any constant level shift will be regarded as distortion and will make
607    max cross-correlation value deviated from 1.
608    C_index is the best time delay of Y that make Y looks similar to X.
609    Refer to http://en.wikipedia.org/wiki/Cross-correlation.
610
611    @param signal_a: A list of numbers which contains the first signal.
612    @param signal_b: A list of numbers which contains the second signal.
613
614    @raises: ValueError if any number in signal_a or signal_b is not a float.
615             ValueError if norm of any array is less than _MINIMUM_NORM.
616
617    @returns: A tuple (correlation index, best delay). If there are more than
618              one best delay, just return the first one.
619    """
620    def check_list_contains_float(numbers):
621        """Checks the elements in a list are all float.
622
623        @param numbers: A list of numbers.
624
625        @raises: ValueError if there is any element which is not a float
626                 in the list.
627        """
628        if any(not isinstance(x, float) for x in numbers):
629            raise ValueError('List contains number which is not a float')
630
631    check_list_contains_float(signal_a)
632    check_list_contains_float(signal_b)
633
634    norm_a = numpy.linalg.norm(signal_a)
635    norm_b = numpy.linalg.norm(signal_b)
636    logging.debug('norm_a: %f', norm_a)
637    logging.debug('norm_b: %f', norm_b)
638    if norm_a <= _MINIMUM_NORM or norm_b <= _MINIMUM_NORM:
639        raise ValueError('No meaningful data as norm is too small.')
640
641    correlation = numpy.correlate(signal_a, signal_b, 'full')
642    max_correlation = max(correlation)
643    best_delays = [i for i, j in enumerate(correlation) if j == max_correlation]
644    if len(best_delays) > 1:
645        logging.warning('There are more than one best delay: %r', best_delays)
646    return max_correlation / (norm_a * norm_b), best_delays[0]
647
648
649def trim_data(data, threshold=0):
650    """Trims a data by removing value that is too small in head and tail.
651
652    Removes elements in head and tail whose absolute value is smaller than
653    or equal to threshold.
654    E.g. trim_data([0.0, 0.1, 0.2, 0.3, 0.2, 0.1, 0.0], 0.2) =
655    ([0.2, 0.3, 0.2], 2)
656
657    @param data: A list of numbers.
658    @param threshold: The threshold to compare against.
659
660    @returns: A tuple (trimmed_data, end_trimmed_length), where
661              end_trimmed_length is the length of original data being trimmed
662              from the end.
663              Returns ([], None) if there is no valid data.
664    """
665    indice_valid = [
666            i for i, j in enumerate(data) if abs(j) > threshold]
667    if not indice_valid:
668        logging.warning(
669                'There is no element with absolute value greater '
670                'than threshold %f', threshold)
671        return [], None
672    logging.debug('Start and end of indice_valid: %d, %d',
673                  indice_valid[0], indice_valid[-1])
674    end_trimmed_length = len(data) - indice_valid[-1] - 1
675    logging.debug('Trimmed length in the end: %d', end_trimmed_length)
676    return (data[indice_valid[0] : indice_valid[-1] + 1], end_trimmed_length)
677
678
679def get_one_channel_correlation(test_data, golden_data):
680    """Gets max cross-correlation of test_data and golden_data.
681
682    Trims test data and compute the max cross-correlation against golden_data.
683    Signal can be trimmed because those zero values in the head and tail of
684    a signal will not affect correlation computation.
685
686    @param test_data: A list containing the data to compare against golden data.
687    @param golden_data: A list containing the golden data.
688
689    @returns: A tuple (max cross-correlation, best_delay) if data is valid.
690              Otherwise returns (None, None). Refer to docstring of
691              get_max_cross_correlation.
692    """
693    trimmed_test_data, end_trimmed_length = trim_data(test_data)
694
695    def to_float(samples):
696      """Casts elements in the list to float.
697
698      @param samples: A list of numbers.
699
700      @returns: A list of original numbers casted to float.
701      """
702      samples_float = [float(x) for x in samples]
703      return samples_float
704
705    max_cross_correlation, best_delay =  get_max_cross_correlation(
706            to_float(golden_data),
707            to_float(trimmed_test_data))
708
709    # The reason to add back the trimmed length in the end.
710    # E.g.:
711    # golden data:
712    #
713    # |-----------vvvv----------------|  vvvv is the signal of interest.
714    #       a                 b
715    #
716    # test data:
717    #
718    # |---x----vvvv--------x----------------|  x is the place to trim.
719    #   c   d         e            f
720    #
721    # trimmed test data:
722    #
723    # |----vvvv--------|
724    #   d         e
725    #
726    # The first output of cross correlation computation :
727    #
728    #                  |-----------vvvv----------------|
729    #                       a                 b
730    #
731    # |----vvvv--------|
732    #   d         e
733    #
734    # The largest output of cross correlation computation happens at
735    # delay a + e.
736    #
737    #                  |-----------vvvv----------------|
738    #                       a                 b
739    #
740    #                         |----vvvv--------|
741    #                           d         e
742    #
743    # Cross correlation starts computing by aligning the last sample
744    # of the trimmed test data to the first sample of golden data.
745    # The best delay calculated from trimmed test data and golden data
746    # cross correlation is e + a. But the real best delay that should be
747    # identical on two channel should be e + a + f.
748    # So we need to add back the length being trimmed in the end.
749
750    if max_cross_correlation:
751        return max_cross_correlation, best_delay + end_trimmed_length
752    else:
753        return None, None
754
755
756def compare_one_channel_correlation(test_data, golden_data, parameters):
757    """Compares two one-channel data by correlation.
758
759    @param test_data: A list containing the data to compare against golden data.
760    @param golden_data: A list containing the golden data.
761    @param parameters: A dict containing parameters for method.
762
763    @returns: A dict containing:
764              index: The index of similarity where 1 means they are different
765                  only by a positive scale.
766              best_delay: The best delay of test data in relative to golden
767                  data.
768              equal: A bool containing comparing result.
769    """
770    if 'correlation_threshold' in parameters:
771        threshold = parameters['correlation_threshold']
772    else:
773        threshold = _CORRELATION_INDEX_THRESHOLD
774
775    result_dict = dict()
776    max_cross_correlation, best_delay = get_one_channel_correlation(
777            test_data, golden_data)
778    result_dict['index'] = max_cross_correlation
779    result_dict['best_delay'] = best_delay
780    result_dict['equal'] = True if (
781        max_cross_correlation and
782        max_cross_correlation > threshold) else False
783    logging.debug('result_dict: %r', result_dict)
784    return result_dict
785
786
787def compare_data_correlation(golden_data_binary, golden_data_format,
788                             test_data_binary, test_data_format,
789                             channel_map, parameters=None):
790    """Compares two raw data using correlation.
791
792    @param golden_data_binary: The binary containing golden data.
793    @param golden_data_format: The data format of golden data.
794    @param test_data_binary: The binary containing test data.
795    @param test_data_format: The data format of test data.
796    @param channel_map: A list containing channel mapping.
797                        E.g. [1, 0, None, None, None, None, None, None] means
798                        channel 0 of test data should map to channel 1 of
799                        golden data. Channel 1 of test data should map to
800                        channel 0 of golden data. Channel 2 to 7 of test data
801                        should be skipped.
802    @param parameters: A dict containing parameters for method, if needed.
803
804    @raises: NotImplementedError if file type is not raw.
805             NotImplementedError if sampling rates of two data are not the same.
806             error.TestFail if golden data and test data are not equal.
807    """
808    if parameters is None:
809        parameters = dict()
810
811    if (golden_data_format['file_type'] != 'raw' or
812        test_data_format['file_type'] != 'raw'):
813        raise NotImplementedError('Only support raw data in compare_data.')
814    if (golden_data_format['rate'] != test_data_format['rate']):
815        raise NotImplementedError(
816                'Only support comparing data with the same sampling rate')
817    golden_data = audio_data.AudioRawData(
818            binary=golden_data_binary,
819            channel=golden_data_format['channel'],
820            sample_format=golden_data_format['sample_format'])
821    test_data = audio_data.AudioRawData(
822            binary=test_data_binary,
823            channel=test_data_format['channel'],
824            sample_format=test_data_format['sample_format'])
825    compare_results = []
826    for test_channel, golden_channel in enumerate(channel_map):
827        if golden_channel is None:
828            logging.info('Skipped channel %d', test_channel)
829            continue
830        test_data_one_channel = test_data.channel_data[test_channel]
831        golden_data_one_channel = golden_data.channel_data[golden_channel]
832        result_dict = dict(test_channel=test_channel,
833                           golden_channel=golden_channel)
834        result_dict.update(
835                compare_one_channel_correlation(
836                        test_data_one_channel, golden_data_one_channel,
837                        parameters))
838        compare_results.append(result_dict)
839    logging.info('compare_results: %r', compare_results)
840    for result in compare_results:
841        if not result['equal']:
842            error_msg = ('Failed on test channel %d and golden channel %d with '
843                         'index %f') % (
844                                 result['test_channel'],
845                                 result['golden_channel'],
846                                 result['index'])
847            logging.error(error_msg)
848            raise error.TestFail(error_msg)
849    # Also checks best delay are exactly the same.
850    best_delays = set([result['best_delay'] for result in compare_results])
851    if len(best_delays) > 1:
852        error_msg = 'There are more than one best delay: %s' % best_delays
853        logging.error(error_msg)
854        raise error.TestFail(error_msg)
855
856
857class _base_rms_test(test.test):
858    """Base class for all rms_test """
859
860    def postprocess(self):
861        super(_base_rms_test, self).postprocess()
862
863        # Sum up the number of failed constraints in each iteration
864        if sum(len(x) for x in self.failed_constraints):
865            generate_rms_postmortem()
866
867
868class chrome_rms_test(_base_rms_test):
869    """Base test class for audio RMS test with Chrome.
870
871    The chrome instance can be accessed by self.chrome.
872    """
873    def warmup(self):
874        super(chrome_rms_test, self).warmup()
875
876        # Not all client of this file using telemetry.
877        # Just do the import here for those who really need it.
878        from autotest_lib.client.common_lib.cros import chrome
879
880        self.chrome = chrome.Chrome(init_network_controller=True)
881
882        # The audio configuration could be changed when we
883        # restart chrome.
884        try:
885            cras_rms_test_setup()
886        except Exception:
887            self.chrome.browser.Close()
888            raise
889
890
891    def cleanup(self, *args):
892        try:
893            self.chrome.close()
894        finally:
895            super(chrome_rms_test, self).cleanup()
896
897class cras_rms_test(_base_rms_test):
898    """Base test class for CRAS audio RMS test."""
899
900    def warmup(self):
901        super(cras_rms_test, self).warmup()
902        # Stop ui to make sure there are not other streams.
903        utils.stop_service('ui', ignore_status=True)
904        cras_rms_test_setup()
905
906    def cleanup(self, *args):
907        # Restart ui.
908        utils.start_service('ui', ignore_status=True)
909
910
911class alsa_rms_test(_base_rms_test):
912    """Base test class for ALSA audio RMS test.
913
914    Note the warmup will take 10 seconds and the device cannot be used before it
915    returns.
916    """
917    def warmup(self):
918        super(alsa_rms_test, self).warmup()
919
920        cras_rms_test_setup()
921        # We need CRAS to initialize the volume and gain.
922        cras_utils.playback(playback_file="/dev/zero", duration=1)
923        # CRAS will release the device after 10 seconds.
924        time.sleep(11)
925