1#!/usr/bin/env python
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import daemonserver
17import errno
18import logging
19import socket
20import SocketServer
21import threading
22import time
23
24from third_party.dns import flags
25from third_party.dns import message
26from third_party.dns import rcode
27from third_party.dns import resolver
28from third_party.dns import rdatatype
29from third_party import ipaddr
30
31
32
33class DnsProxyException(Exception):
34  pass
35
36
37DEFAULT_DNS_PORT = 53
38
39
40class RealDnsLookup(object):
41  def __init__(self, name_servers, dns_forwarding, proxy_host, proxy_port):
42    if (proxy_host in name_servers and proxy_port == DEFAULT_DNS_PORT and
43        dns_forwarding):
44      raise DnsProxyException(
45          'Invalid nameserver: %s (causes an infinte loop)'.format(
46              proxy_host))
47    self.resolver = resolver.get_default_resolver()
48    self.resolver.nameservers = name_servers
49    self.dns_cache_lock = threading.Lock()
50    self.dns_cache = {}
51
52  @staticmethod
53  def _IsIPAddress(hostname):
54    try:
55      socket.inet_aton(hostname)
56      return True
57    except socket.error:
58      return False
59
60  def __call__(self, hostname, rdtype=rdatatype.A):
61    """Return real IP for a host.
62
63    Args:
64      host: a hostname ending with a period (e.g. "www.google.com.")
65      rdtype: the query type (1 for 'A', 28 for 'AAAA')
66    Returns:
67      the IP address as a string (e.g. "192.168.25.2")
68    """
69    if self._IsIPAddress(hostname):
70      return hostname
71    self.dns_cache_lock.acquire()
72    ip = self.dns_cache.get(hostname)
73    self.dns_cache_lock.release()
74    if ip:
75      return ip
76    try:
77      answers = self.resolver.query(hostname, rdtype)
78    except resolver.NXDOMAIN:
79      return None
80    except resolver.NoNameservers:
81      logging.debug('_real_dns_lookup(%s) -> No nameserver.',
82                    hostname)
83      return None
84    except (resolver.NoAnswer, resolver.Timeout) as ex:
85      logging.debug('_real_dns_lookup(%s) -> None (%s)',
86                    hostname, ex.__class__.__name__)
87      return None
88    if answers:
89      ip = str(answers[0])
90    self.dns_cache_lock.acquire()
91    self.dns_cache[hostname] = ip
92    self.dns_cache_lock.release()
93    return ip
94
95  def ClearCache(self):
96    """Clear the dns cache."""
97    self.dns_cache_lock.acquire()
98    self.dns_cache.clear()
99    self.dns_cache_lock.release()
100
101
102class ReplayDnsLookup(object):
103  """Resolve DNS requests to replay host."""
104  def __init__(self, replay_ip, filters=None):
105    self.replay_ip = replay_ip
106    self.filters = filters or []
107
108  def __call__(self, hostname):
109    ip = self.replay_ip
110    for f in self.filters:
111      ip = f(hostname, default_ip=ip)
112    return ip
113
114
115class PrivateIpFilter(object):
116  """Resolve private hosts to their real IPs and others to the Web proxy IP.
117
118  Hosts in the given http_archive will resolve to the Web proxy IP without
119  checking the real IP.
120
121  This only supports IPv4 lookups.
122  """
123  def __init__(self, real_dns_lookup, http_archive):
124    """Initialize PrivateIpDnsLookup.
125
126    Args:
127      real_dns_lookup: a function that resolves a host to an IP.
128      http_archive: an instance of a HttpArchive
129        Hosts is in the archive will always resolve to the web_proxy_ip
130    """
131    self.real_dns_lookup = real_dns_lookup
132    self.http_archive = http_archive
133    self.InitializeArchiveHosts()
134
135  def __call__(self, host, default_ip):
136    """Return real IPv4 for private hosts and Web proxy IP otherwise.
137
138    Args:
139      host: a hostname ending with a period (e.g. "www.google.com.")
140    Returns:
141      IP address as a string or None (if lookup fails)
142    """
143    ip = default_ip
144    if host not in self.archive_hosts:
145      real_ip = self.real_dns_lookup(host)
146      if real_ip:
147        if ipaddr.IPAddress(real_ip).is_private:
148          ip = real_ip
149      else:
150        ip = None
151    return ip
152
153  def InitializeArchiveHosts(self):
154    """Recompute the archive_hosts from the http_archive."""
155    self.archive_hosts = set('%s.' % req.host.split(':')[0]
156                             for req in self.http_archive)
157
158
159class DelayFilter(object):
160  """Add a delay to replayed lookups."""
161
162  def __init__(self, is_record_mode, delay_ms):
163    self.is_record_mode = is_record_mode
164    self.delay_ms = int(delay_ms)
165
166  def __call__(self, host, default_ip):
167    if not self.is_record_mode:
168      time.sleep(self.delay_ms * 1000.0)
169    return default_ip
170
171  def SetRecordMode(self):
172    self.is_record_mode = True
173
174  def SetReplayMode(self):
175    self.is_record_mode = False
176
177
178class UdpDnsHandler(SocketServer.DatagramRequestHandler):
179  """Resolve DNS queries to localhost.
180
181  Possible alternative implementation:
182  http://howl.play-bow.org/pipermail/dnspython-users/2010-February/000119.html
183  """
184
185  STANDARD_QUERY_OPERATION_CODE = 0
186
187  def handle(self):
188    """Handle a DNS query.
189
190    IPv6 requests (with rdtype AAAA) receive mismatched IPv4 responses
191    (with rdtype A). To properly support IPv6, the http proxy would
192    need both types of addresses. By default, Windows XP does not
193    support IPv6.
194    """
195    self.data = self.rfile.read()
196    self.transaction_id = self.data[0]
197    self.flags = self.data[1]
198    self.qa_counts = self.data[4:6]
199    self.domain = ''
200    operation_code = (ord(self.data[2]) >> 3) & 15
201    if operation_code == self.STANDARD_QUERY_OPERATION_CODE:
202      self.wire_domain = self.data[12:]
203      self.domain = self._domain(self.wire_domain)
204    else:
205      logging.debug("DNS request with non-zero operation code: %s",
206                    operation_code)
207    ip = self.server.dns_lookup(self.domain)
208    if ip is None:
209      logging.debug('dnsproxy: %s -> NXDOMAIN', self.domain)
210      response = self.get_dns_no_such_name_response()
211    else:
212      if ip == self.server.server_address[0]:
213        logging.debug('dnsproxy: %s -> %s (replay web proxy)', self.domain, ip)
214      else:
215        logging.debug('dnsproxy: %s -> %s', self.domain, ip)
216      response = self.get_dns_response(ip)
217    self.wfile.write(response)
218
219  @classmethod
220  def _domain(cls, wire_domain):
221    domain = ''
222    index = 0
223    length = ord(wire_domain[index])
224    while length:
225      domain += wire_domain[index + 1:index + length + 1] + '.'
226      index += length + 1
227      length = ord(wire_domain[index])
228    return domain
229
230  def get_dns_response(self, ip):
231    packet = ''
232    if self.domain:
233      packet = (
234          self.transaction_id +
235          self.flags +
236          '\x81\x80' +        # standard query response, no error
237          self.qa_counts * 2 + '\x00\x00\x00\x00' +  # Q&A counts
238          self.wire_domain +
239          '\xc0\x0c'          # pointer to domain name
240          '\x00\x01'          # resource record type ("A" host address)
241          '\x00\x01'          # class of the data
242          '\x00\x00\x00\x3c'  # ttl (seconds)
243          '\x00\x04' +        # resource data length (4 bytes for ip)
244          socket.inet_aton(ip)
245          )
246    return packet
247
248  def get_dns_no_such_name_response(self):
249    query_message = message.from_wire(self.data)
250    response_message = message.make_response(query_message)
251    response_message.flags |= flags.AA | flags.RA
252    response_message.set_rcode(rcode.NXDOMAIN)
253    return response_message.to_wire()
254
255
256class DnsProxyServer(SocketServer.ThreadingUDPServer,
257                     daemonserver.DaemonServer):
258  # Increase the request queue size. The default value, 5, is set in
259  # SocketServer.TCPServer (the parent of BaseHTTPServer.HTTPServer).
260  # Since we're intercepting many domains through this single server,
261  # it is quite possible to get more than 5 concurrent requests.
262  request_queue_size = 256
263
264  # Allow sockets to be reused. See
265  # http://svn.python.org/projects/python/trunk/Lib/SocketServer.py for more
266  # details.
267  allow_reuse_address = True
268
269  # Don't prevent python from exiting when there is thread activity.
270  daemon_threads = True
271
272  def __init__(self, host='', port=53, dns_lookup=None):
273    """Initialize DnsProxyServer.
274
275    Args:
276      host: a host string (name or IP) to bind the dns proxy and to which
277        DNS requests will be resolved.
278      port: an integer port on which to bind the proxy.
279      dns_lookup: a list of filters to apply to lookup.
280    """
281    try:
282      SocketServer.ThreadingUDPServer.__init__(
283          self, (host, port), UdpDnsHandler)
284    except socket.error, (error_number, msg):
285      if error_number == errno.EACCES:
286        raise DnsProxyException(
287            'Unable to bind DNS server on (%s:%s)' % (host, port))
288      raise
289    self.dns_lookup = dns_lookup or (lambda host: self.server_address[0])
290    self.server_port = self.server_address[1]
291    logging.warning('DNS server started on %s:%d', self.server_address[0],
292                                                   self.server_address[1])
293
294  def cleanup(self):
295    try:
296      self.shutdown()
297      self.server_close()
298    except KeyboardInterrupt, e:
299      pass
300    logging.info('Stopped DNS server')
301