1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
13 
14 // MSVC++ requires this to be set before any other includes to get M_PI.
15 #define _USE_MATH_DEFINES
16 
17 #include <math.h>
18 #include <vector>
19 
20 #include "webrtc/common_audio/lapped_transform.h"
21 #include "webrtc/common_audio/channel_buffer.h"
22 #include "webrtc/modules/audio_processing/beamformer/beamformer.h"
23 #include "webrtc/modules/audio_processing/beamformer/complex_matrix.h"
24 #include "webrtc/system_wrappers/include/scoped_vector.h"
25 
26 namespace webrtc {
27 
28 // Enhances sound sources coming directly in front of a uniform linear array
29 // and suppresses sound sources coming from all other directions. Operates on
30 // multichannel signals and produces single-channel output.
31 //
32 // The implemented nonlinear postfilter algorithm taken from "A Robust Nonlinear
33 // Beamforming Postprocessor" by Bastiaan Kleijn.
34 class NonlinearBeamformer
35   : public Beamformer<float>,
36     public LappedTransform::Callback {
37  public:
38   static const float kHalfBeamWidthRadians;
39 
40   explicit NonlinearBeamformer(
41       const std::vector<Point>& array_geometry,
42       SphericalPointf target_direction =
43           SphericalPointf(static_cast<float>(M_PI) / 2.f, 0.f, 1.f));
44 
45   // Sample rate corresponds to the lower band.
46   // Needs to be called before the NonlinearBeamformer can be used.
47   void Initialize(int chunk_size_ms, int sample_rate_hz) override;
48 
49   // Process one time-domain chunk of audio. The audio is expected to be split
50   // into frequency bands inside the ChannelBuffer. The number of frames and
51   // channels must correspond to the constructor parameters. The same
52   // ChannelBuffer can be passed in as |input| and |output|.
53   void ProcessChunk(const ChannelBuffer<float>& input,
54                     ChannelBuffer<float>* output) override;
55 
56   void AimAt(const SphericalPointf& target_direction) override;
57 
58   bool IsInBeam(const SphericalPointf& spherical_point) override;
59 
60   // After processing each block |is_target_present_| is set to true if the
61   // target signal es present and to false otherwise. This methods can be called
62   // to know if the data is target signal or interference and process it
63   // accordingly.
is_target_present()64   bool is_target_present() override { return is_target_present_; }
65 
66  protected:
67   // Process one frequency-domain block of audio. This is where the fun
68   // happens. Implements LappedTransform::Callback.
69   void ProcessAudioBlock(const complex<float>* const* input,
70                          size_t num_input_channels,
71                          size_t num_freq_bins,
72                          size_t num_output_channels,
73                          complex<float>* const* output) override;
74 
75  private:
76   FRIEND_TEST_ALL_PREFIXES(NonlinearBeamformerTest,
77                            InterfAnglesTakeAmbiguityIntoAccount);
78 
79   typedef Matrix<float> MatrixF;
80   typedef ComplexMatrix<float> ComplexMatrixF;
81   typedef complex<float> complex_f;
82 
83   void InitLowFrequencyCorrectionRanges();
84   void InitHighFrequencyCorrectionRanges();
85   void InitInterfAngles();
86   void InitDelaySumMasks();
87   void InitTargetCovMats();
88   void InitDiffuseCovMats();
89   void InitInterfCovMats();
90   void NormalizeCovMats();
91 
92   // Calculates postfilter masks that minimize the mean squared error of our
93   // estimation of the desired signal.
94   float CalculatePostfilterMask(const ComplexMatrixF& interf_cov_mat,
95                                 float rpsiw,
96                                 float ratio_rxiw_rxim,
97                                 float rmxi_r);
98 
99   // Prevents the postfilter masks from degenerating too quickly (a cause of
100   // musical noise).
101   void ApplyMaskTimeSmoothing();
102   void ApplyMaskFrequencySmoothing();
103 
104   // The postfilter masks are unreliable at low frequencies. Calculates a better
105   // mask by averaging mid-low frequency values.
106   void ApplyLowFrequencyCorrection();
107 
108   // Postfilter masks are also unreliable at high frequencies. Average mid-high
109   // frequency masks to calculate a single mask per block which can be applied
110   // in the time-domain. Further, we average these block-masks over a chunk,
111   // resulting in one postfilter mask per audio chunk. This allows us to skip
112   // both transforming and blocking the high-frequency signal.
113   void ApplyHighFrequencyCorrection();
114 
115   // Compute the means needed for the above frequency correction.
116   float MaskRangeMean(size_t start_bin, size_t end_bin);
117 
118   // Applies both sets of masks to |input| and store in |output|.
119   void ApplyMasks(const complex_f* const* input, complex_f* const* output);
120 
121   void EstimateTargetPresence();
122 
123   static const size_t kFftSize = 256;
124   static const size_t kNumFreqBins = kFftSize / 2 + 1;
125 
126   // Deals with the fft transform and blocking.
127   size_t chunk_length_;
128   rtc::scoped_ptr<LappedTransform> lapped_transform_;
129   float window_[kFftSize];
130 
131   // Parameters exposed to the user.
132   const size_t num_input_channels_;
133   int sample_rate_hz_;
134 
135   const std::vector<Point> array_geometry_;
136   // The normal direction of the array if it has one and it is in the xy-plane.
137   const rtc::Optional<Point> array_normal_;
138 
139   // Minimum spacing between microphone pairs.
140   const float min_mic_spacing_;
141 
142   // Calculated based on user-input and constants in the .cc file.
143   size_t low_mean_start_bin_;
144   size_t low_mean_end_bin_;
145   size_t high_mean_start_bin_;
146   size_t high_mean_end_bin_;
147 
148   // Quickly varying mask updated every block.
149   float new_mask_[kNumFreqBins];
150   // Time smoothed mask.
151   float time_smooth_mask_[kNumFreqBins];
152   // Time and frequency smoothed mask.
153   float final_mask_[kNumFreqBins];
154 
155   float target_angle_radians_;
156   // Angles of the interferer scenarios.
157   std::vector<float> interf_angles_radians_;
158   // The angle between the target and the interferer scenarios.
159   const float away_radians_;
160 
161   // Array of length |kNumFreqBins|, Matrix of size |1| x |num_channels_|.
162   ComplexMatrixF delay_sum_masks_[kNumFreqBins];
163   ComplexMatrixF normalized_delay_sum_masks_[kNumFreqBins];
164 
165   // Arrays of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
166   // |num_input_channels_|.
167   ComplexMatrixF target_cov_mats_[kNumFreqBins];
168   ComplexMatrixF uniform_cov_mat_[kNumFreqBins];
169   // Array of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
170   // |num_input_channels_|. ScopedVector has a size equal to the number of
171   // interferer scenarios.
172   ScopedVector<ComplexMatrixF> interf_cov_mats_[kNumFreqBins];
173 
174   // Of length |kNumFreqBins|.
175   float wave_numbers_[kNumFreqBins];
176 
177   // Preallocated for ProcessAudioBlock()
178   // Of length |kNumFreqBins|.
179   float rxiws_[kNumFreqBins];
180   // The vector has a size equal to the number of interferer scenarios.
181   std::vector<float> rpsiws_[kNumFreqBins];
182 
183   // The microphone normalization factor.
184   ComplexMatrixF eig_m_;
185 
186   // For processing the high-frequency input signal.
187   float high_pass_postfilter_mask_;
188 
189   // True when the target signal is present.
190   bool is_target_present_;
191   // Number of blocks after which the data is considered interference if the
192   // mask does not pass |kMaskSignalThreshold|.
193   size_t hold_target_blocks_;
194   // Number of blocks since the last mask that passed |kMaskSignalThreshold|.
195   size_t interference_blocks_count_;
196 };
197 
198 }  // namespace webrtc
199 
200 #endif  // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
201