1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/render_signal_analyzer.h"
12 
13 #include <math.h>
14 
15 #include <algorithm>
16 #include <utility>
17 #include <vector>
18 
19 #include "api/array_view.h"
20 #include "rtc_base/checks.h"
21 
22 namespace webrtc {
23 
24 namespace {
25 constexpr size_t kCounterThreshold = 5;
26 
27 // Identifies local bands with narrow characteristics.
IdentifySmallNarrowBandRegions(const RenderBuffer & render_buffer,const absl::optional<size_t> & delay_partitions,std::array<size_t,kFftLengthBy2-1> * narrow_band_counters)28 void IdentifySmallNarrowBandRegions(
29     const RenderBuffer& render_buffer,
30     const absl::optional<size_t>& delay_partitions,
31     std::array<size_t, kFftLengthBy2 - 1>* narrow_band_counters) {
32   RTC_DCHECK(narrow_band_counters);
33 
34   if (!delay_partitions) {
35     narrow_band_counters->fill(0);
36     return;
37   }
38 
39   std::array<size_t, kFftLengthBy2 - 1> channel_counters;
40   channel_counters.fill(0);
41   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> X2 =
42       render_buffer.Spectrum(*delay_partitions);
43   for (size_t ch = 0; ch < X2.size(); ++ch) {
44     for (size_t k = 1; k < kFftLengthBy2; ++k) {
45       if (X2[ch][k] > 3 * std::max(X2[ch][k - 1], X2[ch][k + 1])) {
46         ++channel_counters[k - 1];
47       }
48     }
49   }
50   for (size_t k = 1; k < kFftLengthBy2; ++k) {
51     (*narrow_band_counters)[k - 1] =
52         channel_counters[k - 1] > 0 ? (*narrow_band_counters)[k - 1] + 1 : 0;
53   }
54 }
55 
56 // Identifies whether the signal has a single strong narrow-band component.
IdentifyStrongNarrowBandComponent(const RenderBuffer & render_buffer,int strong_peak_freeze_duration,absl::optional<int> * narrow_peak_band,size_t * narrow_peak_counter)57 void IdentifyStrongNarrowBandComponent(const RenderBuffer& render_buffer,
58                                        int strong_peak_freeze_duration,
59                                        absl::optional<int>* narrow_peak_band,
60                                        size_t* narrow_peak_counter) {
61   RTC_DCHECK(narrow_peak_band);
62   RTC_DCHECK(narrow_peak_counter);
63   if (*narrow_peak_band &&
64       ++(*narrow_peak_counter) >
65           static_cast<size_t>(strong_peak_freeze_duration)) {
66     *narrow_peak_band = absl::nullopt;
67   }
68 
69   const std::vector<std::vector<std::vector<float>>>& x_latest =
70       render_buffer.Block(0);
71   float max_peak_level = 0.f;
72   for (size_t channel = 0; channel < x_latest[0].size(); ++channel) {
73     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2_latest =
74         render_buffer.Spectrum(0)[channel];
75 
76     // Identify the spectral peak.
77     const int peak_bin =
78         static_cast<int>(std::max_element(X2_latest.begin(), X2_latest.end()) -
79                          X2_latest.begin());
80 
81     // Compute the level around the peak.
82     float non_peak_power = 0.f;
83     for (int k = std::max(0, peak_bin - 14); k < peak_bin - 4; ++k) {
84       non_peak_power = std::max(X2_latest[k], non_peak_power);
85     }
86     for (int k = peak_bin + 5;
87          k < std::min(peak_bin + 15, static_cast<int>(kFftLengthBy2Plus1));
88          ++k) {
89       non_peak_power = std::max(X2_latest[k], non_peak_power);
90     }
91 
92     // Assess the render signal strength.
93     auto result0 = std::minmax_element(x_latest[0][channel].begin(),
94                                        x_latest[0][channel].end());
95     float max_abs = std::max(fabs(*result0.first), fabs(*result0.second));
96 
97     if (x_latest.size() > 1) {
98       const auto result1 = std::minmax_element(x_latest[1][channel].begin(),
99                                                x_latest[1][channel].end());
100       max_abs =
101           std::max(max_abs, static_cast<float>(std::max(
102                                 fabs(*result1.first), fabs(*result1.second))));
103     }
104 
105     // Detect whether the spectral peak has as strong narrowband nature.
106     const float peak_level = X2_latest[peak_bin];
107     if (peak_bin > 0 && max_abs > 100 && peak_level > 100 * non_peak_power) {
108       // Store the strongest peak across channels.
109       if (peak_level > max_peak_level) {
110         max_peak_level = peak_level;
111         *narrow_peak_band = peak_bin;
112         *narrow_peak_counter = 0;
113       }
114     }
115   }
116 }
117 
118 }  // namespace
119 
RenderSignalAnalyzer(const EchoCanceller3Config & config)120 RenderSignalAnalyzer::RenderSignalAnalyzer(const EchoCanceller3Config& config)
121     : strong_peak_freeze_duration_(config.filter.refined.length_blocks) {
122   narrow_band_counters_.fill(0);
123 }
124 RenderSignalAnalyzer::~RenderSignalAnalyzer() = default;
125 
Update(const RenderBuffer & render_buffer,const absl::optional<size_t> & delay_partitions)126 void RenderSignalAnalyzer::Update(
127     const RenderBuffer& render_buffer,
128     const absl::optional<size_t>& delay_partitions) {
129   // Identify bands of narrow nature.
130   IdentifySmallNarrowBandRegions(render_buffer, delay_partitions,
131                                  &narrow_band_counters_);
132 
133   // Identify the presence of a strong narrow band.
134   IdentifyStrongNarrowBandComponent(render_buffer, strong_peak_freeze_duration_,
135                                     &narrow_peak_band_, &narrow_peak_counter_);
136 }
137 
MaskRegionsAroundNarrowBands(std::array<float,kFftLengthBy2Plus1> * v) const138 void RenderSignalAnalyzer::MaskRegionsAroundNarrowBands(
139     std::array<float, kFftLengthBy2Plus1>* v) const {
140   RTC_DCHECK(v);
141 
142   // Set v to zero around narrow band signal regions.
143   if (narrow_band_counters_[0] > kCounterThreshold) {
144     (*v)[1] = (*v)[0] = 0.f;
145   }
146   for (size_t k = 2; k < kFftLengthBy2 - 1; ++k) {
147     if (narrow_band_counters_[k - 1] > kCounterThreshold) {
148       (*v)[k - 2] = (*v)[k - 1] = (*v)[k] = (*v)[k + 1] = (*v)[k + 2] = 0.f;
149     }
150   }
151   if (narrow_band_counters_[kFftLengthBy2 - 2] > kCounterThreshold) {
152     (*v)[kFftLengthBy2] = (*v)[kFftLengthBy2 - 1] = 0.f;
153   }
154 }
155 
156 }  // namespace webrtc
157