1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subband_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <functional>
15 
16 #include "rtc_base/checks.h"
17 #include "rtc_base/numerics/safe_minmax.h"
18 #include "system_wrappers/include/field_trial.h"
19 
20 namespace webrtc {
21 
22 namespace {
23 
24 constexpr float kX2BandEnergyThreshold = 44015068.0f;
25 constexpr int kBlocksToHoldErle = 100;
26 constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150;
27 constexpr int kPointsToAccumulate = 6;
28 
SetMaxErleBands(float max_erle_l,float max_erle_h)29 std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l,
30                                                       float max_erle_h) {
31   std::array<float, kFftLengthBy2Plus1> max_erle;
32   std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l);
33   std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h);
34   return max_erle;
35 }
36 
EnableMinErleDuringOnsets()37 bool EnableMinErleDuringOnsets() {
38   return !field_trial::IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch");
39 }
40 
41 }  // namespace
42 
SubbandErleEstimator(const EchoCanceller3Config & config,size_t num_capture_channels)43 SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config,
44                                            size_t num_capture_channels)
45     : use_onset_detection_(config.erle.onset_detection),
46       min_erle_(config.erle.min),
47       max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
48       use_min_erle_during_onsets_(EnableMinErleDuringOnsets()),
49       accum_spectra_(num_capture_channels),
50       erle_(num_capture_channels),
51       erle_onsets_(num_capture_channels),
52       coming_onset_(num_capture_channels),
53       hold_counters_(num_capture_channels) {
54   Reset();
55 }
56 
57 SubbandErleEstimator::~SubbandErleEstimator() = default;
58 
Reset()59 void SubbandErleEstimator::Reset() {
60   for (auto& erle : erle_) {
61     erle.fill(min_erle_);
62   }
63   for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) {
64     erle_onsets_[ch].fill(min_erle_);
65     coming_onset_[ch].fill(true);
66     hold_counters_[ch].fill(0);
67   }
68   ResetAccumulatedSpectra();
69 }
70 
Update(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)71 void SubbandErleEstimator::Update(
72     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
73     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
74     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
75     const std::vector<bool>& converged_filters) {
76   UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters);
77   UpdateBands(converged_filters);
78 
79   if (use_onset_detection_) {
80     DecreaseErlePerBandForLowRenderSignals();
81   }
82 
83   for (auto& erle : erle_) {
84     erle[0] = erle[1];
85     erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1];
86   }
87 }
88 
Dump(const std::unique_ptr<ApmDataDumper> & data_dumper) const89 void SubbandErleEstimator::Dump(
90     const std::unique_ptr<ApmDataDumper>& data_dumper) const {
91   data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()[0]);
92 }
93 
UpdateBands(const std::vector<bool> & converged_filters)94 void SubbandErleEstimator::UpdateBands(
95     const std::vector<bool>& converged_filters) {
96   const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
97   for (int ch = 0; ch < num_capture_channels; ++ch) {
98     // Note that the use of the converged_filter flag already imposed
99     // a minimum of the erle that can be estimated as that flag would
100     // be false if the filter is performing poorly.
101     if (!converged_filters[ch]) {
102       continue;
103     }
104 
105     std::array<float, kFftLengthBy2> new_erle;
106     std::array<bool, kFftLengthBy2> is_erle_updated;
107     is_erle_updated.fill(false);
108 
109     for (size_t k = 1; k < kFftLengthBy2; ++k) {
110       if (accum_spectra_.num_points[ch] == kPointsToAccumulate &&
111           accum_spectra_.E2[ch][k] > 0.f) {
112         new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k];
113         is_erle_updated[k] = true;
114       }
115     }
116 
117     if (use_onset_detection_) {
118       for (size_t k = 1; k < kFftLengthBy2; ++k) {
119         if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) {
120           if (coming_onset_[ch][k]) {
121             coming_onset_[ch][k] = false;
122             if (!use_min_erle_during_onsets_) {
123               float alpha = new_erle[k] < erle_onsets_[ch][k] ? 0.3f : 0.15f;
124               erle_onsets_[ch][k] = rtc::SafeClamp(
125                   erle_onsets_[ch][k] +
126                       alpha * (new_erle[k] - erle_onsets_[ch][k]),
127                   min_erle_, max_erle_[k]);
128             }
129           }
130           hold_counters_[ch][k] = kBlocksForOnsetDetection;
131         }
132       }
133     }
134 
135     for (size_t k = 1; k < kFftLengthBy2; ++k) {
136       if (is_erle_updated[k]) {
137         float alpha = 0.05f;
138         if (new_erle[k] < erle_[ch][k]) {
139           alpha = accum_spectra_.low_render_energy[ch][k] ? 0.f : 0.1f;
140         }
141         erle_[ch][k] =
142             rtc::SafeClamp(erle_[ch][k] + alpha * (new_erle[k] - erle_[ch][k]),
143                            min_erle_, max_erle_[k]);
144       }
145     }
146   }
147 }
148 
DecreaseErlePerBandForLowRenderSignals()149 void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() {
150   const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
151   for (int ch = 0; ch < num_capture_channels; ++ch) {
152     for (size_t k = 1; k < kFftLengthBy2; ++k) {
153       --hold_counters_[ch][k];
154       if (hold_counters_[ch][k] <=
155           (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
156         if (erle_[ch][k] > erle_onsets_[ch][k]) {
157           erle_[ch][k] = std::max(erle_onsets_[ch][k], 0.97f * erle_[ch][k]);
158           RTC_DCHECK_LE(min_erle_, erle_[ch][k]);
159         }
160         if (hold_counters_[ch][k] <= 0) {
161           coming_onset_[ch][k] = true;
162           hold_counters_[ch][k] = 0;
163         }
164       }
165     }
166   }
167 }
168 
ResetAccumulatedSpectra()169 void SubbandErleEstimator::ResetAccumulatedSpectra() {
170   for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) {
171     accum_spectra_.Y2[ch].fill(0.f);
172     accum_spectra_.E2[ch].fill(0.f);
173     accum_spectra_.num_points[ch] = 0;
174     accum_spectra_.low_render_energy[ch].fill(false);
175   }
176 }
177 
UpdateAccumulatedSpectra(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)178 void SubbandErleEstimator::UpdateAccumulatedSpectra(
179     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
180     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
181     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
182     const std::vector<bool>& converged_filters) {
183   auto& st = accum_spectra_;
184   RTC_DCHECK_EQ(st.E2.size(), E2.size());
185   RTC_DCHECK_EQ(st.E2.size(), E2.size());
186   const int num_capture_channels = static_cast<int>(Y2.size());
187   for (int ch = 0; ch < num_capture_channels; ++ch) {
188     // Note that the use of the converged_filter flag already imposed
189     // a minimum of the erle that can be estimated as that flag would
190     // be false if the filter is performing poorly.
191     if (!converged_filters[ch]) {
192       continue;
193     }
194 
195     if (st.num_points[ch] == kPointsToAccumulate) {
196       st.num_points[ch] = 0;
197       st.Y2[ch].fill(0.f);
198       st.E2[ch].fill(0.f);
199       st.low_render_energy[ch].fill(false);
200     }
201 
202     std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(),
203                    st.Y2[ch].begin(), std::plus<float>());
204     std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(),
205                    st.E2[ch].begin(), std::plus<float>());
206 
207     for (size_t k = 0; k < X2.size(); ++k) {
208       st.low_render_energy[ch][k] =
209           st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold;
210     }
211 
212     ++st.num_points[ch];
213   }
214 }
215 
216 }  // namespace webrtc
217