1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import logging
17
18from avatar import BumblePandoraDevice, PandoraDevice, PandoraDevices
19from avatar.aio import asynchronous
20from bumble import smp
21from bumble.hci import Address
22from bumble.pairing import PairingDelegate
23from concurrent import futures
24from contextlib import suppress
25from mobly import base_test, signals, test_runner
26from mobly.asserts import assert_equal  # type: ignore
27from mobly.asserts import assert_false  # type: ignore
28from mobly.asserts import assert_is_not_none  # type: ignore
29from mobly.asserts import assert_true  # type: ignore
30from pandora.host_pb2 import RANDOM, DataTypes, OwnAddressType, ScanningResponse
31from pandora.security_pb2 import LE_LEVEL3, PairingEventAnswer
32from typing import NoReturn, Optional
33
34
35class SmpTest(base_test.BaseTestClass):  # type: ignore[misc]
36    devices: Optional[PandoraDevices] = None
37
38    dut: PandoraDevice
39    ref: PandoraDevice
40
41    def setup_class(self) -> None:
42        self.devices = PandoraDevices(self)
43        self.dut, self.ref, *_ = self.devices
44
45        # Enable BR/EDR mode for Bumble devices.
46        for device in self.devices:
47            if isinstance(device, BumblePandoraDevice):
48                device.config.setdefault('classic_enabled', True)
49
50    def teardown_class(self) -> None:
51        if self.devices:
52            self.devices.stop_all()
53
54    @asynchronous
55    async def setup_test(self) -> None:
56        await asyncio.gather(self.dut.reset(), self.ref.reset())
57
58    async def handle_pairing_events(self) -> NoReturn:
59        dut_pairing_stream = self.dut.aio.security.OnPairing()
60        ref_pairing_stream = self.ref.aio.security.OnPairing()
61        try:
62            while True:
63                dut_pairing_event = await anext(dut_pairing_stream)
64
65                if dut_pairing_event.method_variant() == 'passkey_entry_notification':
66                    ref_pairing_event = await anext(ref_pairing_stream)
67
68                    assert_equal(ref_pairing_event.method_variant(), 'passkey_entry_request')
69                    assert_is_not_none(dut_pairing_event.passkey_entry_notification)
70                    assert dut_pairing_event.passkey_entry_notification is not None
71
72                    ref_ev_answer = PairingEventAnswer(
73                        event=ref_pairing_event,
74                        passkey=dut_pairing_event.passkey_entry_notification,
75                    )
76                    ref_pairing_stream.send_nowait(ref_ev_answer)
77                else:
78                    dut_pairing_stream.send_nowait(PairingEventAnswer(
79                        event=dut_pairing_event,
80                        confirm=True,
81                    ))
82                    ref_pairing_event = await anext(ref_pairing_stream)
83
84                    ref_pairing_stream.send_nowait(PairingEventAnswer(
85                        event=ref_pairing_event,
86                        confirm=True,
87                    ))
88
89        finally:
90            dut_pairing_stream.cancel()
91
92    async def dut_pair(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType) -> ScanningResponse:
93        advertisement = self.ref.aio.host.Advertise(
94            legacy=True,
95            connectable=True,
96            own_address_type=ref_address_type,
97            data=DataTypes(manufacturer_specific_data=b'pause cafe'),
98        )
99
100        scan = self.dut.aio.host.Scan(own_address_type=dut_address_type)
101        ref = await anext((x async for x in scan if b'pause cafe' in x.data.manufacturer_specific_data))
102        scan.cancel()
103
104        pairing = asyncio.create_task(self.handle_pairing_events())
105        (dut_ref_res, ref_dut_res) = await asyncio.gather(
106            self.dut.aio.host.ConnectLE(own_address_type=dut_address_type, **ref.address_asdict()),
107            anext(aiter(advertisement)),
108        )
109
110        advertisement.cancel()
111        ref_dut, dut_ref = ref_dut_res.connection, dut_ref_res.connection
112        assert_is_not_none(dut_ref)
113        assert dut_ref
114
115        (secure, wait_security) = await asyncio.gather(
116            self.dut.aio.security.Secure(connection=dut_ref, le=LE_LEVEL3),
117            self.ref.aio.security.WaitSecurity(connection=ref_dut, le=LE_LEVEL3),
118        )
119
120        pairing.cancel()
121        with suppress(asyncio.CancelledError, futures.CancelledError):
122            await pairing
123
124        assert_equal(secure.result_variant(), 'success')
125        assert_equal(wait_security.result_variant(), 'success')
126
127        await asyncio.gather(
128            self.ref.aio.host.Disconnect(connection=ref_dut),
129            self.dut.aio.host.WaitDisconnection(connection=dut_ref),
130        )
131        return ref
132
133    @asynchronous
134    async def test_le_pairing_delete_dup_bond_record(self) -> None:
135        if isinstance(self.dut, BumblePandoraDevice):
136            raise signals.TestSkip('TODO: Fix test for Bumble DUT')
137        if not isinstance(self.ref, BumblePandoraDevice):
138            raise signals.TestSkip('Test require Bumble as reference device(s)')
139
140        class Session(smp.Session):
141            # Hack to send same identity address from ref during both pairing
142            def send_command(self: smp.Session, command: smp.SMP_Command) -> None:
143                if isinstance(command, smp.SMP_Identity_Address_Information_Command):
144                    command = smp.SMP_Identity_Address_Information_Command(
145                        addr_type=Address.RANDOM_IDENTITY_ADDRESS,
146                        bd_addr=Address(
147                            'F6:F7:F8:F9:FA:FB',
148                            Address.RANDOM_IDENTITY_ADDRESS,
149                        ),
150                    )
151                self.manager.send_command(self.connection, command)
152
153        self.ref.device.smp_session_proxy = Session
154
155        # Pair with same device 2 times.
156        # Ref device advertises with different random address but uses same identity address
157        ref1 = await self.dut_pair(dut_address_type=RANDOM, ref_address_type=RANDOM)
158        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref1.random)
159        assert_true(is_bonded.value, "")
160
161        await self.ref.reset()
162        self.ref.device.smp_session_proxy = Session
163
164        ref2 = await self.dut_pair(dut_address_type=RANDOM, ref_address_type=RANDOM)
165        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref2.random)
166        assert_true(is_bonded.value, "")
167
168        is_bonded = await self.dut.aio.security_storage.IsBonded(random=ref1.random)
169        assert_false(is_bonded.value, "")
170
171    @asynchronous
172    async def test_mitm_sec_req_on_enc(self) -> None:
173        if not isinstance(self.ref, BumblePandoraDevice):
174            raise signals.TestSkip('Test require Bumble as reference device(s)')
175
176        io_capability = PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT
177        self.ref.server_config.io_capability = io_capability
178
179        advertisement = self.ref.aio.host.Advertise(
180            legacy=True,
181            connectable=True,
182            own_address_type=RANDOM,
183            data=DataTypes(manufacturer_specific_data=b'pause cafe'),
184        )
185
186        scan = self.dut.aio.host.Scan(own_address_type=RANDOM)
187        ref = await anext((x async for x in scan if b'pause cafe' in x.data.manufacturer_specific_data))
188        scan.cancel()
189
190        asyncio.create_task(self.handle_pairing_events())
191        (dut_ref_res, ref_dut_res) = await asyncio.gather(
192            self.dut.aio.host.ConnectLE(own_address_type=RANDOM, **ref.address_asdict()),
193            anext(aiter(advertisement)),
194        )
195
196        advertisement.cancel()
197        ref_dut, dut_ref = ref_dut_res.connection, dut_ref_res.connection
198        assert_is_not_none(dut_ref)
199        assert dut_ref
200
201        # Pair with MITM requirements
202        (secure, wait_security) = await asyncio.gather(
203            self.dut.aio.security.Secure(connection=dut_ref, le=LE_LEVEL3),
204            self.ref.aio.security.WaitSecurity(connection=ref_dut, le=LE_LEVEL3),
205        )
206
207        assert_equal(secure.result_variant(), 'success')
208        assert_equal(wait_security.result_variant(), 'success')
209
210        # Disconnect
211        await asyncio.gather(
212            self.ref.aio.host.Disconnect(connection=ref_dut),
213            self.dut.aio.host.WaitDisconnection(connection=dut_ref),
214        )
215
216        advertisement = self.ref.aio.host.Advertise(
217            legacy=True,
218            connectable=True,
219            own_address_type=RANDOM,
220            data=DataTypes(manufacturer_specific_data=b'pause cafe'),
221        )
222
223        scan = self.dut.aio.host.Scan(own_address_type=RANDOM)
224        ref = await anext((x async for x in scan if b'pause cafe' in x.data.manufacturer_specific_data))
225        scan.cancel()
226
227        (dut_ref_res, ref_dut_res) = await asyncio.gather(
228            self.dut.aio.host.ConnectLE(own_address_type=RANDOM, **ref.address_asdict()),
229            anext(aiter(advertisement)),
230        )
231        ref_dut, dut_ref = ref_dut_res.connection, dut_ref_res.connection
232
233        # Wait for the link to get encrypted
234        connection = self.ref.device.lookup_connection(int.from_bytes(ref_dut.cookie.value, 'big'))
235        assert_is_not_none(connection)
236        assert connection
237
238        self.ref.device.smp_manager.request_pairing(connection)
239
240        def on_connection_encryption_change() -> None:
241            assert isinstance(self.ref, BumblePandoraDevice)
242            self.ref.device.smp_manager.request_pairing(connection)
243
244        connection.on('connection_encryption_change', on_connection_encryption_change)
245
246        # Fail if repairing is initiated
247        fut = asyncio.get_running_loop().create_future()
248
249        class Session(smp.Session):
250
251            def on_smp_pairing_request_command(self, command: smp.SMP_Pairing_Request_Command) -> None:
252                nonlocal fut
253                fut.set_result(False)
254
255        self.ref.device.smp_session_proxy = Session
256
257        # Pass if the link is encrypted again
258        def on_connection_encryption_key_refresh() -> None:
259            nonlocal fut
260            fut.set_result(True)
261
262        connection.on('connection_encryption_key_refresh', on_connection_encryption_key_refresh)
263
264        assert_true(await fut, "Repairing initiated")
265
266
267if __name__ == '__main__':
268    logging.basicConfig(level=logging.DEBUG)
269    test_runner.main()  # type: ignore
270