1 /*
2  *  Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/signal_classifier.h"
12 
13 #include <algorithm>
14 #include <numeric>
15 #include <vector>
16 
17 #include "api/array_view.h"
18 #include "modules/audio_processing/agc2/down_sampler.h"
19 #include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
20 #include "modules/audio_processing/logging/apm_data_dumper.h"
21 #include "rtc_base/checks.h"
22 #include "system_wrappers/include/cpu_features_wrapper.h"
23 
24 namespace webrtc {
25 namespace {
26 
IsSse2Available()27 bool IsSse2Available() {
28 #if defined(WEBRTC_ARCH_X86_FAMILY)
29   return WebRtc_GetCPUInfo(kSSE2) != 0;
30 #else
31   return false;
32 #endif
33 }
34 
RemoveDcLevel(rtc::ArrayView<float> x)35 void RemoveDcLevel(rtc::ArrayView<float> x) {
36   RTC_DCHECK_LT(0, x.size());
37   float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
38   mean /= x.size();
39 
40   for (float& v : x) {
41     v -= mean;
42   }
43 }
44 
PowerSpectrum(const OouraFft * ooura_fft,rtc::ArrayView<const float> x,rtc::ArrayView<float> spectrum)45 void PowerSpectrum(const OouraFft* ooura_fft,
46                    rtc::ArrayView<const float> x,
47                    rtc::ArrayView<float> spectrum) {
48   RTC_DCHECK_EQ(65, spectrum.size());
49   RTC_DCHECK_EQ(128, x.size());
50   float X[128];
51   std::copy(x.data(), x.data() + x.size(), X);
52   ooura_fft->Fft(X);
53 
54   float* X_p = X;
55   RTC_DCHECK_EQ(X_p, &X[0]);
56   spectrum[0] = (*X_p) * (*X_p);
57   ++X_p;
58   RTC_DCHECK_EQ(X_p, &X[1]);
59   spectrum[64] = (*X_p) * (*X_p);
60   for (int k = 1; k < 64; ++k) {
61     ++X_p;
62     RTC_DCHECK_EQ(X_p, &X[2 * k]);
63     spectrum[k] = (*X_p) * (*X_p);
64     ++X_p;
65     RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
66     spectrum[k] += (*X_p) * (*X_p);
67   }
68 }
69 
ClassifySignal(rtc::ArrayView<const float> signal_spectrum,rtc::ArrayView<const float> noise_spectrum,ApmDataDumper * data_dumper)70 webrtc::SignalClassifier::SignalType ClassifySignal(
71     rtc::ArrayView<const float> signal_spectrum,
72     rtc::ArrayView<const float> noise_spectrum,
73     ApmDataDumper* data_dumper) {
74   int num_stationary_bands = 0;
75   int num_highly_nonstationary_bands = 0;
76 
77   // Detect stationary and highly nonstationary bands.
78   for (size_t k = 1; k < 40; k++) {
79     if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
80         signal_spectrum[k] * 3 > noise_spectrum[k]) {
81       ++num_stationary_bands;
82     } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
83       ++num_highly_nonstationary_bands;
84     }
85   }
86 
87   data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
88   data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
89                        &num_highly_nonstationary_bands);
90 
91   // Use the detected number of bands to classify the overall signal
92   // stationarity.
93   if (num_stationary_bands > 15) {
94     return SignalClassifier::SignalType::kStationary;
95   } else {
96     return SignalClassifier::SignalType::kNonStationary;
97   }
98 }
99 
100 }  // namespace
101 
FrameExtender(size_t frame_size,size_t extended_frame_size)102 SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
103                                                size_t extended_frame_size)
104     : x_old_(extended_frame_size - frame_size, 0.f) {}
105 
106 SignalClassifier::FrameExtender::~FrameExtender() = default;
107 
ExtendFrame(rtc::ArrayView<const float> x,rtc::ArrayView<float> x_extended)108 void SignalClassifier::FrameExtender::ExtendFrame(
109     rtc::ArrayView<const float> x,
110     rtc::ArrayView<float> x_extended) {
111   RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
112   std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
113   std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
114   std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
115             x_extended.data() + x_extended.size(), x_old_.data());
116 }
117 
SignalClassifier(ApmDataDumper * data_dumper)118 SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
119     : data_dumper_(data_dumper),
120       down_sampler_(data_dumper_),
121       noise_spectrum_estimator_(data_dumper_),
122       ooura_fft_(IsSse2Available()) {
123   Initialize(48000);
124 }
~SignalClassifier()125 SignalClassifier::~SignalClassifier() {}
126 
Initialize(int sample_rate_hz)127 void SignalClassifier::Initialize(int sample_rate_hz) {
128   down_sampler_.Initialize(sample_rate_hz);
129   noise_spectrum_estimator_.Initialize();
130   frame_extender_.reset(new FrameExtender(80, 128));
131   sample_rate_hz_ = sample_rate_hz;
132   initialization_frames_left_ = 2;
133   consistent_classification_counter_ = 3;
134   last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
135 }
136 
Analyze(rtc::ArrayView<const float> signal)137 SignalClassifier::SignalType SignalClassifier::Analyze(
138     rtc::ArrayView<const float> signal) {
139   RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
140 
141   // Compute the signal power spectrum.
142   float downsampled_frame[80];
143   down_sampler_.DownSample(signal, downsampled_frame);
144   float extended_frame[128];
145   frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
146   RemoveDcLevel(extended_frame);
147   float signal_spectrum[65];
148   PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
149 
150   // Classify the signal based on the estimate of the noise spectrum and the
151   // signal spectrum estimate.
152   const SignalType signal_type = ClassifySignal(
153       signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
154       data_dumper_);
155 
156   // Update the noise spectrum based on the signal spectrum.
157   noise_spectrum_estimator_.Update(signal_spectrum,
158                                    initialization_frames_left_ > 0);
159 
160   // Update the number of frames until a reliable signal spectrum is achieved.
161   initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
162 
163   if (last_signal_type_ == signal_type) {
164     consistent_classification_counter_ =
165         std::max(0, consistent_classification_counter_ - 1);
166   } else {
167     last_signal_type_ = signal_type;
168     consistent_classification_counter_ = 3;
169   }
170 
171   if (consistent_classification_counter_ > 0) {
172     return SignalClassifier::SignalType::kNonStationary;
173   }
174   return signal_type;
175 }
176 
177 }  // namespace webrtc
178