1#!/usr/bin/env python
2# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3#
4# Use of this source code is governed by a BSD-style license
5# that can be found in the LICENSE file in the root of the source
6# tree. An additional intellectual property rights grant can be found
7# in the file PATENTS.  All contributing project authors may
8# be found in the AUTHORS file in the root of the source tree.
9
10"""Perform APM module quality assessment on one or more input files using one or
11   more APM simulator configuration files and one or more test data generators.
12
13Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...]
14                                 -c cfg1.json [cfg2.json ...]
15                                 -n white [echo ...]
16                                 -e audio_level [polqa ...]
17                                 -o /path/to/output
18"""
19
20import argparse
21import logging
22import os
23import sys
24
25import quality_assessment.audioproc_wrapper as audioproc_wrapper
26import quality_assessment.echo_path_simulation as echo_path_simulation
27import quality_assessment.eval_scores as eval_scores
28import quality_assessment.evaluation as evaluation
29import quality_assessment.eval_scores_factory as eval_scores_factory
30import quality_assessment.external_vad as external_vad
31import quality_assessment.test_data_generation as test_data_generation
32import quality_assessment.test_data_generation_factory as  \
33    test_data_generation_factory
34import quality_assessment.simulation as simulation
35
36_ECHO_PATH_SIMULATOR_NAMES = (
37    echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
38_TEST_DATA_GENERATOR_CLASSES = (
39    test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
40_TEST_DATA_GENERATORS_NAMES = _TEST_DATA_GENERATOR_CLASSES.keys()
41_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
42_EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys()
43
44_DEFAULT_CONFIG_FILE = 'apm_configs/default.json'
45
46_POLQA_BIN_NAME = 'PolqaOem64'
47
48
49def _InstanceArgumentsParser():
50  """Arguments parser factory.
51  """
52  parser = argparse.ArgumentParser(description=(
53      'Perform APM module quality assessment on one or more input files using '
54      'one or more APM simulator configuration files and one or more '
55      'test data generators.'))
56
57  parser.add_argument('-c', '--config_files', nargs='+', required=False,
58                      help=('path to the configuration files defining the '
59                            'arguments with which the APM simulator tool is '
60                            'called'),
61                      default=[_DEFAULT_CONFIG_FILE])
62
63  parser.add_argument('-i', '--capture_input_files', nargs='+', required=True,
64                      help='path to the capture input wav files (one or more)')
65
66  parser.add_argument('-r', '--render_input_files', nargs='+', required=False,
67                      help=('path to the render input wav files; either '
68                            'omitted or one file for each file in '
69                            '--capture_input_files (files will be paired by '
70                            'index)'), default=None)
71
72  parser.add_argument('-p', '--echo_path_simulator', required=False,
73                      help=('custom echo path simulator name; required if '
74                            '--render_input_files is specified'),
75                      choices=_ECHO_PATH_SIMULATOR_NAMES,
76                      default=echo_path_simulation.NoEchoPathSimulator.NAME)
77
78  parser.add_argument('-t', '--test_data_generators', nargs='+', required=False,
79                      help='custom list of test data generators to use',
80                      choices=_TEST_DATA_GENERATORS_NAMES,
81                      default=_TEST_DATA_GENERATORS_NAMES)
82
83  parser.add_argument('--additive_noise_tracks_path', required=False,
84                      help='path to the wav files for the additive',
85                      default=test_data_generation.  \
86                              AdditiveNoiseTestDataGenerator.  \
87                              DEFAULT_NOISE_TRACKS_PATH)
88
89  parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
90                      help='custom list of evaluation scores to use',
91                      choices=_EVAL_SCORE_WORKER_NAMES,
92                      default=_EVAL_SCORE_WORKER_NAMES)
93
94  parser.add_argument('-o', '--output_dir', required=False,
95                      help=('base path to the output directory in which the '
96                            'output wav files and the evaluation outcomes '
97                            'are saved'),
98                      default='output')
99
100  parser.add_argument('--polqa_path', required=True,
101                      help='path to the POLQA tool')
102
103  parser.add_argument('--air_db_path', required=True,
104                      help='path to the Aechen IR database')
105
106  parser.add_argument('--apm_sim_path', required=False,
107                      help='path to the APM simulator tool',
108                      default=audioproc_wrapper.  \
109                              AudioProcWrapper.  \
110                              DEFAULT_APM_SIMULATOR_BIN_PATH)
111
112  parser.add_argument('--echo_metric_tool_bin_path', required=False,
113                      help=('path to the echo metric binary '
114                           '(required for the echo eval score)'),
115                      default=None)
116
117  parser.add_argument('--copy_with_identity_generator', required=False,
118                      help=('If true, the identity test data generator makes a '
119                            'copy of the clean speech input file.'),
120                      default=False)
121
122  parser.add_argument('--external_vad_paths', nargs='+', required=False,
123                      help=('Paths to external VAD programs. Each must take'
124                            '\'-i <wav file> -o <output>\' inputs'), default=[])
125
126  parser.add_argument('--external_vad_names', nargs='+', required=False,
127                      help=('Keys to the vad paths. Must be different and '
128                            'as many as the paths.'), default=[])
129
130  return parser
131
132
133def _ValidateArguments(args, parser):
134  if args.capture_input_files and args.render_input_files and (
135      len(args.capture_input_files) != len(args.render_input_files)):
136    parser.error('--render_input_files and --capture_input_files must be lists '
137                 'having the same length')
138    sys.exit(1)
139
140  if args.render_input_files and not args.echo_path_simulator:
141    parser.error('when --render_input_files is set, --echo_path_simulator is '
142                 'also required')
143    sys.exit(1)
144
145  if len(args.external_vad_names) != len(args.external_vad_paths):
146    parser.error('If provided, --external_vad_paths and '
147                 '--external_vad_names must '
148                 'have the same number of arguments.')
149    sys.exit(1)
150
151
152def main():
153  # TODO(alessiob): level = logging.INFO once debugged.
154  logging.basicConfig(level=logging.DEBUG)
155  parser = _InstanceArgumentsParser()
156  args = parser.parse_args()
157  _ValidateArguments(args, parser)
158
159  simulator = simulation.ApmModuleSimulator(
160      test_data_generator_factory=(
161          test_data_generation_factory.TestDataGeneratorFactory(
162              aechen_ir_database_path=args.air_db_path,
163              noise_tracks_path=args.additive_noise_tracks_path,
164              copy_with_identity=args.copy_with_identity_generator)),
165      evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
166          polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
167          echo_metric_tool_bin_path=args.echo_metric_tool_bin_path
168      ),
169      ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
170      evaluator=evaluation.ApmModuleEvaluator(),
171      external_vads=external_vad.ExternalVad.ConstructVadDict(
172          args.external_vad_paths, args.external_vad_names))
173  simulator.Run(
174      config_filepaths=args.config_files,
175      capture_input_filepaths=args.capture_input_files,
176      render_input_filepaths=args.render_input_files,
177      echo_path_simulator_name=args.echo_path_simulator,
178      test_data_generator_names=args.test_data_generators,
179      eval_score_names=args.eval_scores,
180      output_dir=args.output_dir)
181  sys.exit(0)
182
183
184if __name__ == '__main__':
185  main()
186