1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "test/fuzzers/audio_processing_fuzzer_helper.h"
12 
13 #include <algorithm>
14 #include <array>
15 #include <cmath>
16 #include <limits>
17 
18 #include "api/audio/audio_frame.h"
19 #include "modules/audio_processing/include/audio_frame_proxies.h"
20 #include "modules/audio_processing/include/audio_processing.h"
21 #include "rtc_base/checks.h"
22 
23 namespace webrtc {
24 namespace {
ValidForApm(float x)25 bool ValidForApm(float x) {
26   return std::isfinite(x) && -1.0f <= x && x <= 1.0f;
27 }
28 
GenerateFloatFrame(test::FuzzDataHelper * fuzz_data,size_t input_rate,size_t num_channels,float * const * float_frames)29 void GenerateFloatFrame(test::FuzzDataHelper* fuzz_data,
30                         size_t input_rate,
31                         size_t num_channels,
32                         float* const* float_frames) {
33   const size_t samples_per_input_channel =
34       rtc::CheckedDivExact(input_rate, static_cast<size_t>(100));
35   RTC_DCHECK_LE(samples_per_input_channel, 480);
36   for (size_t i = 0; i < num_channels; ++i) {
37     std::fill(float_frames[i], float_frames[i] + samples_per_input_channel, 0);
38     const size_t read_bytes = sizeof(float) * samples_per_input_channel;
39     if (fuzz_data->CanReadBytes(read_bytes)) {
40       rtc::ArrayView<const uint8_t> byte_array =
41           fuzz_data->ReadByteArray(read_bytes);
42       memmove(float_frames[i], byte_array.begin(), read_bytes);
43     }
44 
45     // Sanitize input.
46     for (size_t j = 0; j < samples_per_input_channel; ++j) {
47       if (!ValidForApm(float_frames[i][j])) {
48         float_frames[i][j] = 0.f;
49       }
50     }
51   }
52 }
53 
GenerateFixedFrame(test::FuzzDataHelper * fuzz_data,size_t input_rate,size_t num_channels,AudioFrame * fixed_frame)54 void GenerateFixedFrame(test::FuzzDataHelper* fuzz_data,
55                         size_t input_rate,
56                         size_t num_channels,
57                         AudioFrame* fixed_frame) {
58   const size_t samples_per_input_channel =
59       rtc::CheckedDivExact(input_rate, static_cast<size_t>(100));
60   fixed_frame->samples_per_channel_ = samples_per_input_channel;
61   fixed_frame->sample_rate_hz_ = input_rate;
62   fixed_frame->num_channels_ = num_channels;
63 
64   RTC_DCHECK_LE(samples_per_input_channel * num_channels,
65                 AudioFrame::kMaxDataSizeSamples);
66   for (size_t i = 0; i < samples_per_input_channel * num_channels; ++i) {
67     fixed_frame->mutable_data()[i] = fuzz_data->ReadOrDefaultValue<int16_t>(0);
68   }
69 }
70 }  // namespace
71 
FuzzAudioProcessing(test::FuzzDataHelper * fuzz_data,std::unique_ptr<AudioProcessing> apm)72 void FuzzAudioProcessing(test::FuzzDataHelper* fuzz_data,
73                          std::unique_ptr<AudioProcessing> apm) {
74   AudioFrame fixed_frame;
75   // Normal usage is up to 8 channels. Allowing to fuzz one beyond this allows
76   // us to catch implicit assumptions about normal usage.
77   constexpr int kMaxNumChannels = 9;
78   std::array<std::array<float, 480>, kMaxNumChannels> float_frames;
79   std::array<float*, kMaxNumChannels> float_frame_ptrs;
80   for (int i = 0; i < kMaxNumChannels; ++i) {
81     float_frame_ptrs[i] = float_frames[i].data();
82   }
83   float* const* ptr_to_float_frames = &float_frame_ptrs[0];
84 
85   using Rate = AudioProcessing::NativeRate;
86   const Rate rate_kinds[] = {Rate::kSampleRate8kHz, Rate::kSampleRate16kHz,
87                              Rate::kSampleRate32kHz, Rate::kSampleRate48kHz};
88 
89   // We may run out of fuzz data in the middle of a loop iteration. In
90   // that case, default values will be used for the rest of that
91   // iteration.
92   while (fuzz_data->CanReadBytes(1)) {
93     const bool is_float = fuzz_data->ReadOrDefaultValue(true);
94     // Decide input/output rate for this iteration.
95     const auto input_rate =
96         static_cast<size_t>(fuzz_data->SelectOneOf(rate_kinds));
97     const auto output_rate =
98         static_cast<size_t>(fuzz_data->SelectOneOf(rate_kinds));
99 
100     const uint8_t stream_delay = fuzz_data->ReadOrDefaultValue<uint8_t>(0);
101 
102     // API call needed for AEC-2 and AEC-m to run.
103     apm->set_stream_delay_ms(stream_delay);
104 
105     const bool key_pressed = fuzz_data->ReadOrDefaultValue(true);
106     apm->set_stream_key_pressed(key_pressed);
107 
108     // Make the APM call depending on capture/render mode and float /
109     // fix interface.
110     const bool is_capture = fuzz_data->ReadOrDefaultValue(true);
111 
112     // Fill the arrays with audio samples from the data.
113     int apm_return_code = AudioProcessing::Error::kNoError;
114     if (is_float) {
115       const int num_channels =
116           fuzz_data->ReadOrDefaultValue<uint8_t>(1) % kMaxNumChannels;
117 
118       GenerateFloatFrame(fuzz_data, input_rate, num_channels,
119                          ptr_to_float_frames);
120       if (is_capture) {
121         apm_return_code = apm->ProcessStream(
122             ptr_to_float_frames, StreamConfig(input_rate, num_channels),
123             StreamConfig(output_rate, num_channels), ptr_to_float_frames);
124       } else {
125         apm_return_code = apm->ProcessReverseStream(
126             ptr_to_float_frames, StreamConfig(input_rate, num_channels),
127             StreamConfig(output_rate, num_channels), ptr_to_float_frames);
128       }
129     } else {
130       const int num_channels = fuzz_data->ReadOrDefaultValue(true) ? 2 : 1;
131       GenerateFixedFrame(fuzz_data, input_rate, num_channels, &fixed_frame);
132 
133       if (is_capture) {
134         apm_return_code = ProcessAudioFrame(apm.get(), &fixed_frame);
135       } else {
136         apm_return_code = ProcessReverseAudioFrame(apm.get(), &fixed_frame);
137       }
138     }
139 
140     // Cover stats gathering code paths.
141     static_cast<void>(apm->GetStatistics(true /*has_remote_tracks*/));
142 
143     RTC_DCHECK_NE(apm_return_code, AudioProcessing::kBadDataLengthError);
144   }
145 }
146 }  // namespace webrtc
147