1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import grpc
17import logging
18import struct
19
20from bumble.core import AdvertisingData
21from bumble.decoder import G722Decoder
22from bumble.device import Connection, Connection as BumbleConnection, Device
23from bumble.gatt import (
24    GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
25    GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
26    GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
27    GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
28    GATT_ASHA_SERVICE,
29    GATT_ASHA_VOLUME_CHARACTERISTIC,
30    Characteristic,
31    CharacteristicValue,
32    TemplateService,
33)
34from bumble.l2cap import Channel
35from bumble.pandora import utils
36from bumble.utils import AsyncRunner
37from google.protobuf.empty_pb2 import Empty  # pytype: disable=pyi-error
38from pandora_experimental.asha_grpc_aio import AshaServicer
39from pandora_experimental.asha_pb2 import CaptureAudioRequest, CaptureAudioResponse, RegisterRequest
40from typing import AsyncGenerator, List, Optional
41
42
43class AshaGattService(TemplateService):
44    # TODO: update bumble and remove this when complete
45    UUID = GATT_ASHA_SERVICE
46    OPCODE_START = 1
47    OPCODE_STOP = 2
48    OPCODE_STATUS = 3
49    PROTOCOL_VERSION = 0x01
50    RESERVED_FOR_FUTURE_USE = [00, 00]
51    FEATURE_MAP = [0x01]  # [LE CoC audio output streaming supported]
52    SUPPORTED_CODEC_ID = [0x02, 0x01]  # Codec IDs [G.722 at 16 kHz]
53    RENDER_DELAY = [00, 00]
54
55    def __init__(self, capability: int, hisyncid: List[int], device: Device, psm: int = 0) -> None:
56        self.hisyncid = hisyncid
57        self.capability = capability  # Device Capabilities [Left, Monaural]
58        self.device = device
59        self.audio_out_data = b""
60        self.psm: int = psm  # a non-zero psm is mainly for testing purpose
61
62        logger = logging.getLogger(__name__)
63
64        # Handler for volume control
65        def on_volume_write(connection: Connection, value: bytes) -> None:
66            logger.info(f"--- VOLUME Write:{value[0]}")
67            self.emit("volume", connection, value[0])
68
69        # Handler for audio control commands
70        def on_audio_control_point_write(connection: Connection, value: bytes) -> None:
71            logger.info(f"type {type(value)}")
72            logger.info(f"--- AUDIO CONTROL POINT Write:{value.hex()}")
73            opcode = value[0]
74            if opcode == AshaGattService.OPCODE_START:
75                # Start
76                audio_type = ("Unknown", "Ringtone", "Phone Call", "Media")[value[2]]
77                logger.info(
78                    f"### START: codec={value[1]}, "
79                    f"audio_type={audio_type}, "
80                    f"volume={value[3]}, "
81                    f"otherstate={value[4]}"
82                )
83                self.emit(
84                    "start",
85                    connection,
86                    {
87                        "codec": value[1],
88                        "audiotype": value[2],
89                        "volume": value[3],
90                        "otherstate": value[4],
91                    },
92                )
93            elif opcode == AshaGattService.OPCODE_STOP:
94                logger.info("### STOP")
95                self.emit("stop", connection)
96            elif opcode == AshaGattService.OPCODE_STATUS:
97                logger.info(f"### STATUS: connected={value[1]}")
98
99            # OPCODE_STATUS does not need audio status point update
100            if opcode != AshaGattService.OPCODE_STATUS:
101                AsyncRunner.spawn(device.notify_subscribers(self.audio_status_characteristic, force=True))  # type: ignore[no-untyped-call]
102
103        def on_read_only_properties_read(connection: Connection) -> bytes:
104            value = (
105                bytes(
106                    [
107                        AshaGattService.PROTOCOL_VERSION,  # Version
108                        self.capability,
109                    ]
110                )
111                + bytes(self.hisyncid)
112                + bytes(AshaGattService.FEATURE_MAP)
113                + bytes(AshaGattService.RENDER_DELAY)
114                + bytes(AshaGattService.RESERVED_FOR_FUTURE_USE)
115                + bytes(AshaGattService.SUPPORTED_CODEC_ID)
116            )
117            self.emit("read_only_properties", connection, value)
118            return value
119
120        def on_le_psm_out_read(connection: Connection) -> bytes:
121            self.emit("le_psm_out", connection, self.psm)
122            return struct.pack("<H", self.psm)
123
124        self.read_only_properties_characteristic = Characteristic(
125            GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
126            Characteristic.READ,
127            Characteristic.READABLE,
128            CharacteristicValue(read=on_read_only_properties_read),  # type: ignore[no-untyped-call]
129        )
130
131        self.audio_control_point_characteristic = Characteristic(
132            GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
133            Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
134            Characteristic.WRITEABLE,
135            CharacteristicValue(write=on_audio_control_point_write),  # type: ignore[no-untyped-call]
136        )
137        self.audio_status_characteristic = Characteristic(
138            GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
139            Characteristic.READ | Characteristic.NOTIFY,
140            Characteristic.READABLE,
141            bytes([0]),
142        )
143        self.volume_characteristic = Characteristic(
144            GATT_ASHA_VOLUME_CHARACTERISTIC,
145            Characteristic.WRITE_WITHOUT_RESPONSE,
146            Characteristic.WRITEABLE,
147            CharacteristicValue(write=on_volume_write),  # type: ignore[no-untyped-call]
148        )
149
150        # Register an L2CAP CoC server
151        def on_coc(channel: Channel) -> None:
152            def on_data(data: bytes) -> None:
153                logging.debug(f"data received:{data.hex()}")
154
155                self.emit("data", channel.connection, data)
156                self.audio_out_data += data
157
158            channel.sink = on_data  # type: ignore[no-untyped-call]
159
160        # let the server find a free PSM
161        self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)  # type: ignore[no-untyped-call]
162        self.le_psm_out_characteristic = Characteristic(
163            GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
164            Characteristic.READ,
165            Characteristic.READABLE,
166            CharacteristicValue(read=on_le_psm_out_read),  # type: ignore[no-untyped-call]
167        )
168
169        characteristics = [
170            self.read_only_properties_characteristic,
171            self.audio_control_point_characteristic,
172            self.audio_status_characteristic,
173            self.volume_characteristic,
174            self.le_psm_out_characteristic,
175        ]
176
177        super().__init__(characteristics)  # type: ignore[no-untyped-call]
178
179    def get_advertising_data(self) -> bytes:
180        # Advertisement only uses 4 least significant bytes of the HiSyncId.
181        return bytes(
182            AdvertisingData(
183                [
184                    (
185                        AdvertisingData.SERVICE_DATA_16_BIT_UUID,
186                        bytes(GATT_ASHA_SERVICE)
187                        + bytes(
188                            [
189                                AshaGattService.PROTOCOL_VERSION,
190                                self.capability,
191                            ]
192                        )
193                        + bytes(self.hisyncid[:4]),
194                    ),
195                ]
196            )
197        )
198
199
200class AshaService(AshaServicer):
201    DECODE_FRAME_LENGTH = 80
202
203    device: Device
204    asha_service: Optional[AshaGattService]
205
206    def __init__(self, device: Device) -> None:
207        self.log = utils.BumbleServerLoggerAdapter(logging.getLogger(), {"service_name": "Asha", "device": device})
208        self.device = device
209        self.asha_service = None
210
211    @utils.rpc
212    async def Register(self, request: RegisterRequest, context: grpc.ServicerContext) -> Empty:
213        logging.info("Register")
214        if self.asha_service:
215            self.asha_service.capability = request.capability
216            self.asha_service.hisyncid = request.hisyncid
217        else:
218            self.asha_service = AshaGattService(request.capability, request.hisyncid, self.device)
219            self.device.add_service(self.asha_service)  # type: ignore[no-untyped-call]
220        return Empty()
221
222    @utils.rpc
223    async def CaptureAudio(
224        self, request: CaptureAudioRequest, context: grpc.ServicerContext
225    ) -> AsyncGenerator[CaptureAudioResponse, None]:
226        connection_handle = int.from_bytes(request.connection.cookie.value, "big")
227        logging.info(f"CaptureAudioData connection_handle:{connection_handle}")
228
229        if not (connection := self.device.lookup_connection(connection_handle)):
230            raise RuntimeError(f"Unknown connection for connection_handle:{connection_handle}")
231
232        decoder = G722Decoder()  # type: ignore
233        queue: asyncio.Queue[bytes] = asyncio.Queue()
234
235        def on_data(asha_connection: BumbleConnection, data: bytes) -> None:
236            if asha_connection == connection:
237                queue.put_nowait(data)
238
239        self.asha_service.on("data", on_data)  # type: ignore
240
241        try:
242            while data := await queue.get():
243                output_bytes = bytearray()
244                # First byte is sequence number, last 160 bytes are audio payload.
245                audio_payload = data[1:]
246                data_length = int(len(audio_payload) / AshaService.DECODE_FRAME_LENGTH)
247                for i in range(0, data_length):
248                    input_data = audio_payload[
249                        i * AshaService.DECODE_FRAME_LENGTH : i * AshaService.DECODE_FRAME_LENGTH
250                        + AshaService.DECODE_FRAME_LENGTH
251                    ]
252                    decoded_data = decoder.decode_frame(input_data)  # type: ignore
253                    output_bytes.extend(decoded_data)
254
255                yield CaptureAudioResponse(data=bytes(output_bytes))
256        finally:
257            self.asha_service.remove_listener("data", on_data)  # type: ignore
258