1# Copyright 2014 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import contextlib
6import logging
7import socket
8
9from telemetry.core import exceptions
10from telemetry.core import util
11from telemetry.internal import forwarders
12
13
14class Error(Exception):
15  """Base class for exceptions in this module."""
16  pass
17
18
19class PortsMismatchError(Error):
20  """Raised when local and remote ports are not equal."""
21  pass
22
23
24class ConnectionError(Error):
25  """Raised when unable to connect to local TCP ports."""
26  pass
27
28
29class DoNothingForwarderFactory(forwarders.ForwarderFactory):
30
31  def Create(self, port_pairs):
32    return DoNothingForwarder(port_pairs)
33
34
35class DoNothingForwarder(forwarders.Forwarder):
36  """Check that no forwarding is needed for the given port pairs.
37
38  The local and remote ports must be equal. Otherwise, the "do nothing"
39  forwarder does not make sense. (Raises PortsMismatchError.)
40
41  Also, check that all TCP ports support connections.  (Raises ConnectionError.)
42  """
43
44  def __init__(self, port_pairs):
45    super(DoNothingForwarder, self).__init__(port_pairs)
46    self._CheckPortPairs()
47
48  def _CheckPortPairs(self):
49    # namedtuple._asdict() is a public method. The method starts with an
50    # underscore to avoid conflicts with attribute names.
51    # pylint: disable=protected-access
52    for protocol, port_pair in self._port_pairs._asdict().items():
53      if not port_pair:
54        continue
55      local_port, remote_port = port_pair
56      if local_port != remote_port:
57        raise PortsMismatchError('Local port forwarding is not supported')
58      if protocol == 'dns':
59        logging.debug('Connection test SKIPPED for DNS: %s:%d',
60                      self.host_ip, local_port)
61        continue
62      try:
63        self._WaitForConnectionEstablished(
64            (self.host_ip, local_port), timeout=10)
65        logging.debug(
66            'Connection test succeeded for %s: %s:%d',
67            protocol.upper(), self.host_ip, local_port)
68      except exceptions.TimeoutException:
69        raise ConnectionError(
70            'Unable to connect to %s address: %s:%d',
71            protocol.upper(), self.host_ip, local_port)
72
73  def _WaitForConnectionEstablished(self, address, timeout):
74    def CanConnect():
75      with contextlib.closing(socket.socket()) as s:
76        return s.connect_ex(address) == 0
77    util.WaitFor(CanConnect, timeout)
78