1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "host/frontend/webrtc/audio_handler.h"
18 
19 #include <algorithm>
20 #include <chrono>
21 
22 #include <android-base/logging.h>
23 #include <rtc_base/time_utils.h>
24 
25 namespace cuttlefish {
26 namespace {
27 
28 const virtio_snd_jack_info JACKS[] = {};
29 constexpr uint32_t NUM_JACKS = sizeof(JACKS) / sizeof(JACKS[0]);
30 
31 const virtio_snd_chmap_info CHMAPS[] = {{
32     .hdr = { .hda_fn_nid = Le32(0), },
33     .direction = (uint8_t) AudioStreamDirection::VIRTIO_SND_D_OUTPUT,
34     .channels = 2,
35     .positions = {
36         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FL,
37         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FR
38     },
39 }, {
40     .hdr = { .hda_fn_nid = Le32(0), },
41     .direction = (uint8_t) AudioStreamDirection::VIRTIO_SND_D_INPUT,
42     .channels = 2,
43     .positions = {
44         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FL,
45         (uint8_t) AudioChannelMap::VIRTIO_SND_CHMAP_FR
46     },
47 }};
48 constexpr uint32_t NUM_CHMAPS = sizeof(CHMAPS) / sizeof(CHMAPS[0]);
49 
50 const virtio_snd_pcm_info STREAMS[] = {{
51     .hdr =
52         {
53             .hda_fn_nid = Le32(0),
54         },
55     .features = Le32(0),
56     // webrtc's api is quite primitive and doesn't allow for many different
57     // formats: It only takes the bits_per_sample as a parameter and assumes
58     // the underlying format to be one of the following:
59     .formats = Le64(
60         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S8) |
61         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S16) |
62         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24) |
63         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S32)),
64     .rates = Le64(
65         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_5512) |
66         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_8000) |
67         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_11025) |
68         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_16000) |
69         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_22050) |
70         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_32000) |
71         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_44100) |
72         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_48000) |
73         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_64000) |
74         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_88200) |
75         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_96000) |
76         (((uint64_t)1) << (uint8_t)
77              AudioStreamRate::VIRTIO_SND_PCM_RATE_176400) |
78         (((uint64_t)1) << (uint8_t)
79              AudioStreamRate::VIRTIO_SND_PCM_RATE_192000) |
80         (((uint64_t)1) << (uint8_t)
81              AudioStreamRate::VIRTIO_SND_PCM_RATE_384000)),
82     .direction = (uint8_t)AudioStreamDirection::VIRTIO_SND_D_OUTPUT,
83     .channels_min = 1,
84     .channels_max = 2,
85 }, {
86     .hdr =
87         {
88             .hda_fn_nid = Le32(0),
89         },
90     .features = Le32(0),
91     // webrtc's api is quite primitive and doesn't allow for many different
92     // formats: It only takes the bits_per_sample as a parameter and assumes
93     // the underlying format to be one of the following:
94     .formats = Le64(
95         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S8) |
96         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S16) |
97         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24) |
98         (((uint64_t)1) << (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S32)),
99     .rates = Le64(
100         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_5512) |
101         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_8000) |
102         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_11025) |
103         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_16000) |
104         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_22050) |
105         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_32000) |
106         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_44100) |
107         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_48000) |
108         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_64000) |
109         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_88200) |
110         (((uint64_t)1) << (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_96000) |
111         (((uint64_t)1) << (uint8_t)
112              AudioStreamRate::VIRTIO_SND_PCM_RATE_176400) |
113         (((uint64_t)1) << (uint8_t)
114              AudioStreamRate::VIRTIO_SND_PCM_RATE_192000) |
115         (((uint64_t)1) << (uint8_t)
116              AudioStreamRate::VIRTIO_SND_PCM_RATE_384000)),
117     .direction = (uint8_t)AudioStreamDirection::VIRTIO_SND_D_INPUT,
118     .channels_min = 1,
119     .channels_max = 2,
120 }};
121 constexpr uint32_t NUM_STREAMS = sizeof(STREAMS) / sizeof(STREAMS[0]);
122 
IsCapture(uint32_t stream_id)123 bool IsCapture(uint32_t stream_id) {
124   CHECK(stream_id < NUM_STREAMS) << "Invalid stream id: " << stream_id;
125   return STREAMS[stream_id].direction ==
126          (uint8_t)AudioStreamDirection::VIRTIO_SND_D_INPUT;
127 }
128 
129 class CvdAudioFrameBuffer : public webrtc_streaming::AudioFrameBuffer {
130  public:
CvdAudioFrameBuffer(const uint8_t * buffer,int bits_per_sample,int sample_rate,int channels,int frames)131   CvdAudioFrameBuffer(const uint8_t* buffer, int bits_per_sample,
132                       int sample_rate, int channels, int frames)
133       : buffer_(buffer),
134         bits_per_sample_(bits_per_sample),
135         sample_rate_(sample_rate),
136         channels_(channels),
137         frames_(frames) {}
138 
bits_per_sample() const139   int bits_per_sample() const override { return bits_per_sample_; }
140 
sample_rate() const141   int sample_rate() const override { return sample_rate_; }
142 
channels() const143   int channels() const override { return channels_; }
144 
frames() const145   int frames() const override { return frames_; }
146 
data() const147   const uint8_t* data() const override { return buffer_; }
148 
149  private:
150   const uint8_t* buffer_;
151   int bits_per_sample_;
152   int sample_rate_;
153   int channels_;
154   int frames_;
155 };
156 
BitsPerSample(uint8_t virtio_format)157 int BitsPerSample(uint8_t virtio_format) {
158   switch (virtio_format) {
159     /* analog formats (width / physical width) */
160     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_IMA_ADPCM:
161       /*  4 /  4 bits */
162       return 4;
163     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_MU_LAW:
164       /*  8 /  8 bits */
165       return 8;
166     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_A_LAW:
167       /*  8 /  8 bits */
168       return 8;
169     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S8:
170       /*  8 /  8 bits */
171       return 8;
172     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U8:
173       /*  8 /  8 bits */
174       return 8;
175     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S16:
176       /* 16 / 16 bits */
177       return 16;
178     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U16:
179       /* 16 / 16 bits */
180       return 16;
181     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S18_3:
182       /* 18 / 24 bits */
183       return 24;
184     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U18_3:
185       /* 18 / 24 bits */
186       return 24;
187     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S20_3:
188       /* 20 / 24 bits */
189       return 24;
190     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U20_3:
191       /* 20 / 24 bits */
192       return 24;
193     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24_3:
194       /* 24 / 24 bits */
195       return 24;
196     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U24_3:
197       /* 24 / 24 bits */
198       return 24;
199     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S20:
200       /* 20 / 32 bits */
201       return 32;
202     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U20:
203       /* 20 / 32 bits */
204       return 32;
205     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S24:
206       /* 24 / 32 bits */
207       return 32;
208     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U24:
209       /* 24 / 32 bits */
210       return 32;
211     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_S32:
212       /* 32 / 32 bits */
213       return 32;
214     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_U32:
215       /* 32 / 32 bits */
216       return 32;
217     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_FLOAT:
218       /* 32 / 32 bits */
219       return 32;
220     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_FLOAT64:
221       /* 64 / 64 bits */
222       return 64;
223     /* digital formats (width / physical width) */
224     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U8:
225       /*  8 /  8 bits */
226       return 8;
227     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U16:
228       /* 16 / 16 bits */
229       return 16;
230     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_DSD_U32:
231       /* 32 / 32 bits */
232       return 32;
233     case (uint8_t)AudioStreamFormat::VIRTIO_SND_PCM_FMT_IEC958_SUBFRAME:
234       /* 32 / 32 bits */
235       return 32;
236     default:
237       LOG(ERROR) << "Unknown virtio-snd audio format: " << virtio_format;
238       return -1;
239   }
240 }
241 
SampleRate(uint8_t virtio_rate)242 int SampleRate(uint8_t virtio_rate) {
243   switch (virtio_rate) {
244     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_5512:
245       return 5512;
246     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_8000:
247       return 8000;
248     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_11025:
249       return 11025;
250     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_16000:
251       return 16000;
252     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_22050:
253       return 22050;
254     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_32000:
255       return 32000;
256     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_44100:
257       return 44100;
258     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_48000:
259       return 48000;
260     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_64000:
261       return 64000;
262     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_88200:
263       return 88200;
264     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_96000:
265       return 96000;
266     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_176400:
267       return 176400;
268     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_192000:
269       return 192000;
270     case (uint8_t)AudioStreamRate::VIRTIO_SND_PCM_RATE_384000:
271       return 384000;
272     default:
273       LOG(ERROR) << "Unknown virtio-snd sample rate: " << virtio_rate;
274       return -1;
275   }
276 }
277 
278 }  // namespace
279 
AudioHandler(std::unique_ptr<AudioServer> audio_server,std::shared_ptr<webrtc_streaming::AudioSink> audio_sink,std::shared_ptr<webrtc_streaming::AudioSource> audio_source)280 AudioHandler::AudioHandler(
281     std::unique_ptr<AudioServer> audio_server,
282     std::shared_ptr<webrtc_streaming::AudioSink> audio_sink,
283     std::shared_ptr<webrtc_streaming::AudioSource> audio_source)
284     : audio_sink_(audio_sink),
285       audio_server_(std::move(audio_server)),
286       stream_descs_(NUM_STREAMS),
287       audio_source_(audio_source) {}
288 
Start()289 void AudioHandler::Start() {
290   server_thread_ = std::thread([this]() { Loop(); });
291 }
292 
Loop()293 [[noreturn]] void AudioHandler::Loop() {
294   for (;;) {
295     auto audio_client = audio_server_->AcceptClient(
296         NUM_STREAMS, NUM_JACKS, NUM_CHMAPS,
297         262144 /* tx_shm_len */, 262144 /* rx_shm_len */);
298     CHECK(audio_client) << "Failed to create audio client connection instance";
299 
300     std::thread playback_thread([this, &audio_client]() {
301       while (audio_client->ReceivePlayback(*this)) {
302       }
303     });
304     std::thread capture_thread([this, &audio_client]() {
305       while (audio_client->ReceiveCapture(*this)) {
306       }
307     });
308     // Wait for the client to do something
309     while (audio_client->ReceiveCommands(*this)) {
310     }
311     playback_thread.join();
312     capture_thread.join();
313   }
314 }
315 
StreamsInfo(StreamInfoCommand & cmd)316 void AudioHandler::StreamsInfo(StreamInfoCommand& cmd) {
317   if (cmd.start_id() >= NUM_STREAMS ||
318       cmd.start_id() + cmd.count() > NUM_STREAMS) {
319     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
320     return;
321   }
322   std::vector<virtio_snd_pcm_info> stream_info(
323       &STREAMS[cmd.start_id()], &STREAMS[0] + cmd.start_id() + cmd.count());
324   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, stream_info);
325 }
326 
SetStreamParameters(StreamSetParamsCommand & cmd)327 void AudioHandler::SetStreamParameters(StreamSetParamsCommand& cmd) {
328   if (cmd.stream_id() >= NUM_STREAMS) {
329     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
330     return;
331   }
332   const auto& stream_info = STREAMS[cmd.stream_id()];
333   auto bits_per_sample = BitsPerSample(cmd.format());
334   auto sample_rate = SampleRate(cmd.rate());
335   auto channels = cmd.channels();
336   if (bits_per_sample < 0 || sample_rate < 0 ||
337       channels < stream_info.channels_min ||
338       channels > stream_info.channels_max) {
339     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
340     return;
341   }
342   {
343     std::lock_guard<std::mutex> lock(stream_descs_[cmd.stream_id()].mtx);
344     stream_descs_[cmd.stream_id()].bits_per_sample = bits_per_sample;
345     stream_descs_[cmd.stream_id()].sample_rate = sample_rate;
346     stream_descs_[cmd.stream_id()].channels = channels;
347     auto len10ms = (channels * (sample_rate / 100) * bits_per_sample) / 8;
348     stream_descs_[cmd.stream_id()].buffer.Reset(len10ms);
349   }
350   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
351 }
352 
PrepareStream(StreamControlCommand & cmd)353 void AudioHandler::PrepareStream(StreamControlCommand& cmd) {
354   if (cmd.stream_id() >= NUM_STREAMS) {
355     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
356     return;
357   }
358   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
359 }
360 
ReleaseStream(StreamControlCommand & cmd)361 void AudioHandler::ReleaseStream(StreamControlCommand& cmd) {
362   if (cmd.stream_id() >= NUM_STREAMS) {
363     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
364     return;
365   }
366   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
367 }
368 
StartStream(StreamControlCommand & cmd)369 void AudioHandler::StartStream(StreamControlCommand& cmd) {
370   if (cmd.stream_id() >= NUM_STREAMS) {
371     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
372     return;
373   }
374   stream_descs_[cmd.stream_id()].active = true;
375   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
376 }
377 
StopStream(StreamControlCommand & cmd)378 void AudioHandler::StopStream(StreamControlCommand& cmd) {
379   if (cmd.stream_id() >= NUM_STREAMS) {
380     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG);
381     return;
382   }
383   stream_descs_[cmd.stream_id()].active = false;
384   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK);
385 }
386 
ChmapsInfo(ChmapInfoCommand & cmd)387 void AudioHandler::ChmapsInfo(ChmapInfoCommand& cmd) {
388   if (cmd.start_id() >= NUM_CHMAPS ||
389       cmd.start_id() + cmd.count() > NUM_CHMAPS) {
390     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
391     return;
392   }
393   std::vector<virtio_snd_chmap_info> chmap_info(
394       &CHMAPS[cmd.start_id()], &CHMAPS[cmd.start_id()] + cmd.count());
395   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, chmap_info);
396 }
397 
JacksInfo(JackInfoCommand & cmd)398 void AudioHandler::JacksInfo(JackInfoCommand& cmd) {
399   if (cmd.start_id() >= NUM_JACKS ||
400       cmd.start_id() + cmd.count() > NUM_JACKS) {
401     cmd.Reply(AudioStatus::VIRTIO_SND_S_BAD_MSG, {});
402     return;
403   }
404   std::vector<virtio_snd_jack_info> jack_info(
405       &JACKS[cmd.start_id()], &JACKS[cmd.start_id()] + cmd.count());
406   cmd.Reply(AudioStatus::VIRTIO_SND_S_OK, jack_info);
407 }
408 
OnPlaybackBuffer(TxBuffer buffer)409 void AudioHandler::OnPlaybackBuffer(TxBuffer buffer) {
410   auto stream_id = buffer.stream_id();
411   auto& stream_desc = stream_descs_[stream_id];
412   {
413     std::lock_guard<std::mutex> lock(stream_desc.mtx);
414     auto& holding_buffer = stream_descs_[stream_id].buffer;
415     // Invalid or capture streams shouldn't send tx buffers
416     if (stream_id >= NUM_STREAMS || IsCapture(stream_id)) {
417       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_BAD_MSG, 0, 0);
418       return;
419     }
420     // A buffer may be received for an inactive stream if we were slow to
421     // process it and the other side stopped the stream. Quietly ignore it in
422     // that case
423     if (!stream_desc.active) {
424       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
425       return;
426     }
427     // Webrtc will silently ignore any buffer with a length different than 10ms,
428     // so we must split any buffer bigger than that and temporarily store any
429     // remaining frames that are less than that size.
430     auto current_time = rtc::TimeMillis();
431     // The timestamp of the first 10ms chunk to be sent so that the last one
432     // will have the current time
433     auto base_time =
434         current_time - ((buffer.len() - 1) / holding_buffer.buffer.size()) * 10;
435     // number of frames in a 10 ms buffer
436     const int frames = stream_desc.sample_rate / 100;
437     size_t pos = 0;
438     while (pos < buffer.len()) {
439       if (holding_buffer.empty() &&
440           buffer.len() - pos >= holding_buffer.buffer.size()) {
441         // Avoid the extra copy into holding buffer
442         // This casts away volatility of the pointer, necessary because the
443         // webrtc api doesn't expect volatile memory. This should be safe though
444         // because webrtc will use the contents of the buffer before returning
445         // and only then we release it.
446         CvdAudioFrameBuffer audio_frame_buffer(
447             const_cast<const uint8_t*>(&buffer.get()[pos]),
448             stream_desc.bits_per_sample, stream_desc.sample_rate,
449             stream_desc.channels, frames);
450         audio_sink_->OnFrame(audio_frame_buffer, base_time);
451         pos += holding_buffer.buffer.size();
452       } else {
453         pos += holding_buffer.Add(buffer.get() + pos, buffer.len() - pos);
454         if (holding_buffer.full()) {
455           auto buffer_ptr = const_cast<const uint8_t*>(holding_buffer.data());
456           CvdAudioFrameBuffer audio_frame_buffer(
457               buffer_ptr, stream_desc.bits_per_sample, stream_desc.sample_rate,
458               stream_desc.channels, frames);
459           audio_sink_->OnFrame(audio_frame_buffer, base_time);
460           holding_buffer.count = 0;
461         }
462       }
463       base_time += 10;
464     }
465   }
466   buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
467 }
468 
OnCaptureBuffer(RxBuffer buffer)469 void AudioHandler::OnCaptureBuffer(RxBuffer buffer) {
470   auto stream_id = buffer.stream_id();
471   auto& stream_desc = stream_descs_[stream_id];
472   {
473     std::lock_guard<std::mutex> lock(stream_desc.mtx);
474     // Invalid or playback streams shouldn't send rx buffers
475     if (stream_id >= NUM_STREAMS || !IsCapture(stream_id)) {
476       LOG(ERROR) << "Received capture buffers on playback stream " << stream_id;
477       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_BAD_MSG, 0, 0);
478       return;
479     }
480     // A buffer may be received for an inactive stream if we were slow to
481     // process it and the other side stopped the stream. Quietly ignore it in
482     // that case
483     if (!stream_desc.active) {
484       buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
485       return;
486     }
487     const auto bytes_per_sample = stream_desc.bits_per_sample / 8;
488     const auto samples_per_channel = stream_desc.sample_rate / 100;
489     const auto bytes_per_request =
490         samples_per_channel * bytes_per_sample * stream_desc.channels;
491     bool muted = false;
492     size_t bytes_read = 0;
493     auto& holding_buffer = stream_descs_[stream_id].buffer;
494     auto rx_buffer = const_cast<uint8_t*>(buffer.get());
495     if (!holding_buffer.empty()) {
496       // Consume any bytes remaining from previous requests
497       bytes_read += holding_buffer.Take(rx_buffer + bytes_read,
498                                         buffer.len() - bytes_read);
499     }
500     while (buffer.len() - bytes_read >= bytes_per_request) {
501       // Skip the holding buffer in as many reads as possible to avoid the extra
502       // copies
503       auto write_pos = rx_buffer + bytes_read;
504       auto res = audio_source_->GetMoreAudioData(
505           write_pos, bytes_per_sample, samples_per_channel,
506           stream_desc.channels, stream_desc.sample_rate, muted);
507       if (res < 0) {
508         // This is likely a recoverable error, log the error but don't let the
509         // VMM know about it so that it doesn't crash.
510         LOG(ERROR) << "Failed to receive audio data from client";
511         break;
512       }
513       if (muted) {
514         // The source is muted, just fill the buffer with zeros and return
515         memset(rx_buffer + bytes_read, 0, buffer.len() - bytes_read);
516         bytes_read = buffer.len();
517         break;
518       }
519       auto bytes_received = res * bytes_per_sample * stream_desc.channels;
520       bytes_read += bytes_received;
521     }
522     if (bytes_read < buffer.len()) {
523       // There is some buffer left to fill, but it's less than 10ms, read into
524       // holding buffer to ensure the remainder is kept around for future reads
525       auto write_pos = holding_buffer.data();
526       // Holding buffer is the exact size we need to read into and is emptied
527       // before we try to read into it.
528       CHECK(holding_buffer.freeCapacity() >= bytes_per_request)
529           << "Buffer too small for receiving audio";
530       auto res = audio_source_->GetMoreAudioData(
531           write_pos, bytes_per_sample, samples_per_channel,
532           stream_desc.channels, stream_desc.sample_rate, muted);
533       if (res < 0) {
534         // This is likely a recoverable error, log the error but don't let the
535         // VMM know about it so that it doesn't crash.
536         LOG(ERROR) << "Failed to receive audio data from client";
537       } else if (muted) {
538         // The source is muted, just fill the buffer with zeros and return
539         memset(rx_buffer + bytes_read, 0, buffer.len() - bytes_read);
540         bytes_read = buffer.len();
541       } else {
542         auto bytes_received = res * bytes_per_sample * stream_desc.channels;
543         holding_buffer.count += bytes_received;
544         bytes_read += holding_buffer.Take(rx_buffer + bytes_read,
545                                           buffer.len() - bytes_read);
546         // If the entire buffer is not full by now there is a bug above
547         // somewhere
548         CHECK(bytes_read == buffer.len()) << "Failed to read entire buffer";
549       }
550     }
551   }
552   buffer.SendStatus(AudioStatus::VIRTIO_SND_S_OK, 0, buffer.len());
553 }
554 
Reset(size_t size)555 void AudioHandler::HoldingBuffer::Reset(size_t size) {
556   buffer.resize(size);
557   count = 0;
558 }
559 
Add(const volatile uint8_t * data,size_t max_len)560 size_t AudioHandler::HoldingBuffer::Add(const volatile uint8_t* data,
561                                         size_t max_len) {
562   auto added_len = std::min(max_len, buffer.size() - count);
563   std::copy(data, data + added_len, &buffer[count]);
564   count += added_len;
565   return added_len;
566 }
567 
Take(uint8_t * dst,size_t len)568 size_t AudioHandler::HoldingBuffer::Take(uint8_t* dst, size_t len) {
569   auto n = std::min(len, count);
570   std::copy(buffer.begin(), buffer.begin() + n, dst);
571   std::copy(buffer.begin() + n, buffer.begin() + count, buffer.begin());
572   count -= n;
573   return n;
574 }
575 
empty() const576 bool AudioHandler::HoldingBuffer::empty() const { return count == 0; }
577 
full() const578 bool AudioHandler::HoldingBuffer::full() const {
579   return count == buffer.size();
580 }
581 
freeCapacity() const582 size_t AudioHandler::HoldingBuffer::freeCapacity() const {
583   return buffer.size() - count;
584 }
585 
data()586 uint8_t* AudioHandler::HoldingBuffer::data() { return buffer.data(); }
587 
588 }  // namespace cuttlefish
589