1#!/usr/bin/python2
2
3# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7from __future__ import absolute_import
8from __future__ import division
9from __future__ import print_function
10
11import logging
12from six.moves import range
13import socket
14import sys
15import time
16
17import common
18
19from autotest_lib.client.cros import dhcp_handling_rule
20from autotest_lib.client.cros import dhcp_packet
21from autotest_lib.client.cros import dhcp_test_server
22
23TEST_DATA_PATH_PREFIX = "client/cros/dhcp_test_data/"
24
25TEST_CLASSLESS_STATIC_ROUTE_DATA = \
26        "\x12\x0a\x09\xc0\xac\x1f\x9b\x0a" \
27        "\x00\xc0\xa8\x00\xfe"
28
29TEST_CLASSLESS_STATIC_ROUTE_LIST_PARSED = [
30        (18, "10.9.192.0", "172.31.155.10"),
31        (0, "0.0.0.0", "192.168.0.254")
32        ]
33
34TEST_DOMAIN_SEARCH_LIST_COMPRESSED = \
35        "\x03eng\x06google\x03com\x00\x09marketing\xC0\x04"
36
37TEST_DOMAIN_SEARCH_LIST_PARSED = ("eng.google.com", "marketing.google.com")
38
39# At this time, we don't support the compression allowed in the RFC.
40# This is correct and sufficient for our purposes.
41TEST_DOMAIN_SEARCH_LIST_EXPECTED = \
42        "\x03eng\x06google\x03com\x00\x09marketing\x06google\x03com\x00"
43
44TEST_DOMAIN_SEARCH_LIST1 = \
45        "w\x10\x03eng\x06google\x03com\x00"
46
47TEST_DOMAIN_SEARCH_LIST2 = \
48        "w\x16\x09marketing\x06google\x03com\x00"
49
50def bin2hex(byte_str, justification=20):
51    """
52    Turn big hex strings into prettier strings of hex bytes.  Group those hex
53    bytes into lines justification bytes long.
54    """
55    chars = ["x" + (hex(ord(c))[2:].zfill(2)) for c in byte_str]
56    groups = []
57    for i in range(0, len(chars), justification):
58        groups.append("".join(chars[i:i+justification]))
59    return "\n".join(groups)
60
61def test_packet_serialization():
62    log_file = open(TEST_DATA_PATH_PREFIX + "dhcp_discovery.log", "rb")
63    binary_discovery_packet = log_file.read()
64    log_file.close()
65    discovery_packet = dhcp_packet.DhcpPacket(byte_str=binary_discovery_packet)
66    if not discovery_packet.is_valid:
67        return False
68    generated_string = discovery_packet.to_binary_string()
69    if generated_string is None:
70        print("Failed to generate string from packet object.")
71        return False
72    if generated_string != binary_discovery_packet:
73        print("Packets didn't match: ")
74        print("Generated: \n%s" % bin2hex(generated_string))
75        print("Expected: \n%s" % bin2hex(binary_discovery_packet))
76        return False
77    print("test_packet_serialization PASSED")
78    return True
79
80def test_classless_static_route_parsing():
81    parsed_routes = dhcp_packet.ClasslessStaticRoutesOption.unpack(
82            TEST_CLASSLESS_STATIC_ROUTE_DATA)
83    if parsed_routes != TEST_CLASSLESS_STATIC_ROUTE_LIST_PARSED:
84        print("Parsed binary domain list and got %s but expected %s" %
85               (repr(parsed_routes),
86                repr(TEST_CLASSLESS_STATIC_ROUTE_LIST_PARSED)))
87        return False
88    print("test_classless_static_route_parsing PASSED")
89    return True
90
91def test_classless_static_route_serialization():
92    byte_string = dhcp_packet.ClasslessStaticRoutesOption.pack(
93            TEST_CLASSLESS_STATIC_ROUTE_LIST_PARSED)
94    if byte_string != TEST_CLASSLESS_STATIC_ROUTE_DATA:
95        # Turn the strings into printable hex strings on a single line.
96        pretty_actual = bin2hex(byte_string, 100)
97        pretty_expected = bin2hex(TEST_CLASSLESS_STATIC_ROUTE_DATA, 100)
98        print("Expected to serialize %s to %s but instead got %s." %
99               (repr(TEST_CLASSLESS_STATIC_ROUTE_LIST_PARSED), pretty_expected,
100                     pretty_actual))
101        return False
102    print("test_classless_static_route_serialization PASSED")
103    return True
104
105def test_domain_search_list_parsing():
106    parsed_domains = dhcp_packet.DomainListOption.unpack(
107            TEST_DOMAIN_SEARCH_LIST_COMPRESSED)
108    # Order matters too.
109    parsed_domains = tuple(parsed_domains)
110    if parsed_domains != TEST_DOMAIN_SEARCH_LIST_PARSED:
111        print("Parsed binary domain list and got %s but expected %s" %
112               (parsed_domains, TEST_DOMAIN_SEARCH_LIST_EXPECTED))
113        return False
114    print("test_domain_search_list_parsing PASSED")
115    return True
116
117def test_domain_search_list_serialization():
118    byte_string = dhcp_packet.DomainListOption.pack(
119            TEST_DOMAIN_SEARCH_LIST_PARSED)
120    if byte_string != TEST_DOMAIN_SEARCH_LIST_EXPECTED:
121        # Turn the strings into printable hex strings on a single line.
122        pretty_actual = bin2hex(byte_string, 100)
123        pretty_expected = bin2hex(TEST_DOMAIN_SEARCH_LIST_EXPECTED, 100)
124        print("Expected to serialize %s to %s but instead got %s." %
125               (TEST_DOMAIN_SEARCH_LIST_PARSED, pretty_expected, pretty_actual))
126        return False
127    print("test_domain_search_list_serialization PASSED")
128    return True
129
130def test_broken_domain_search_list_parsing():
131    byte_string = '\x00' * 240 + TEST_DOMAIN_SEARCH_LIST1 + TEST_DOMAIN_SEARCH_LIST2 + '\xff'
132    packet = dhcp_packet.DhcpPacket(byte_str=byte_string)
133    if len(packet._options) != 1:
134        print("Expected domain list of length 1")
135        return False
136    for k, v in packet._options.items():
137        if tuple(v) != TEST_DOMAIN_SEARCH_LIST_PARSED:
138            print("Expected binary domain list and got %s but expected %s" %
139                    (tuple(v), TEST_DOMAIN_SEARCH_LIST_PARSED))
140            return False
141    print("test_broken_domain_search_list_parsing PASSED")
142    return True
143
144def receive_packet(a_socket, timeout_seconds=1.0):
145    data = None
146    start_time = time.time()
147    while data is None and start_time + timeout_seconds > time.time():
148        try:
149            data, _ = a_socket.recvfrom(1024)
150        except socket.timeout:
151            pass # We expect many timeouts.
152    if data is None:
153        print("Timed out before we received a response from the server.")
154        return None
155
156    print("Client received a packet of length %d from the server." % len(data))
157    packet = dhcp_packet.DhcpPacket(byte_str=data)
158    if not packet.is_valid:
159        print("Received an invalid response from DHCP server.")
160        return None
161
162    return packet
163
164def test_simple_server_exchange(server):
165    intended_ip = "127.0.0.42"
166    subnet_mask = "255.255.255.0"
167    server_ip = "127.0.0.1"
168    lease_time_seconds = 60
169    test_timeout = 3.0
170    mac_addr = "\x01\x02\x03\x04\x05\x06"
171    # Build up our packets and have them request some default option values,
172    # like the IP we're being assigned and the address of the server assigning
173    # it.
174    discovery_message = dhcp_packet.DhcpPacket.create_discovery_packet(mac_addr)
175    discovery_message.set_option(
176            dhcp_packet.OPTION_PARAMETER_REQUEST_LIST,
177            dhcp_packet.OPTION_VALUE_PARAMETER_REQUEST_LIST_DEFAULT)
178    request_message = dhcp_packet.DhcpPacket.create_request_packet(
179            discovery_message.transaction_id,
180            mac_addr)
181    request_message.set_option(
182            dhcp_packet.OPTION_PARAMETER_REQUEST_LIST,
183            dhcp_packet.OPTION_VALUE_PARAMETER_REQUEST_LIST_DEFAULT)
184    # This is the pool of settings the DHCP server will seem to draw from to
185    # answer queries from the client.  This information is written into packets
186    # through the handling rules.
187    dhcp_server_config = {
188            dhcp_packet.OPTION_SERVER_ID : server_ip,
189            dhcp_packet.OPTION_SUBNET_MASK : subnet_mask,
190            dhcp_packet.OPTION_IP_LEASE_TIME : lease_time_seconds,
191            dhcp_packet.OPTION_REQUESTED_IP : intended_ip,
192            }
193    # Build up the handling rules for the server and start the test.
194    rules = []
195    rules.append(dhcp_handling_rule.DhcpHandlingRule_RespondToDiscovery(
196            intended_ip,
197            server_ip,
198            dhcp_server_config, {}))
199    rules.append(dhcp_handling_rule.DhcpHandlingRule_RespondToRequest(
200            intended_ip,
201            server_ip,
202            dhcp_server_config, {}))
203    rules[-1].is_final_handler = True
204    server.start_test(rules, test_timeout)
205    # Because we don't want to require root permissions to run these tests,
206    # listen on the loopback device, don't broadcast, and don't use reserved
207    # ports (like the actual DHCP ports).  Use 8068/8067 instead.
208    client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
209    client_socket.bind(("127.0.0.1", 8068))
210    client_socket.settimeout(0.1)
211    client_socket.sendto(discovery_message.to_binary_string(),
212                         (server_ip, 8067))
213
214    offer_packet = receive_packet(client_socket)
215    if offer_packet is None:
216        return False
217
218    if (offer_packet.message_type != dhcp_packet.MESSAGE_TYPE_OFFER):
219        print("Type of DHCP response is not offer.")
220        return False
221
222    if offer_packet.get_field(dhcp_packet.FIELD_YOUR_IP) != intended_ip:
223        print("Server didn't offer the IP we expected.")
224        return False
225
226    print("Offer looks good to the client, sending request.")
227    # In real tests, dhcpcd formats all the DISCOVERY and REQUEST messages.  In
228    # our unit test, we have to do this ourselves.
229    request_message.set_option(
230            dhcp_packet.OPTION_SERVER_ID,
231            offer_packet.get_option(dhcp_packet.OPTION_SERVER_ID))
232    request_message.set_option(
233            dhcp_packet.OPTION_SUBNET_MASK,
234            offer_packet.get_option(dhcp_packet.OPTION_SUBNET_MASK))
235    request_message.set_option(
236            dhcp_packet.OPTION_IP_LEASE_TIME,
237            offer_packet.get_option(dhcp_packet.OPTION_IP_LEASE_TIME))
238    request_message.set_option(
239            dhcp_packet.OPTION_REQUESTED_IP,
240            offer_packet.get_option(dhcp_packet.OPTION_REQUESTED_IP))
241    # Send the REQUEST message.
242    client_socket.sendto(request_message.to_binary_string(),
243                         (server_ip, 8067))
244    ack_packet = receive_packet(client_socket)
245    if ack_packet is None:
246        return False
247
248    if (ack_packet.message_type != dhcp_packet.MESSAGE_TYPE_ACK):
249        print("Type of DHCP response is not acknowledgement.")
250        return False
251
252    if ack_packet.get_field(dhcp_packet.FIELD_YOUR_IP) != intended_ip:
253        print("Server didn't give us the IP we expected.")
254        return False
255
256    print("Waiting for the server to finish.")
257    server.wait_for_test_to_finish()
258    print("Server agrees that the test is over.")
259    if not server.last_test_passed:
260        print("Server is unhappy with the test result.")
261        return False
262
263    print("test_simple_server_exchange PASSED.")
264    return True
265
266def test_server_dialogue():
267    server = dhcp_test_server.DhcpTestServer(ingress_address="127.0.0.1",
268                                             ingress_port=8067,
269                                             broadcast_address="127.0.0.1",
270                                             broadcast_port=8068)
271    server.start()
272    ret = False
273    if server.is_healthy:
274        ret = test_simple_server_exchange(server)
275    else:
276        print("Server isn't healthy, aborting.")
277    print("Sending server stop() signal.")
278    server.stop()
279    print("Stop signal sent.")
280    return ret
281
282def run_tests():
283    logger = logging.getLogger("dhcp")
284    logger.setLevel(logging.DEBUG)
285    stream_handler = logging.StreamHandler()
286    stream_handler.setLevel(logging.DEBUG)
287    logger.addHandler(stream_handler)
288    retval = test_packet_serialization()
289    retval &= test_classless_static_route_parsing()
290    retval &= test_classless_static_route_serialization()
291    retval &= test_domain_search_list_parsing()
292    retval &= test_domain_search_list_serialization()
293    retval &= test_broken_domain_search_list_parsing()
294    retval &= test_server_dialogue()
295    if retval:
296        print("All tests PASSED.")
297        return 0
298    else:
299        print("Some tests FAILED")
300        return -1
301
302if __name__ == "__main__":
303    sys.exit(run_tests())
304