1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
13 
14 #include <stddef.h>
15 #include <sys/types.h>
16 
17 #include <array>
18 #include <vector>
19 
20 #include "api/array_view.h"
21 #include "api/function_view.h"
22 #include "modules/audio_processing/agc2/rnn_vad/common.h"
23 #include "rtc_base/system/arch.h"
24 
25 namespace webrtc {
26 namespace rnn_vad {
27 
28 // Maximum number of units for a fully-connected layer. This value is used to
29 // over-allocate space for fully-connected layers output vectors (implemented as
30 // std::array). The value should equal the number of units of the largest
31 // fully-connected layer.
32 constexpr size_t kFullyConnectedLayersMaxUnits = 24;
33 
34 // Maximum number of units for a recurrent layer. This value is used to
35 // over-allocate space for recurrent layers state vectors (implemented as
36 // std::array). The value should equal the number of units of the largest
37 // recurrent layer.
38 constexpr size_t kRecurrentLayersMaxUnits = 24;
39 
40 // Fully-connected layer.
41 class FullyConnectedLayer {
42  public:
43   FullyConnectedLayer(size_t input_size,
44                       size_t output_size,
45                       rtc::ArrayView<const int8_t> bias,
46                       rtc::ArrayView<const int8_t> weights,
47                       rtc::FunctionView<float(float)> activation_function,
48                       Optimization optimization);
49   FullyConnectedLayer(const FullyConnectedLayer&) = delete;
50   FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
51   ~FullyConnectedLayer();
input_size()52   size_t input_size() const { return input_size_; }
output_size()53   size_t output_size() const { return output_size_; }
optimization()54   Optimization optimization() const { return optimization_; }
55   rtc::ArrayView<const float> GetOutput() const;
56   // Computes the fully-connected layer output.
57   void ComputeOutput(rtc::ArrayView<const float> input);
58 
59  private:
60   const size_t input_size_;
61   const size_t output_size_;
62   const std::vector<float> bias_;
63   const std::vector<float> weights_;
64   rtc::FunctionView<float(float)> activation_function_;
65   // The output vector of a recurrent layer has length equal to |output_size_|.
66   // However, for efficiency, over-allocation is used.
67   std::array<float, kFullyConnectedLayersMaxUnits> output_;
68   const Optimization optimization_;
69 };
70 
71 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
72 // activation functions for the update/reset and output gates respectively.
73 class GatedRecurrentLayer {
74  public:
75   GatedRecurrentLayer(size_t input_size,
76                       size_t output_size,
77                       rtc::ArrayView<const int8_t> bias,
78                       rtc::ArrayView<const int8_t> weights,
79                       rtc::ArrayView<const int8_t> recurrent_weights,
80                       Optimization optimization);
81   GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
82   GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
83   ~GatedRecurrentLayer();
input_size()84   size_t input_size() const { return input_size_; }
output_size()85   size_t output_size() const { return output_size_; }
optimization()86   Optimization optimization() const { return optimization_; }
87   rtc::ArrayView<const float> GetOutput() const;
88   void Reset();
89   // Computes the recurrent layer output and updates the status.
90   void ComputeOutput(rtc::ArrayView<const float> input);
91 
92  private:
93   const size_t input_size_;
94   const size_t output_size_;
95   const std::vector<float> bias_;
96   const std::vector<float> weights_;
97   const std::vector<float> recurrent_weights_;
98   // The state vector of a recurrent layer has length equal to |output_size_|.
99   // However, to avoid dynamic allocation, over-allocation is used.
100   std::array<float, kRecurrentLayersMaxUnits> state_;
101   const Optimization optimization_;
102 };
103 
104 // Recurrent network based VAD.
105 class RnnBasedVad {
106  public:
107   RnnBasedVad();
108   RnnBasedVad(const RnnBasedVad&) = delete;
109   RnnBasedVad& operator=(const RnnBasedVad&) = delete;
110   ~RnnBasedVad();
111   void Reset();
112   // Compute and returns the probability of voice (range: [0.0, 1.0]).
113   float ComputeVadProbability(
114       rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
115       bool is_silence);
116 
117  private:
118   FullyConnectedLayer input_layer_;
119   GatedRecurrentLayer hidden_layer_;
120   FullyConnectedLayer output_layer_;
121 };
122 
123 }  // namespace rnn_vad
124 }  // namespace webrtc
125 
126 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
127