1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import dbus
6import logging
7import socket
8import time
9import urllib2
10
11import common
12
13# Import 'flimflam_test_path' first in order to import 'routing'.
14# Disable warning about flimflam_test_path not being used since it is used
15# to find routing but not explicitly used as a module.
16# pylint: disable-msg=W0611
17import flimflam_test_path
18import routing
19
20from autotest_lib.client.bin import utils
21from autotest_lib.client.common_lib import error
22from autotest_lib.client.cros.cellular import mm
23
24
25def _Bug24628WorkaroundEnable(modem):
26    """Enables a modem.  Try again if a SerialResponseTimeout is received."""
27    # http://code.google.com/p/chromium-os/issues/detail?id=24628
28    tries = 5
29    while tries > 0:
30        try:
31            modem.Enable(True)
32            return
33        except dbus.exceptions.DBusException, e:
34            logging.error('Enable failed: %s', e)
35            tries -= 1
36            if tries > 0:
37                logging.error('_Bug24628WorkaroundEnable:  sleeping')
38                time.sleep(6)
39                logging.error('_Bug24628WorkaroundEnable:  retrying')
40            else:
41                raise
42
43
44# TODO(rochberg):  Move modem-specific functions to cellular/cell_utils
45def ResetAllModems(conn_mgr):
46    """
47    Disables/Enables cycle all modems to ensure valid starting state.
48
49    @param conn_mgr: Connection manager (shill)
50    """
51    service = conn_mgr.FindCellularService()
52    if not service:
53        conn_mgr.EnableTechnology('cellular')
54        service = conn_mgr.FindCellularService()
55
56    logging.info('ResetAllModems: found service %s', service)
57
58    try:
59        if service:
60            service.SetProperty('AutoConnect', False),
61    except dbus.exceptions.DBusException, e:
62        # The service object may disappear, we can safely ignore it.
63        if e._dbus_error_name != 'org.freedesktop.DBus.Error.UnknownMethod':
64            raise
65
66    for manager, path in mm.EnumerateDevices():
67        modem = manager.GetModem(path)
68        version = modem.GetVersion()
69        # Icera modems behave weirdly if we cancel the operation while the
70        # modem is connecting or disconnecting. Work around the issue by waiting
71        # until the connect/disconnect operation completes.
72        # TODO(benchan): Remove this workaround once the issue is addressed
73        # on the modem side.
74        utils.poll_for_condition(
75            lambda: not modem.IsConnectingOrDisconnecting(),
76            exception=utils.TimeoutError('Timed out waiting for modem to ' +
77                                         'finish connecting/disconnecting'),
78            sleep_interval=1,
79            timeout=30)
80        modem.Enable(False)
81        # Although we disable at the ModemManager level, we need to wait for
82        # shill to process the disable to ensure the modem is in a stable state
83        # before continuing else we may end up trying to enable a modem that
84        # is still in the process of disabling.
85        cm_device = conn_mgr.FindElementByPropertySubstring('Device',
86                                                            'DBus.Object',
87                                                            path)
88        utils.poll_for_condition(
89            lambda: not cm_device.GetProperties()['Powered'],
90            exception=utils.TimeoutError(
91                'Timed out waiting for shill device disable'),
92            sleep_interval=1,
93            timeout=30)
94        assert modem.IsDisabled()
95
96        if 'Y3300XXKB1' in version:
97            _Bug24628WorkaroundEnable(modem)
98        else:
99            modem.Enable(True)
100            # Wait for shill to process the enable for the same reason as
101            # above (during disable).
102            utils.poll_for_condition(
103                lambda: cm_device.GetProperties()['Powered'],
104                exception=utils.TimeoutError(
105                    'Timed out waiting for shill device enable'),
106                sleep_interval=1,
107                timeout=30)
108            assert modem.IsEnabled()
109
110
111class IpTablesContext(object):
112    """Context manager that manages iptables rules."""
113    IPTABLES = '/sbin/iptables'
114
115    def __init__(self, initial_allowed_host=None):
116        self.initial_allowed_host = initial_allowed_host
117        self.rules = []
118
119    def _IpTables(self, command):
120        # Run, log, return output
121        return utils.system_output('%s %s' % (self.IPTABLES, command),
122                                   retain_output=True)
123
124    def _RemoveRule(self, rule):
125        self._IpTables('-D ' + rule)
126        self.rules.remove(rule)
127
128    def AllowHost(self, host):
129        """
130        Allows the specified host through the firewall.
131
132        @param host: Name of host to allow
133        """
134        for proto in ['tcp', 'udp']:
135            rule = 'INPUT -s %s/32 -p %s -m %s -j ACCEPT' % (host, proto, proto)
136            output = self._IpTables('-S INPUT')
137            current = [x.rstrip() for x in output.splitlines()]
138            logging.error('current: %s', current)
139            if '-A ' + rule in current:
140                # Already have the rule
141                logging.info('Not adding redundant %s', rule)
142                continue
143            self._IpTables('-A '+ rule)
144            self.rules.append(rule)
145
146    def _CleanupRules(self):
147        for rule in self.rules:
148            self._RemoveRule(rule)
149
150    def __enter__(self):
151        if self.initial_allowed_host:
152            self.AllowHost(self.initial_allowed_host)
153        return self
154
155    def __exit__(self, exception, value, traceback):
156        self._CleanupRules()
157        return False
158
159
160def NameServersForService(conn_mgr, service):
161    """
162    Returns the list of name servers used by a connected service.
163
164    @param conn_mgr: Connection manager (shill)
165    @param service: Name of the connected service
166    @return: List of name servers used by |service|
167    """
168    service_properties = service.GetProperties(utf8_strings=True)
169    device_path = service_properties['Device']
170    device = conn_mgr.GetObjectInterface('Device', device_path)
171    if device is None:
172        logging.error('No device for service %s', service)
173        return []
174
175    properties = device.GetProperties(utf8_strings=True)
176
177    hosts = []
178    for path in properties['IPConfigs']:
179        ipconfig = conn_mgr.GetObjectInterface('IPConfig', path)
180        ipconfig_properties = ipconfig.GetProperties(utf8_strings=True)
181        hosts += ipconfig_properties['NameServers']
182
183    logging.info('Name servers: %s', ', '.join(hosts))
184
185    return hosts
186
187
188def CheckInterfaceForDestination(host, expected_interface):
189    """
190    Checks that routes for host go through a given interface.
191
192    The concern here is that our network setup may have gone wrong
193    and our test connections may go over some other network than
194    the one we're trying to test.  So we take all the IP addresses
195    for the supplied host and make sure they go through the given
196    network interface.
197
198    @param host: Destination host
199    @param expected_interface: Expected interface name
200    @raises: error.TestFail if the routes for the given host go through
201            a different interface than the expected one.
202
203    """
204    # addrinfo records: (family, type, proto, canonname, (addr, port))
205    server_addresses = [record[4][0]
206                        for record in socket.getaddrinfo(host, 80)]
207
208    routes = routing.NetworkRoutes()
209    for address in server_addresses:
210        interface = routes.getRouteFor(address).interface
211        logging.info('interface for %s: %s', address, interface)
212        if interface != expected_interface:
213            raise error.TestFail('Target server %s uses interface %s'
214                                 '(%s expected).' %
215                                 (address, interface, expected_interface))
216
217
218FETCH_URL_PATTERN_FOR_TEST = \
219    'http://testing-chargen.appspot.com/download?size=%d'
220
221def FetchUrl(url_pattern, bytes_to_fetch=10, fetch_timeout=10):
222    """
223    Fetches a specified number of bytes from a URL.
224
225    @param url_pattern: URL pattern for fetching a specified number of bytes.
226            %d in the pattern is to be filled in with the number of bytes to
227            fetch.
228    @param bytes_to_fetch: Number of bytes to fetch.
229    @param fetch_timeout: Number of seconds to wait for the fetch to complete
230            before it times out.
231    @return: The time in seconds spent for fetching the specified number of
232            bytes.
233    @raises: error.TestError if one of the following happens:
234            - The fetch takes no time.
235            - The number of bytes fetched differs from the specified
236              number.
237
238    """
239    # Limit the amount of bytes to read at a time.
240    _MAX_FETCH_READ_BYTES = 1024 * 1024
241
242    url = url_pattern % bytes_to_fetch
243    logging.info('FetchUrl %s', url)
244    start_time = time.time()
245    result = urllib2.urlopen(url, timeout=fetch_timeout)
246    bytes_fetched = 0
247    while bytes_fetched < bytes_to_fetch:
248        bytes_left = bytes_to_fetch - bytes_fetched
249        bytes_to_read = min(bytes_left, _MAX_FETCH_READ_BYTES)
250        bytes_read = len(result.read(bytes_to_read))
251        bytes_fetched += bytes_read
252        if bytes_read != bytes_to_read:
253            raise error.TestError('FetchUrl tried to read %d bytes, but got '
254                                  '%d bytes instead.' %
255                                  (bytes_to_read, bytes_read))
256        fetch_time = time.time() - start_time
257        if fetch_time > fetch_timeout:
258            raise error.TestError('FetchUrl exceeded timeout.')
259
260    return fetch_time
261