1# Copyright 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import logging
6import socket
7import struct
8import time
9
10from autotest_lib.client.common_lib import error
11from autotest_lib.client.common_lib.cros.network import interface
12
13
14class InterfaceHost(object):
15    """A host for use with ZeroconfDaemon that binds to an interface."""
16
17    @property
18    def ip_addr(self):
19        """Get the IP address of the interface we're bound to."""
20        return self._interface.ipv4_address
21
22
23    def __init__(self, interface_name):
24        self._interface = interface.Interface(interface_name)
25        self._socket = None
26
27
28    def close(self):
29        """Close the underlying socket."""
30        if self._socket:
31            self._socket.close()
32
33
34    def socket(self, family, sock_type):
35        """Get a socket bound to this interface.
36
37        Only supports IPv4 UDP sockets on broadcast addresses.
38
39        @param family: must be socket.AF_INET.
40        @param sock_type: must be socket.SOCK_DGRAM.
41
42        """
43        if family != socket.AF_INET or sock_type != socket.SOCK_DGRAM:
44            raise error.TestError('InterfaceHost only understands UDP sockets.')
45        if self._socket is not None:
46            raise error.TestError('InterfaceHost only supports a single '
47                                  'multicast socket.')
48
49        self._socket = InterfaceDatagramSocket(self.ip_addr)
50        return self._socket
51
52
53    def run_until(self, predicate, timeout_seconds):
54        """Handle traffic from our socket until |predicate|() is true.
55
56        @param predicate: function without arguments that returns True or False.
57        @param timeout_seconds: number of seconds to wait for predicate to
58                                become True.
59        @return: tuple(success, duration) where success is True iff predicate()
60                 became true before |timeout_seconds| passed.
61
62        """
63        start_time = time.time()
64        duration = lambda: time.time() - start_time
65        while duration() < timeout_seconds:
66            if predicate():
67                return True, duration()
68            # Assume this take non-trivial time, don't sleep here.
69            self._socket.run_once()
70        return False, duration()
71
72
73class InterfaceDatagramSocket(object):
74    """Broadcast UDP socket bound to a particular network interface."""
75
76    # Wait for a UDP frame to appear for this long before timing out.
77    TIMEOUT_VALUE_SECONDS = 0.5
78
79    def __init__(self, interface_ip):
80        """Construct an instance.
81
82        @param interface_ip: string like '239.192.1.100'.
83
84        """
85        self._interface_ip = interface_ip
86        self._recv_callback = None
87        self._recv_sock = None
88        self._send_sock = None
89
90
91    def close(self):
92        """Close state associated with this object."""
93        if self._recv_sock is not None:
94            # Closing the socket drops membership groups.
95            self._recv_sock.close()
96            self._recv_sock = None
97        if self._send_sock is not None:
98            self._send_sock.close()
99            self._send_sock = None
100
101
102    def listen(self, ip_addr, port, recv_callback):
103        """Bind and listen on the ip_addr:port.
104
105        @param ip_addr: Multicast group IP (e.g. '224.0.0.251')
106        @param port: Local destination port number.
107        @param recv_callback: A callback function that accepts three arguments,
108                              the received string, the sender IPv4 address and
109                              the sender port number.
110
111        """
112        if self._recv_callback is not None:
113            raise error.TestError('listen() called twice on '
114                                  'InterfaceDatagramSocket.')
115        # Multicast addresses are in 224.0.0.0 - 239.255.255.255 (rfc5771)
116        ip_addr_prefix = ord(socket.inet_aton(ip_addr)[0])
117        if ip_addr_prefix < 224 or ip_addr_prefix > 239:
118            raise error.TestError('Invalid multicast address.')
119
120        self._recv_callback = recv_callback
121        # Set up a socket to receive just traffic from the given address.
122        self._recv_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
123        self._recv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
124        self._recv_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
125                                   socket.inet_aton(ip_addr) +
126                                   socket.inet_aton(self._interface_ip))
127        self._recv_sock.settimeout(self.TIMEOUT_VALUE_SECONDS)
128        self._recv_sock.bind((ip_addr, port))
129        # When we send responses, we want to send them from this particular
130        # interface.  The easiest way to do this is bind a socket directly to
131        # the IP for the interface.  We're going to ignore messages sent to this
132        # socket.
133        self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
134        self._send_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
135        self._send_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL,
136                                   struct.pack('b', 1))
137        self._send_sock.bind((self._interface_ip, port))
138
139
140    def run_once(self):
141        """Receive pending frames if available, return after timeout otw."""
142        if self._recv_sock is None:
143            raise error.TestError('Must listen() on socket before recv\'ing.')
144        BUFFER_SIZE_BYTES = 2048
145        try:
146            data, sender_addr = self._recv_sock.recvfrom(BUFFER_SIZE_BYTES)
147        except socket.timeout:
148            return
149        if len(sender_addr) != 2:
150            logging.error('Unexpected address: %r', sender_addr)
151        self._recv_callback(data, *sender_addr)
152
153
154    def send(self, data, ip_addr, port):
155        """Send |data| to an IPv4 address.
156
157        @param data: string of raw bytes to send.
158        @param ip_addr: string like '239.192.1.100'.
159        @param port: int like 50000.
160
161        """
162        self._send_sock.sendto(data, (ip_addr, port))
163