1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/wpd_node.h"
12 
13 #include <math.h>
14 #include <string.h>
15 
16 #include "common_audio/fir_filter.h"
17 #include "common_audio/fir_filter_factory.h"
18 #include "modules/audio_processing/transient/dyadic_decimator.h"
19 #include "rtc_base/checks.h"
20 
21 namespace webrtc {
22 
WPDNode(size_t length,const float * coefficients,size_t coefficients_length)23 WPDNode::WPDNode(size_t length,
24                  const float* coefficients,
25                  size_t coefficients_length)
26     :  // The data buffer has parent data length to be able to contain and
27        // filter it.
28       data_(new float[2 * length + 1]),
29       length_(length),
30       filter_(
31           CreateFirFilter(coefficients, coefficients_length, 2 * length + 1)) {
32   RTC_DCHECK_GT(length, 0);
33   RTC_DCHECK(coefficients);
34   RTC_DCHECK_GT(coefficients_length, 0);
35   memset(data_.get(), 0.f, (2 * length + 1) * sizeof(data_[0]));
36 }
37 
~WPDNode()38 WPDNode::~WPDNode() {}
39 
Update(const float * parent_data,size_t parent_data_length)40 int WPDNode::Update(const float* parent_data, size_t parent_data_length) {
41   if (!parent_data || (parent_data_length / 2) != length_) {
42     return -1;
43   }
44 
45   // Filter data.
46   filter_->Filter(parent_data, parent_data_length, data_.get());
47 
48   // Decimate data.
49   const bool kOddSequence = true;
50   size_t output_samples = DyadicDecimate(data_.get(), parent_data_length,
51                                          kOddSequence, data_.get(), length_);
52   if (output_samples != length_) {
53     return -1;
54   }
55 
56   // Get abs to all values.
57   for (size_t i = 0; i < length_; ++i) {
58     data_[i] = fabs(data_[i]);
59   }
60 
61   return 0;
62 }
63 
set_data(const float * new_data,size_t length)64 int WPDNode::set_data(const float* new_data, size_t length) {
65   if (!new_data || length != length_) {
66     return -1;
67   }
68   memcpy(data_.get(), new_data, length * sizeof(data_[0]));
69   return 0;
70 }
71 
72 }  // namespace webrtc
73