1# Copyright 2023 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import collections
17import enum
18import hci_packets as hci
19import link_layer_packets as ll
20import llcp_packets as llcp
21import py.bluetooth
22import sys
23import typing
24import unittest
25from typing import Optional, Tuple, Union
26from hci_packets import ErrorCode
27
28from ctypes import *
29
30rootcanal = cdll.LoadLibrary("lib_rootcanal_ffi.so")
31rootcanal.ffi_controller_new.restype = c_void_p
32
33SEND_HCI_FUNC = CFUNCTYPE(None, c_int, POINTER(c_ubyte), c_size_t)
34SEND_LL_FUNC = CFUNCTYPE(None, POINTER(c_ubyte), c_size_t, c_int, c_int)
35
36
37class Idc(enum.IntEnum):
38    Cmd = 1
39    Acl = 2
40    Sco = 3
41    Evt = 4
42    Iso = 5
43
44
45class Phy(enum.IntEnum):
46    LowEnergy = 0
47    BrEdr = 1
48
49
50class LeFeatures:
51
52    def __init__(self, le_features: int):
53        self.mask = le_features
54        self.ll_privacy = (le_features & hci.LLFeaturesBits.LL_PRIVACY) != 0
55        self.le_extended_advertising = (le_features & hci.LLFeaturesBits.LE_EXTENDED_ADVERTISING) != 0
56        self.le_periodic_advertising = (le_features & hci.LLFeaturesBits.LE_PERIODIC_ADVERTISING) != 0
57
58
59def generate_rpa(irk: bytes) -> hci.Address:
60    rpa = bytearray(6)
61    rpa_type = c_char * 6
62    rootcanal.ffi_generate_rpa(c_char_p(irk), rpa_type.from_buffer(rpa))
63    rpa.reverse()
64    return hci.Address(bytes(rpa))
65
66
67class Controller:
68    """Binder class over RootCanal's ffi interfaces.
69    The methods send_cmd, send_hci, send_ll are used to inject HCI or LL
70    packets into the controller, and receive_hci, receive_ll to
71    catch outgoing HCI packets of LL pdus."""
72
73    def __init__(self, address: hci.Address):
74        # Write the callbacks for handling HCI and LL send events.
75        @SEND_HCI_FUNC
76        def send_hci(idc: c_int, data: POINTER(c_ubyte), data_len: c_size_t):
77            packet = []
78            for n in range(data_len):
79                packet.append(data[n])
80            self.receive_hci_(int(idc), bytes(packet))
81
82        @SEND_LL_FUNC
83        def send_ll(data: POINTER(c_ubyte), data_len: c_size_t, phy: c_int, tx_power: c_int):
84            packet = []
85            for n in range(data_len):
86                packet.append(data[n])
87            self.receive_ll_(bytes(packet), int(phy), int(tx_power))
88
89        self.send_hci_callback = SEND_HCI_FUNC(send_hci)
90        self.send_ll_callback = SEND_LL_FUNC(send_ll)
91
92        # Create a c++ controller instance.
93        self.instance = rootcanal.ffi_controller_new(c_char_p(address.address), self.send_hci_callback,
94                                                     self.send_ll_callback)
95
96        self.address = address
97        self.evt_queue = collections.deque()
98        self.acl_queue = collections.deque()
99        self.iso_queue = collections.deque()
100        self.ll_queue = collections.deque()
101        self.evt_queue_event = asyncio.Event()
102        self.acl_queue_event = asyncio.Event()
103        self.iso_queue_event = asyncio.Event()
104        self.ll_queue_event = asyncio.Event()
105
106    def __del__(self):
107        rootcanal.ffi_controller_delete(c_void_p(self.instance))
108
109    def receive_hci_(self, idc: int, packet: bytes):
110        if idc == Idc.Evt:
111            print(f"<-- received HCI event data={len(packet)}[..]")
112            self.evt_queue.append(packet)
113            self.evt_queue_event.set()
114        elif idc == Idc.Acl:
115            print(f"<-- received HCI ACL packet data={len(packet)}[..]")
116            self.acl_queue.append(packet)
117            self.acl_queue_event.set()
118        elif idc == Idc.Iso:
119            print(f"<-- received HCI ISO packet data={len(packet)}[..]")
120            self.iso_queue.append(packet)
121            self.iso_queue_event.set()
122        else:
123            print(f"ignoring HCI packet typ={idc}")
124
125    def receive_ll_(self, packet: bytes, phy: int, tx_power: int):
126        print(f"<-- received LL pdu data={len(packet)}[..]")
127        self.ll_queue.append(packet)
128        self.ll_queue_event.set()
129
130    def send_cmd(self, cmd: hci.Command):
131        print(f"--> sending HCI command {cmd.__class__.__name__}")
132        data = cmd.serialize()
133        rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Cmd), c_char_p(data), c_int(len(data)))
134
135    def send_iso(self, iso: hci.Iso):
136        print(f"--> sending HCI iso pdu data={len(iso.payload)}[..]")
137        data = iso.serialize()
138        rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Iso), c_char_p(data), c_int(len(data)))
139
140    def send_ll(self, pdu: ll.LinkLayerPacket, phy: Phy = Phy.LowEnergy, rssi: int = -90):
141        print(f"--> sending LL pdu {pdu.__class__.__name__}")
142        data = pdu.serialize()
143        rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy),
144                                            c_int(rssi))
145
146    def send_llcp(self,
147                  source_address: hci.Address,
148                  destination_address: hci.Address,
149                  pdu: llcp.LlcpPacket,
150                  phy: Phy = Phy.LowEnergy,
151                  rssi: int = -90):
152        print(f"--> sending LLCP pdu {pdu.__class__.__name__}")
153        ll_pdu = ll.Llcp(source_address=source_address,
154                         destination_address=destination_address,
155                         payload=pdu.serialize())
156        data = ll_pdu.serialize()
157        rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy),
158                                            c_int(rssi))
159
160    async def start(self):
161
162        async def timer():
163            while True:
164                await asyncio.sleep(0.005)
165                rootcanal.ffi_controller_tick(c_void_p(self.instance))
166
167        # Spawn the controller timer task.
168        self.timer_task = asyncio.create_task(timer())
169
170    def stop(self):
171        # Cancel the controller timer task.
172        del self.timer_task
173
174        if self.evt_queue:
175            print("evt queue not empty at stop():")
176            for packet in self.evt_queue:
177                evt = hci.Event.parse_all(packet)
178                evt.show()
179            raise Exception("evt queue not empty at stop()")
180
181        if self.iso_queue:
182            print("iso queue not empty at stop():")
183            for packet in self.iso_queue:
184                iso = hci.Iso.parse_all(packet)
185                iso.show()
186            raise Exception("ll queue not empty at stop()")
187
188        if self.ll_queue:
189            for (packet, _) in self.ll_queue:
190                pdu = ll.LinkLayerPacket.parse_all(packet)
191                pdu.show()
192            raise Exception("ll queue not empty at stop()")
193
194    async def receive_evt(self):
195        while not self.evt_queue:
196            await self.evt_queue_event.wait()
197            self.evt_queue_event.clear()
198        return self.evt_queue.popleft()
199
200    async def receive_iso(self):
201        while not self.iso_queue:
202            await self.iso_queue_event.wait()
203            self.iso_queue_event.clear()
204        return self.iso_queue.popleft()
205
206    async def expect_evt(self, expected_evt: hci.Event):
207        packet = await self.receive_evt()
208        evt = hci.Event.parse_all(packet)
209        if evt != expected_evt:
210            print("received unexpected event")
211            print("expected event:")
212            expected_evt.show()
213            print("received event:")
214            evt.show()
215            raise Exception(f"unexpected evt {evt.__class__.__name__}")
216
217    async def receive_ll(self):
218        while not self.ll_queue:
219            await self.ll_queue_event.wait()
220            self.ll_queue_event.clear()
221        return self.ll_queue.popleft()
222
223
224class Any:
225    """Helper class that will match all other values.
226       Use an element of this class in expected packets to match any value
227      returned by the Controller stack."""
228
229    def __eq__(self, other) -> bool:
230        return True
231
232    def __format__(self, format_spec: str) -> str:
233        return "_"
234
235
236class ControllerTest(unittest.IsolatedAsyncioTestCase):
237    """Helper class for writing controller tests using the python bindings.
238    The test setups the controller sending the Reset command and configuring
239    the event masks to allow all events. The local device address is
240    always configured as 11:11:11:11:11:11."""
241
242    Any = Any()
243
244    def setUp(self):
245        self.controller = Controller(hci.Address('11:11:11:11:11:11'))
246
247    async def asyncSetUp(self):
248        controller = self.controller
249
250        # Start the controller timer.
251        await controller.start()
252
253        # Reset the controller and enable all events and LE events.
254        controller.send_cmd(hci.Reset())
255        await controller.expect_evt(hci.ResetComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
256        controller.send_cmd(hci.SetEventMask(event_mask=0xffffffffffffffff))
257        await controller.expect_evt(hci.SetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
258        controller.send_cmd(hci.LeSetEventMask(le_event_mask=0xffffffffffffffff))
259        await controller.expect_evt(hci.LeSetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
260
261        # Load the local supported features to be able to disable tests
262        # that rely on unsupported features.
263        controller.send_cmd(hci.LeReadLocalSupportedFeatures())
264        evt = await self.expect_cmd_complete(hci.LeReadLocalSupportedFeaturesComplete)
265        controller.le_features = LeFeatures(evt.le_features)
266
267    async def expect_evt(self, expected_evt: typing.Union[hci.Event, type], timeout: int = 3) -> hci.Event:
268        packet = await asyncio.wait_for(self.controller.receive_evt(), timeout=timeout)
269        evt = hci.Event.parse_all(packet)
270
271        if isinstance(expected_evt, type) and not isinstance(evt, expected_evt):
272            print("received unexpected event")
273            print(f"expected event: {expected_evt.__class__.__name__}")
274            print("received event:")
275            evt.show()
276            self.assertTrue(False)
277
278        if isinstance(expected_evt, hci.Event) and evt != expected_evt:
279            print("received unexpected event")
280            print(f"expected event:")
281            expected_evt.show()
282            print("received event:")
283            evt.show()
284            self.assertTrue(False)
285
286        return evt
287
288    async def expect_cmd_complete(self, expected_evt: type, timeout: int = 3) -> hci.Event:
289        evt = await self.expect_evt(expected_evt, timeout=timeout)
290        assert evt.status == ErrorCode.SUCCESS
291        assert evt.num_hci_command_packets == 1
292        return evt
293
294    async def expect_iso(self, expected_iso: hci.Iso, timeout: int = 3):
295        packet = await asyncio.wait_for(self.controller.receive_iso(), timeout=timeout)
296        iso = hci.Iso.parse_all(packet)
297
298        if iso != expected_iso:
299            print("received unexpected iso packet")
300            print("expected packet:")
301            expected_iso.show()
302            print("received packet:")
303            iso.show()
304            self.assertTrue(False)
305
306    async def expect_ll(self,
307                        expected_pdus: typing.Union[list, typing.Union[ll.LinkLayerPacket, type]],
308                        ignored_pdus: typing.Union[list, type] = [],
309                        timeout: int = 3) -> ll.LinkLayerPacket:
310        if not isinstance(ignored_pdus, list):
311            ignored_pdus = [ignored_pdus]
312
313        if not isinstance(expected_pdus, list):
314            expected_pdus = [expected_pdus]
315
316        async with asyncio.timeout(timeout):
317            while True:
318                packet = await asyncio.wait_for(self.controller.receive_ll())
319                pdu = ll.LinkLayerPacket.parse_all(packet)
320
321                for ignored_pdu in ignored_pdus:
322                    if isinstance(pdu, ignored_pdu):
323                        continue
324
325                for expected_pdu in expected_pdus:
326                    if isinstance(expected_pdu, type) and isinstance(pdu, expected_pdu):
327                        return pdu
328                    if isinstance(expected_pdu, ll.LinkLayerPacket) and pdu == expected_pdu:
329                        return pdu
330
331                print("received unexpected pdu:")
332                pdu.show()
333                print("expected pdus:")
334                for expected_pdu in expected_pdus:
335                    if isinstance(expected_pdu, type):
336                        print(f"- {expected_pdu.__name__}")
337                    if isinstance(expected_pdu, ll.LinkLayerPacket):
338                        print(f"- {expected_pdu.__class__.__name__}")
339                        expected_pdu.show()
340
341                self.assertTrue(False)
342
343    async def expect_llcp(self,
344                          source_address: hci.Address,
345                          destination_address: hci.Address,
346                          expected_pdu: llcp.LlcpPacket,
347                          timeout: int = 3) -> llcp.LlcpPacket:
348        packet = await asyncio.wait_for(self.controller.receive_ll(), timeout=timeout)
349        pdu = ll.LinkLayerPacket.parse_all(packet)
350
351        if (pdu.type != ll.PacketType.LLCP or pdu.source_address != source_address or
352                pdu.destination_address != destination_address):
353            print("received unexpected pdu:")
354            pdu.show()
355            print(f"expected pdu: {source_address} -> {destination_address}")
356            expected_pdu.show()
357            self.assertTrue(False)
358
359        pdu = llcp.LlcpPacket.parse_all(pdu.payload)
360        if pdu != expected_pdu:
361            print("received unexpected pdu:")
362            pdu.show()
363            print("expected pdu:")
364            expected_pdu.show()
365            self.assertTrue(False)
366
367        return pdu
368
369    async def enable_connected_isochronous_stream_host_support(self):
370        """Enable Connected Isochronous Stream Host Support in the LE Feature mask."""
371        self.controller.send_cmd(
372            hci.LeSetHostFeature(bit_number=hci.LeHostFeatureBits.CONNECTED_ISO_STREAM_HOST_SUPPORT,
373                                 bit_value=hci.Enable.ENABLED))
374
375        await self.expect_evt(hci.LeSetHostFeatureComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
376
377    async def establish_le_connection_central(self, peer_address: hci.Address) -> int:
378        """Establish a connection with the selected peer as Central.
379        Returns the ACL connection handle for the opened link."""
380        self.controller.send_cmd(
381            hci.LeExtendedCreateConnection(initiator_filter_policy=hci.InitiatorFilterPolicy.USE_PEER_ADDRESS,
382                                           own_address_type=hci.OwnAddressType.PUBLIC_DEVICE_ADDRESS,
383                                           peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
384                                           peer_address=peer_address,
385                                           initiating_phys=0x1,
386                                           initiating_phy_parameters=[
387                                               hci.InitiatingPhyParameters(
388                                                   scan_interval=0x200,
389                                                   scan_window=0x100,
390                                                   connection_interval_min=0x200,
391                                                   connection_interval_max=0x200,
392                                                   max_latency=0x6,
393                                                   supervision_timeout=0xc80,
394                                                   min_ce_length=0,
395                                                   max_ce_length=0,
396                                               )
397                                           ]))
398
399        await self.expect_evt(hci.LeExtendedCreateConnectionStatus(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
400
401        self.controller.send_ll(ll.LeLegacyAdvertisingPdu(source_address=peer_address,
402                                                          advertising_address_type=ll.AddressType.PUBLIC,
403                                                          advertising_type=ll.LegacyAdvertisingType.ADV_IND,
404                                                          advertising_data=[]),
405                                rssi=-16)
406
407        await self.expect_ll(
408            ll.LeConnect(source_address=self.controller.address,
409                         destination_address=peer_address,
410                         initiating_address_type=ll.AddressType.PUBLIC,
411                         advertising_address_type=ll.AddressType.PUBLIC,
412                         conn_interval=0x200,
413                         conn_peripheral_latency=0x6,
414                         conn_supervision_timeout=0xc80))
415
416        self.controller.send_ll(
417            ll.LeConnectComplete(source_address=peer_address,
418                                 destination_address=self.controller.address,
419                                 initiating_address_type=ll.AddressType.PUBLIC,
420                                 advertising_address_type=ll.AddressType.PUBLIC,
421                                 conn_interval=0x200,
422                                 conn_peripheral_latency=0x6,
423                                 conn_supervision_timeout=0xc80))
424
425        connection_complete = await self.expect_evt(
426            hci.LeEnhancedConnectionComplete(status=ErrorCode.SUCCESS,
427                                             connection_handle=self.Any,
428                                             role=hci.Role.CENTRAL,
429                                             peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
430                                             peer_address=peer_address,
431                                             connection_interval=0x200,
432                                             peripheral_latency=0x6,
433                                             supervision_timeout=0xc80,
434                                             central_clock_accuracy=hci.ClockAccuracy.PPM_500))
435
436        acl_connection_handle = connection_complete.connection_handle
437        await self.expect_evt(
438            hci.LeChannelSelectionAlgorithm(connection_handle=acl_connection_handle,
439                                            channel_selection_algorithm=hci.ChannelSelectionAlgorithm.ALGORITHM_1))
440
441        return acl_connection_handle
442
443    async def establish_le_connection_peripheral(self, peer_address: hci.Address) -> int:
444        """Establish a connection with the selected peer as Peripheral.
445        Returns the ACL connection handle for the opened link."""
446        self.controller.send_cmd(
447            hci.LeSetAdvertisingParameters(advertising_interval_min=0x200,
448                                           advertising_interval_max=0x200,
449                                           advertising_type=hci.AdvertisingType.ADV_IND,
450                                           own_address_type=hci.OwnAddressType.PUBLIC_DEVICE_ADDRESS,
451                                           advertising_channel_map=0x7,
452                                           advertising_filter_policy=hci.AdvertisingFilterPolicy.ALL_DEVICES))
453
454        await self.expect_evt(
455            hci.LeSetAdvertisingParametersComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
456
457        self.controller.send_cmd(hci.LeSetAdvertisingEnable(advertising_enable=True))
458
459        await self.expect_evt(hci.LeSetAdvertisingEnableComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
460
461        self.controller.send_ll(ll.LeConnect(source_address=peer_address,
462                                             destination_address=self.controller.address,
463                                             initiating_address_type=ll.AddressType.PUBLIC,
464                                             advertising_address_type=ll.AddressType.PUBLIC,
465                                             conn_interval=0x200,
466                                             conn_peripheral_latency=0x200,
467                                             conn_supervision_timeout=0x200),
468                                rssi=-16)
469
470        await self.expect_ll(
471            ll.LeConnectComplete(source_address=self.controller.address,
472                                 destination_address=peer_address,
473                                 conn_interval=0x200,
474                                 conn_peripheral_latency=0x200,
475                                 conn_supervision_timeout=0x200))
476
477        connection_complete = await self.expect_evt(
478            hci.LeEnhancedConnectionComplete(status=ErrorCode.SUCCESS,
479                                             connection_handle=self.Any,
480                                             role=hci.Role.PERIPHERAL,
481                                             peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
482                                             peer_address=peer_address,
483                                             connection_interval=0x200,
484                                             peripheral_latency=0x200,
485                                             supervision_timeout=0x200,
486                                             central_clock_accuracy=hci.ClockAccuracy.PPM_500))
487
488        return connection_complete.connection_handle
489
490    def tearDown(self):
491        self.controller.stop()
492