1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/filter_analyzer.h"
12 
13 #include <math.h>
14 
15 #include <algorithm>
16 #include <array>
17 #include <numeric>
18 
19 #include "modules/audio_processing/aec3/aec3_common.h"
20 #include "modules/audio_processing/aec3/render_buffer.h"
21 #include "modules/audio_processing/logging/apm_data_dumper.h"
22 #include "rtc_base/atomic_ops.h"
23 #include "rtc_base/checks.h"
24 
25 namespace webrtc {
26 namespace {
27 
FindPeakIndex(rtc::ArrayView<const float> filter_time_domain,size_t peak_index_in,size_t start_sample,size_t end_sample)28 size_t FindPeakIndex(rtc::ArrayView<const float> filter_time_domain,
29                      size_t peak_index_in,
30                      size_t start_sample,
31                      size_t end_sample) {
32   size_t peak_index_out = peak_index_in;
33   float max_h2 =
34       filter_time_domain[peak_index_out] * filter_time_domain[peak_index_out];
35   for (size_t k = start_sample; k <= end_sample; ++k) {
36     float tmp = filter_time_domain[k] * filter_time_domain[k];
37     if (tmp > max_h2) {
38       peak_index_out = k;
39       max_h2 = tmp;
40     }
41   }
42 
43   return peak_index_out;
44 }
45 
46 }  // namespace
47 
48 int FilterAnalyzer::instance_count_ = 0;
49 
FilterAnalyzer(const EchoCanceller3Config & config,size_t num_capture_channels)50 FilterAnalyzer::FilterAnalyzer(const EchoCanceller3Config& config,
51                                size_t num_capture_channels)
52     : data_dumper_(
53           new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
54       bounded_erl_(config.ep_strength.bounded_erl),
55       default_gain_(config.ep_strength.default_gain),
56       h_highpass_(num_capture_channels,
57                   std::vector<float>(
58                       GetTimeDomainLength(config.filter.refined.length_blocks),
59                       0.f)),
60       filter_analysis_states_(num_capture_channels,
61                               FilterAnalysisState(config)),
62       filter_delays_blocks_(num_capture_channels, 0) {
63   Reset();
64 }
65 
66 FilterAnalyzer::~FilterAnalyzer() = default;
67 
Reset()68 void FilterAnalyzer::Reset() {
69   blocks_since_reset_ = 0;
70   ResetRegion();
71   for (auto& state : filter_analysis_states_) {
72     state.peak_index = 0;
73     state.gain = default_gain_;
74     state.consistent_filter_detector.Reset();
75   }
76   std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), 0);
77 }
78 
Update(rtc::ArrayView<const std::vector<float>> filters_time_domain,const RenderBuffer & render_buffer,bool * any_filter_consistent,float * max_echo_path_gain)79 void FilterAnalyzer::Update(
80     rtc::ArrayView<const std::vector<float>> filters_time_domain,
81     const RenderBuffer& render_buffer,
82     bool* any_filter_consistent,
83     float* max_echo_path_gain) {
84   RTC_DCHECK(any_filter_consistent);
85   RTC_DCHECK(max_echo_path_gain);
86   RTC_DCHECK_EQ(filters_time_domain.size(), filter_analysis_states_.size());
87   RTC_DCHECK_EQ(filters_time_domain.size(), h_highpass_.size());
88 
89   ++blocks_since_reset_;
90   SetRegionToAnalyze(filters_time_domain[0].size());
91   AnalyzeRegion(filters_time_domain, render_buffer);
92 
93   // Aggregate the results for all capture channels.
94   auto& st_ch0 = filter_analysis_states_[0];
95   *any_filter_consistent = st_ch0.consistent_estimate;
96   *max_echo_path_gain = st_ch0.gain;
97   min_filter_delay_blocks_ = filter_delays_blocks_[0];
98   for (size_t ch = 1; ch < filters_time_domain.size(); ++ch) {
99     auto& st_ch = filter_analysis_states_[ch];
100     *any_filter_consistent =
101         *any_filter_consistent || st_ch.consistent_estimate;
102     *max_echo_path_gain = std::max(*max_echo_path_gain, st_ch.gain);
103     min_filter_delay_blocks_ =
104         std::min(min_filter_delay_blocks_, filter_delays_blocks_[ch]);
105   }
106 }
107 
AnalyzeRegion(rtc::ArrayView<const std::vector<float>> filters_time_domain,const RenderBuffer & render_buffer)108 void FilterAnalyzer::AnalyzeRegion(
109     rtc::ArrayView<const std::vector<float>> filters_time_domain,
110     const RenderBuffer& render_buffer) {
111   // Preprocess the filter to avoid issues with low-frequency components in the
112   // filter.
113   PreProcessFilters(filters_time_domain);
114   data_dumper_->DumpRaw("aec3_linear_filter_processed_td", h_highpass_[0]);
115 
116   constexpr float kOneByBlockSize = 1.f / kBlockSize;
117   for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) {
118     RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size());
119     RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size());
120 
121     auto& st_ch = filter_analysis_states_[ch];
122     RTC_DCHECK_EQ(h_highpass_[ch].size(), filters_time_domain[ch].size());
123     RTC_DCHECK_GT(h_highpass_[ch].size(), 0);
124     st_ch.peak_index = std::min(st_ch.peak_index, h_highpass_[ch].size() - 1);
125 
126     st_ch.peak_index =
127         FindPeakIndex(h_highpass_[ch], st_ch.peak_index, region_.start_sample_,
128                       region_.end_sample_);
129     filter_delays_blocks_[ch] = st_ch.peak_index >> kBlockSizeLog2;
130     UpdateFilterGain(h_highpass_[ch], &st_ch);
131     st_ch.filter_length_blocks =
132         filters_time_domain[ch].size() * kOneByBlockSize;
133 
134     st_ch.consistent_estimate = st_ch.consistent_filter_detector.Detect(
135         h_highpass_[ch], region_,
136         render_buffer.Block(-filter_delays_blocks_[ch])[0], st_ch.peak_index,
137         filter_delays_blocks_[ch]);
138   }
139 }
140 
UpdateFilterGain(rtc::ArrayView<const float> filter_time_domain,FilterAnalysisState * st)141 void FilterAnalyzer::UpdateFilterGain(
142     rtc::ArrayView<const float> filter_time_domain,
143     FilterAnalysisState* st) {
144   bool sufficient_time_to_converge =
145       blocks_since_reset_ > 5 * kNumBlocksPerSecond;
146 
147   if (sufficient_time_to_converge && st->consistent_estimate) {
148     st->gain = fabsf(filter_time_domain[st->peak_index]);
149   } else {
150     // TODO(peah): Verify whether this check against a float is ok.
151     if (st->gain) {
152       st->gain = std::max(st->gain, fabsf(filter_time_domain[st->peak_index]));
153     }
154   }
155 
156   if (bounded_erl_ && st->gain) {
157     st->gain = std::max(st->gain, 0.01f);
158   }
159 }
160 
PreProcessFilters(rtc::ArrayView<const std::vector<float>> filters_time_domain)161 void FilterAnalyzer::PreProcessFilters(
162     rtc::ArrayView<const std::vector<float>> filters_time_domain) {
163   for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) {
164     RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size());
165     RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size());
166 
167     RTC_DCHECK_GE(h_highpass_[ch].capacity(), filters_time_domain[ch].size());
168     h_highpass_[ch].resize(filters_time_domain[ch].size());
169     // Minimum phase high-pass filter with cutoff frequency at about 600 Hz.
170     constexpr std::array<float, 3> h = {
171         {0.7929742f, -0.36072128f, -0.47047766f}};
172 
173     std::fill(h_highpass_[ch].begin() + region_.start_sample_,
174               h_highpass_[ch].begin() + region_.end_sample_ + 1, 0.f);
175     for (size_t k = std::max(h.size() - 1, region_.start_sample_);
176          k <= region_.end_sample_; ++k) {
177       for (size_t j = 0; j < h.size(); ++j) {
178         h_highpass_[ch][k] += filters_time_domain[ch][k - j] * h[j];
179       }
180     }
181   }
182 }
183 
ResetRegion()184 void FilterAnalyzer::ResetRegion() {
185   region_.start_sample_ = 0;
186   region_.end_sample_ = 0;
187 }
188 
SetRegionToAnalyze(size_t filter_size)189 void FilterAnalyzer::SetRegionToAnalyze(size_t filter_size) {
190   constexpr size_t kNumberBlocksToUpdate = 1;
191   auto& r = region_;
192   r.start_sample_ = r.end_sample_ >= filter_size - 1 ? 0 : r.end_sample_ + 1;
193   r.end_sample_ =
194       std::min(r.start_sample_ + kNumberBlocksToUpdate * kBlockSize - 1,
195                filter_size - 1);
196 
197   // Check range.
198   RTC_DCHECK_LT(r.start_sample_, filter_size);
199   RTC_DCHECK_LT(r.end_sample_, filter_size);
200   RTC_DCHECK_LE(r.start_sample_, r.end_sample_);
201 }
202 
ConsistentFilterDetector(const EchoCanceller3Config & config)203 FilterAnalyzer::ConsistentFilterDetector::ConsistentFilterDetector(
204     const EchoCanceller3Config& config)
205     : active_render_threshold_(config.render_levels.active_render_limit *
206                                config.render_levels.active_render_limit *
207                                kFftLengthBy2) {}
208 
Reset()209 void FilterAnalyzer::ConsistentFilterDetector::Reset() {
210   significant_peak_ = false;
211   filter_floor_accum_ = 0.f;
212   filter_secondary_peak_ = 0.f;
213   filter_floor_low_limit_ = 0;
214   filter_floor_high_limit_ = 0;
215   consistent_estimate_counter_ = 0;
216   consistent_delay_reference_ = -10;
217 }
218 
Detect(rtc::ArrayView<const float> filter_to_analyze,const FilterRegion & region,rtc::ArrayView<const std::vector<float>> x_block,size_t peak_index,int delay_blocks)219 bool FilterAnalyzer::ConsistentFilterDetector::Detect(
220     rtc::ArrayView<const float> filter_to_analyze,
221     const FilterRegion& region,
222     rtc::ArrayView<const std::vector<float>> x_block,
223     size_t peak_index,
224     int delay_blocks) {
225   if (region.start_sample_ == 0) {
226     filter_floor_accum_ = 0.f;
227     filter_secondary_peak_ = 0.f;
228     filter_floor_low_limit_ = peak_index < 64 ? 0 : peak_index - 64;
229     filter_floor_high_limit_ =
230         peak_index > filter_to_analyze.size() - 129 ? 0 : peak_index + 128;
231   }
232 
233   for (size_t k = region.start_sample_;
234        k < std::min(region.end_sample_ + 1, filter_floor_low_limit_); ++k) {
235     float abs_h = fabsf(filter_to_analyze[k]);
236     filter_floor_accum_ += abs_h;
237     filter_secondary_peak_ = std::max(filter_secondary_peak_, abs_h);
238   }
239 
240   for (size_t k = std::max(filter_floor_high_limit_, region.start_sample_);
241        k <= region.end_sample_; ++k) {
242     float abs_h = fabsf(filter_to_analyze[k]);
243     filter_floor_accum_ += abs_h;
244     filter_secondary_peak_ = std::max(filter_secondary_peak_, abs_h);
245   }
246 
247   if (region.end_sample_ == filter_to_analyze.size() - 1) {
248     float filter_floor = filter_floor_accum_ /
249                          (filter_floor_low_limit_ + filter_to_analyze.size() -
250                           filter_floor_high_limit_);
251 
252     float abs_peak = fabsf(filter_to_analyze[peak_index]);
253     significant_peak_ = abs_peak > 10.f * filter_floor &&
254                         abs_peak > 2.f * filter_secondary_peak_;
255   }
256 
257   if (significant_peak_) {
258     bool active_render_block = false;
259     for (auto& x_channel : x_block) {
260       const float x_energy = std::inner_product(
261           x_channel.begin(), x_channel.end(), x_channel.begin(), 0.f);
262       if (x_energy > active_render_threshold_) {
263         active_render_block = true;
264         break;
265       }
266     }
267 
268     if (consistent_delay_reference_ == delay_blocks) {
269       if (active_render_block) {
270         ++consistent_estimate_counter_;
271       }
272     } else {
273       consistent_estimate_counter_ = 0;
274       consistent_delay_reference_ = delay_blocks;
275     }
276   }
277   return consistent_estimate_counter_ > 1.5f * kNumBlocksPerSecond;
278 }
279 
280 }  // namespace webrtc
281