1 /*
2  *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/utility/pffft_wrapper.h"
12 
13 #include "rtc_base/checks.h"
14 #include "third_party/pffft/src/pffft.h"
15 
16 namespace webrtc {
17 namespace {
18 
GetBufferSize(size_t fft_size,Pffft::FftType fft_type)19 size_t GetBufferSize(size_t fft_size, Pffft::FftType fft_type) {
20   return fft_size * (fft_type == Pffft::FftType::kReal ? 1 : 2);
21 }
22 
AllocatePffftBuffer(size_t size)23 float* AllocatePffftBuffer(size_t size) {
24   return static_cast<float*>(pffft_aligned_malloc(size * sizeof(float)));
25 }
26 
27 }  // namespace
28 
FloatBuffer(size_t fft_size,FftType fft_type)29 Pffft::FloatBuffer::FloatBuffer(size_t fft_size, FftType fft_type)
30     : size_(GetBufferSize(fft_size, fft_type)),
31       data_(AllocatePffftBuffer(size_)) {}
32 
~FloatBuffer()33 Pffft::FloatBuffer::~FloatBuffer() {
34   pffft_aligned_free(data_);
35 }
36 
GetConstView() const37 rtc::ArrayView<const float> Pffft::FloatBuffer::GetConstView() const {
38   return {data_, size_};
39 }
40 
GetView()41 rtc::ArrayView<float> Pffft::FloatBuffer::GetView() {
42   return {data_, size_};
43 }
44 
Pffft(size_t fft_size,FftType fft_type)45 Pffft::Pffft(size_t fft_size, FftType fft_type)
46     : fft_size_(fft_size),
47       fft_type_(fft_type),
48       pffft_status_(pffft_new_setup(
49           fft_size_,
50           fft_type == Pffft::FftType::kReal ? PFFFT_REAL : PFFFT_COMPLEX)),
51       scratch_buffer_(
52           AllocatePffftBuffer(GetBufferSize(fft_size_, fft_type_))) {
53   RTC_DCHECK(pffft_status_);
54   RTC_DCHECK(scratch_buffer_);
55 }
56 
~Pffft()57 Pffft::~Pffft() {
58   pffft_destroy_setup(pffft_status_);
59   pffft_aligned_free(scratch_buffer_);
60 }
61 
IsValidFftSize(size_t fft_size,FftType fft_type)62 bool Pffft::IsValidFftSize(size_t fft_size, FftType fft_type) {
63   if (fft_size == 0) {
64     return false;
65   }
66   // PFFFT only supports transforms for inputs of length N of the form
67   // N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT
68   // and a >= 4 for the complex FFT.
69   constexpr int kFactors[] = {2, 3, 5};
70   int factorization[] = {0, 0, 0};
71   int n = static_cast<int>(fft_size);
72   for (int i = 0; i < 3; ++i) {
73     while (n % kFactors[i] == 0) {
74       n = n / kFactors[i];
75       factorization[i]++;
76     }
77   }
78   int a_min = (fft_type == Pffft::FftType::kReal) ? 5 : 4;
79   return factorization[0] >= a_min && n == 1;
80 }
81 
IsSimdEnabled()82 bool Pffft::IsSimdEnabled() {
83   return pffft_simd_size() > 1;
84 }
85 
CreateBuffer() const86 std::unique_ptr<Pffft::FloatBuffer> Pffft::CreateBuffer() const {
87   // Cannot use make_unique from absl because Pffft is the only friend of
88   // Pffft::FloatBuffer.
89   std::unique_ptr<Pffft::FloatBuffer> buffer(
90       new Pffft::FloatBuffer(fft_size_, fft_type_));
91   return buffer;
92 }
93 
ForwardTransform(const FloatBuffer & in,FloatBuffer * out,bool ordered)94 void Pffft::ForwardTransform(const FloatBuffer& in,
95                              FloatBuffer* out,
96                              bool ordered) {
97   RTC_DCHECK_EQ(in.size(), GetBufferSize(fft_size_, fft_type_));
98   RTC_DCHECK_EQ(in.size(), out->size());
99   RTC_DCHECK(scratch_buffer_);
100   if (ordered) {
101     pffft_transform_ordered(pffft_status_, in.const_data(), out->data(),
102                             scratch_buffer_, PFFFT_FORWARD);
103   } else {
104     pffft_transform(pffft_status_, in.const_data(), out->data(),
105                     scratch_buffer_, PFFFT_FORWARD);
106   }
107 }
108 
BackwardTransform(const FloatBuffer & in,FloatBuffer * out,bool ordered)109 void Pffft::BackwardTransform(const FloatBuffer& in,
110                               FloatBuffer* out,
111                               bool ordered) {
112   RTC_DCHECK_EQ(in.size(), GetBufferSize(fft_size_, fft_type_));
113   RTC_DCHECK_EQ(in.size(), out->size());
114   RTC_DCHECK(scratch_buffer_);
115   if (ordered) {
116     pffft_transform_ordered(pffft_status_, in.const_data(), out->data(),
117                             scratch_buffer_, PFFFT_BACKWARD);
118   } else {
119     pffft_transform(pffft_status_, in.const_data(), out->data(),
120                     scratch_buffer_, PFFFT_BACKWARD);
121   }
122 }
123 
FrequencyDomainConvolve(const FloatBuffer & fft_x,const FloatBuffer & fft_y,FloatBuffer * out,float scaling)124 void Pffft::FrequencyDomainConvolve(const FloatBuffer& fft_x,
125                                     const FloatBuffer& fft_y,
126                                     FloatBuffer* out,
127                                     float scaling) {
128   RTC_DCHECK_EQ(fft_x.size(), GetBufferSize(fft_size_, fft_type_));
129   RTC_DCHECK_EQ(fft_x.size(), fft_y.size());
130   RTC_DCHECK_EQ(fft_x.size(), out->size());
131   pffft_zconvolve_accumulate(pffft_status_, fft_x.const_data(),
132                              fft_y.const_data(), out->data(), scaling);
133 }
134 
135 }  // namespace webrtc
136