1# Copyright (C) 2024 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15Copied from tools/rootcanal/scripts/test_channel.py
16"""
17
18import socket
19import enum
20from time import sleep
21
22
23class Connection:
24
25    def __init__(self, port):
26        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27        self._socket.connect(("localhost", port))
28
29    def close(self):
30        self._socket.close()
31
32    def send(self, data):
33        self._socket.sendall(data.encode())
34
35    def receive(self, size):
36        return self._socket.recv(size)
37
38
39class TestChannel:
40
41    def __init__(self, port):
42        self._connection = Connection(port)
43        self._closed = False
44
45    def close(self):
46        self._connection.close()
47        self._closed = True
48
49    def send_command(self, name, args):
50        args = [str(arg) for arg in args]
51        name_size = len(name)
52        args_size = len(args)
53        self.lint_command(name, args, name_size, args_size)
54        encoded_name = chr(name_size) + name
55        encoded_args = chr(args_size) + "".join(chr(len(arg)) + arg for arg in args)
56        command = encoded_name + encoded_args
57        if self._closed:
58            return
59        self._connection.send(command)
60        if name != "CLOSE_TEST_CHANNEL":
61            return self.receive_response().decode()
62
63    def receive_response(self):
64        if self._closed:
65            return b"Closed"
66        size_chars = self._connection.receive(4)
67        if not size_chars:
68            return b"No response, assuming that the connection is broken"
69        response_size = 0
70        for i in range(0, len(size_chars) - 1):
71            response_size |= size_chars[i] << (8 * i)
72        response = self._connection.receive(response_size)
73        return response
74
75    def lint_command(self, name, args, name_size, args_size):
76        assert name_size == len(name) and args_size == len(args)
77        try:
78            name.encode()
79            for arg in args:
80                arg.encode()
81        except UnicodeError:
82            print("Unrecognized characters.")
83            raise
84        if name_size > 255 or args_size > 255:
85            raise ValueError  # Size must be encodable in one octet.
86        for arg in args:
87            if len(arg) > 255:
88                raise ValueError  # Size must be encodable in one octet.
89
90
91class Dongle(enum.Enum):
92    DEFAULT = "default"
93    LAIRD_BL654 = "laird_bl654"
94    CSR_RCK_PTS_DONGLE = "csr_rck_pts_dongle"
95
96
97class RootCanal:
98
99    def __init__(self, port):
100        self.channel = TestChannel(port)
101        self.disconnected_dev_phys = None
102
103        # discard initialization messages
104        self.channel.receive_response()
105
106    def close(self):
107        self.channel.close()
108
109    def select_pts_dongle(self, dongle: Dongle):
110        """Use the control port to dynamically reconfigure the controller
111        properties for the dongle used by the PTS tester.
112
113        This method will cause a Reset on the controller.
114        This method shall exclusively be called from the test_started
115        interaction."""
116        # The PTS is the device with the highest ID,
117        # Android is always first to connect to root-canal.
118        (devices, _) = self._read_device_list()
119        pts_id = max([id for (id, _) in devices])
120        self.channel.send_command("set_device_configuration", [pts_id, dongle.value])
121
122    def move_out_of_range(self):
123        """Space out the connected devices to generate a supervision
124        timeout for all existing connections."""
125        # Disconnect all devices from all phys.
126        (devices, phys) = self._read_device_list()
127        for (device_id, _) in devices:
128            for (phy_id, _, phy_devices) in phys:
129                if device_id in phy_devices:
130                    self.channel.send_command("del_device_from_phy", [device_id, phy_id])
131
132    def move_in_range(self):
133        """Move the connected devices to the same point to ensure
134        the reconnection of previous links."""
135        # Reconnect all devices to all phys.
136        # Beacons are only added back to LE phys.
137        (devices, phys) = self._read_device_list()
138        for (device_id, device_name) in devices:
139            target_phys = ["LOW_ENERGY"]
140            if device_name.startswith("hci_device"):
141                target_phys.append("BR_EDR")
142
143            for (phy_id, phy_name, phy_devices) in phys:
144                if phy_name in target_phys and not device_id in phy_devices:
145                    self.channel.send_command("add_device_to_phy", [device_id, phy_id])
146
147    def _read_device_list(self):
148        """Query the list of connected devices."""
149        response = self.channel.send_command("list", [])
150
151        devices = []
152        phys = []
153        category = None
154
155        for line in response.split("\n"):
156            line = line.strip()
157            if not line:
158                continue
159            if line.startswith("Devices") or line.startswith("Phys"):
160                category = line.split(":")[0]
161            elif category == "Devices":
162                parts = line.split(":")
163                device_id = int(parts[0])
164                device_name = parts[1]
165                devices.append((device_id, device_name))
166            elif category == "Phys":
167                parts = line.split(":")
168                phy_id = int(parts[0])
169                phy_name = parts[1]
170                phy_devices = [int(id.strip()) for id in parts[2].split(",") if id.strip()]
171                phys.append((phy_id, phy_name, phy_devices))
172
173        return (devices, phys)
174