1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/reverb_model_estimator.h"
12 
13 #include <algorithm>
14 #include <array>
15 #include <cmath>
16 #include <numeric>
17 #include <vector>
18 
19 #include "absl/types/optional.h"
20 #include "api/array_view.h"
21 #include "api/audio/echo_canceller3_config.h"
22 #include "modules/audio_processing/aec3/aec3_common.h"
23 #include "modules/audio_processing/aec3/aec3_fft.h"
24 #include "modules/audio_processing/aec3/fft_data.h"
25 #include "rtc_base/checks.h"
26 #include "test/gtest.h"
27 
28 namespace webrtc {
29 
30 namespace {
31 
CreateConfigForTest(float default_decay)32 EchoCanceller3Config CreateConfigForTest(float default_decay) {
33   EchoCanceller3Config cfg;
34   cfg.ep_strength.default_len = default_decay;
35   cfg.filter.refined.length_blocks = 40;
36   return cfg;
37 }
38 
39 constexpr int kFilterDelayBlocks = 2;
40 
41 }  // namespace
42 
43 class ReverbModelEstimatorTest {
44  public:
ReverbModelEstimatorTest(float default_decay,size_t num_capture_channels)45   ReverbModelEstimatorTest(float default_decay, size_t num_capture_channels)
46       : aec3_config_(CreateConfigForTest(default_decay)),
47         estimated_decay_(default_decay),
48         h_(num_capture_channels,
49            std::vector<float>(
50                aec3_config_.filter.refined.length_blocks * kBlockSize,
51                0.f)),
52         H2_(num_capture_channels,
53             std::vector<std::array<float, kFftLengthBy2Plus1>>(
54                 aec3_config_.filter.refined.length_blocks)),
55         quality_linear_(num_capture_channels, 1.0f) {
56     CreateImpulseResponseWithDecay();
57   }
58   void RunEstimator();
GetDecay()59   float GetDecay() { return estimated_decay_; }
GetTrueDecay()60   float GetTrueDecay() { return kTruePowerDecay; }
GetPowerTailDb()61   float GetPowerTailDb() { return 10.f * std::log10(estimated_power_tail_); }
GetTruePowerTailDb()62   float GetTruePowerTailDb() { return 10.f * std::log10(true_power_tail_); }
63 
64  private:
65   void CreateImpulseResponseWithDecay();
66   static constexpr bool kStationaryBlock = false;
67   static constexpr float kTruePowerDecay = 0.5f;
68   const EchoCanceller3Config aec3_config_;
69   float estimated_decay_;
70   float estimated_power_tail_ = 0.f;
71   float true_power_tail_ = 0.f;
72   std::vector<std::vector<float>> h_;
73   std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_;
74   std::vector<absl::optional<float>> quality_linear_;
75 };
76 
CreateImpulseResponseWithDecay()77 void ReverbModelEstimatorTest::CreateImpulseResponseWithDecay() {
78   const Aec3Fft fft;
79   for (const auto& h_k : h_) {
80     RTC_DCHECK_EQ(h_k.size(),
81                   aec3_config_.filter.refined.length_blocks * kBlockSize);
82   }
83   for (const auto& H2_k : H2_) {
84     RTC_DCHECK_EQ(H2_k.size(), aec3_config_.filter.refined.length_blocks);
85   }
86   RTC_DCHECK_EQ(kFilterDelayBlocks, 2);
87 
88   float decay_sample = std::sqrt(powf(kTruePowerDecay, 1.f / kBlockSize));
89   const size_t filter_delay_coefficients = kFilterDelayBlocks * kBlockSize;
90   for (auto& h_i : h_) {
91     std::fill(h_i.begin(), h_i.end(), 0.f);
92     h_i[filter_delay_coefficients] = 1.f;
93     for (size_t k = filter_delay_coefficients + 1; k < h_i.size(); ++k) {
94       h_i[k] = h_i[k - 1] * decay_sample;
95     }
96   }
97 
98   for (size_t ch = 0; ch < H2_.size(); ++ch) {
99     for (size_t j = 0, k = 0; j < H2_[ch].size(); ++j, k += kBlockSize) {
100       std::array<float, kFftLength> fft_data;
101       fft_data.fill(0.f);
102       std::copy(h_[ch].begin() + k, h_[ch].begin() + k + kBlockSize,
103                 fft_data.begin());
104       FftData H_j;
105       fft.Fft(&fft_data, &H_j);
106       H_j.Spectrum(Aec3Optimization::kNone, H2_[ch][j]);
107     }
108   }
109   rtc::ArrayView<float> H2_tail(H2_[0][H2_[0].size() - 1]);
110   true_power_tail_ = std::accumulate(H2_tail.begin(), H2_tail.end(), 0.f);
111 }
RunEstimator()112 void ReverbModelEstimatorTest::RunEstimator() {
113   const size_t num_capture_channels = H2_.size();
114   constexpr bool kUsableLinearEstimate = true;
115   ReverbModelEstimator estimator(aec3_config_, num_capture_channels);
116   std::vector<bool> usable_linear_estimates(num_capture_channels,
117                                             kUsableLinearEstimate);
118   std::vector<int> filter_delay_blocks(num_capture_channels,
119                                        kFilterDelayBlocks);
120   for (size_t k = 0; k < 3000; ++k) {
121     estimator.Update(h_, H2_, quality_linear_, filter_delay_blocks,
122                      usable_linear_estimates, kStationaryBlock);
123   }
124   estimated_decay_ = estimator.ReverbDecay();
125   auto freq_resp_tail = estimator.GetReverbFrequencyResponse();
126   estimated_power_tail_ =
127       std::accumulate(freq_resp_tail.begin(), freq_resp_tail.end(), 0.f);
128 }
129 
TEST(ReverbModelEstimatorTests,NotChangingDecay)130 TEST(ReverbModelEstimatorTests, NotChangingDecay) {
131   constexpr float kDefaultDecay = 0.9f;
132   for (size_t num_capture_channels : {1, 2, 4, 8}) {
133     ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels);
134     test.RunEstimator();
135     EXPECT_EQ(test.GetDecay(), kDefaultDecay);
136     EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f);
137   }
138 }
139 
TEST(ReverbModelEstimatorTests,ChangingDecay)140 TEST(ReverbModelEstimatorTests, ChangingDecay) {
141   constexpr float kDefaultDecay = -0.9f;
142   for (size_t num_capture_channels : {1, 2, 4, 8}) {
143     ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels);
144     test.RunEstimator();
145     EXPECT_NEAR(test.GetDecay(), test.GetTrueDecay(), 0.1);
146     EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f);
147   }
148 }
149 
150 }  // namespace webrtc
151