1# Lint as: python2, python3
2# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""Spins up a trivial HTTP cgi form listener in a thread.
7
8   This HTTPThread class is a utility for use with test cases that
9   need to call back to the Autotest test case with some form value, e.g.
10   http://localhost:nnnn/?status="Browser started!"
11"""
12
13import cgi, errno, logging, os, posixpath, six.moves.SimpleHTTPServer, socket, ssl, sys
14import threading, six.moves.urllib.parse
15from six.moves import urllib
16from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
17from six.moves.socketserver import BaseServer, ThreadingMixIn
18
19
20def _handle_http_errors(func):
21    """Decorator function for cleaner presentation of certain exceptions."""
22    def wrapper(self):
23        try:
24            func(self)
25        except IOError as e:
26            if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
27                # Instead of dumping a stack trace, a single line is sufficient.
28                self.log_error(str(e))
29            else:
30                raise
31
32    return wrapper
33
34
35class FormHandler(six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler):
36    """Implements a form handler (for POST requests only) which simply
37    echoes the key=value parameters back in the response.
38
39    If the form submission is a file upload, the file will be written
40    to disk with the name contained in the 'filename' field.
41    """
42
43    six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
44        '.webm': 'video/webm',
45    })
46
47    # Override the default logging methods to use the logging module directly.
48    def log_error(self, format, *args):
49        logging.warning("(httpd error) %s - - [%s] %s\n" %
50                     (self.address_string(), self.log_date_time_string(),
51                      format%args))
52
53    def log_message(self, format, *args):
54        logging.debug("%s - - [%s] %s\n" %
55                     (self.address_string(), self.log_date_time_string(),
56                      format%args))
57
58    @_handle_http_errors
59    def do_POST(self):
60        form = cgi.FieldStorage(
61            fp=self.rfile,
62            headers=self.headers,
63            environ={'REQUEST_METHOD': 'POST',
64                     'CONTENT_TYPE': self.headers['Content-Type']})
65        # You'd think form.keys() would just return [], like it does for empty
66        # python dicts; you'd be wrong. It raises TypeError if called when it
67        # has no keys.
68        if form:
69            for field in form.keys():
70                field_item = form[field]
71                self.server._form_entries[field] = field_item.value
72        path = six.moves.urllib.parse.urlparse(self.path)[2]
73        if path in self.server._url_handlers:
74            self.server._url_handlers[path](self, form)
75        else:
76            # Echo back information about what was posted in the form.
77            self.write_post_response(form)
78        self._fire_event()
79
80
81    def write_post_response(self, form):
82        """Called to fill out the response to an HTTP POST.
83
84        Override this class to give custom responses.
85        """
86        # Send response boilerplate
87        self.send_response(200)
88        self.end_headers()
89        self.wfile.write(('Hello from Autotest!\nClient: %s\n' %
90                         str(self.client_address)).encode('utf-8'))
91        self.wfile.write(('Request for path: %s\n' % self.path).encode('utf-8'))
92        self.wfile.write(b'Got form data:\n')
93
94        # See the note in do_POST about form.keys().
95        if form:
96            for field in form.keys():
97                field_item = form[field]
98                if field_item.filename:
99                    # The field contains an uploaded file
100                    upload = field_item.file.read()
101                    self.wfile.write(('\tUploaded %s (%d bytes)<br>' %
102                                     (field, len(upload))).encode('utf-8'))
103                    # Write submitted file to specified filename.
104                    open(field_item.filename, 'w').write(upload)
105                    del upload
106                else:
107                    self.wfile.write(('\t%s=%s<br>' % (field, form[field].value)).encode('utf-8'))
108
109
110    def translate_path(self, path):
111        """Override SimpleHTTPRequestHandler's translate_path to serve
112        from arbitrary docroot
113        """
114        # abandon query parameters
115        path = six.moves.urllib.parse.urlparse(path)[2]
116        path = posixpath.normpath(urllib.parse.unquote(path))
117        words = path.split('/')
118        words = [_f for _f in words if _f]
119        path = self.server.docroot
120        for word in words:
121            drive, word = os.path.splitdrive(word)
122            head, word = os.path.split(word)
123            if word in (os.curdir, os.pardir): continue
124            path = os.path.join(path, word)
125        logging.debug('Translated path: %s', path)
126        return path
127
128
129    def _fire_event(self):
130        wait_urls = self.server._wait_urls
131        if self.path in wait_urls:
132            _, e = wait_urls[self.path]
133            e.set()
134            del wait_urls[self.path]
135        else:
136          if self.path not in self.server._urls:
137              # if the url is not in _urls, this means it was neither setup
138              # as a permanent, or event url.
139              logging.debug('URL %s not in watch list' % self.path)
140
141
142    @_handle_http_errors
143    def do_GET(self):
144        form = cgi.FieldStorage(
145            fp=self.rfile,
146            headers=self.headers,
147            environ={'REQUEST_METHOD': 'GET'})
148        split_url = six.moves.urllib.parse.urlsplit(self.path)
149        path = split_url[2]
150        # Strip off query parameters to ensure that the url path
151        # matches any registered events.
152        self.path = path
153        args = six.moves.urllib.parse.parse_qs(split_url[3])
154        if path in self.server._url_handlers:
155            self.server._url_handlers[path](self, args)
156        else:
157            six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
158        self._fire_event()
159
160
161    @_handle_http_errors
162    def do_HEAD(self):
163        six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)
164
165
166class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
167    def __init__(self, server_address, HandlerClass):
168        HTTPServer.__init__(self, server_address, HandlerClass)
169
170
171class HTTPListener(object):
172    # Point default docroot to a non-existent directory (instead of None) to
173    # avoid exceptions when page content is served through handlers only.
174    def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
175        self._server = ThreadedHTTPServer(('', port), FormHandler)
176        self.config_server(self._server, docroot, wait_urls, url_handlers)
177
178    def config_server(self, server, docroot, wait_urls, url_handlers):
179        # Stuff some convenient data fields into the server object.
180        self._server.docroot = docroot
181        self._server._urls = set()
182        self._server._wait_urls = wait_urls
183        self._server._url_handlers = url_handlers
184        self._server._form_entries = {}
185        self._server_thread = threading.Thread(
186            target=self._server.serve_forever)
187
188    def add_url(self, url):
189        """
190          Add a url to the urls that the http server is actively watching for.
191
192          Not adding a url via add_url or add_wait_url, and only installing a
193          handler will still result in that handler being executed, but this
194          server will warn in the debug logs that it does not expect that url.
195
196          Args:
197            url (string): url suffix to listen to
198        """
199        self._server._urls.add(url)
200
201    def add_wait_url(self, url='/', matchParams={}):
202        """
203          Add a wait url to the urls that the http server is aware of.
204
205          Not adding a url via add_url or add_wait_url, and only installing a
206          handler will still result in that handler being executed, but this
207          server will warn in the debug logs that it does not expect that url.
208
209          Args:
210            url (string): url suffix to listen to
211            matchParams (dictionary): an unused dictionary
212
213          Returns:
214            e, and event object. Call e.wait() on the object to wait (block)
215            until the server receives the first request for the wait url.
216
217        """
218        e = threading.Event()
219        self._server._wait_urls[url] = (matchParams, e)
220        self._server._urls.add(url)
221        return e
222
223    def add_url_handler(self, url, handler_func):
224        self._server._url_handlers[url] = handler_func
225
226    def clear_form_entries(self):
227        self._server._form_entries = {}
228
229
230    def get_form_entries(self):
231        """Returns a dictionary of all field=values recieved by the server.
232        """
233        return self._server._form_entries
234
235
236    def run(self):
237        logging.debug('http server on %s:%d' %
238                      (self._server.server_name, self._server.server_port))
239        self._server_thread.start()
240
241
242    def stop(self):
243        self._server.shutdown()
244        self._server.socket.close()
245        self._server_thread.join()
246
247
248class SecureHTTPServer(ThreadingMixIn, HTTPServer):
249    def __init__(self, server_address, HandlerClass, cert_path, key_path):
250        _socket = socket.socket(self.address_family, self.socket_type)
251        self.socket = ssl.wrap_socket(_socket,
252                                      server_side=True,
253                                      ssl_version=ssl.PROTOCOL_TLSv1,
254                                      certfile=cert_path,
255                                      keyfile=key_path)
256        BaseServer.__init__(self, server_address, HandlerClass)
257        self.server_bind()
258        self.server_activate()
259
260
261class SecureHTTPRequestHandler(FormHandler):
262    def setup(self):
263        self.connection = self.request
264        self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
265        self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)
266
267    # Override the default logging methods to use the logging module directly.
268    def log_error(self, format, *args):
269        logging.warning("(httpd error) %s - - [%s] %s\n" %
270                     (self.address_string(), self.log_date_time_string(),
271                      format%args))
272
273    def log_message(self, format, *args):
274        logging.debug("%s - - [%s] %s\n" %
275                     (self.address_string(), self.log_date_time_string(),
276                      format%args))
277
278
279class SecureHTTPListener(HTTPListener):
280    def __init__(self,
281                 cert_path='/etc/login_trust_root.pem',
282                 key_path='/etc/mock_server.key',
283                 port=0,
284                 docroot='/_',
285                 wait_urls={},
286                 url_handlers={}):
287        self._server = SecureHTTPServer(('', port),
288                                        SecureHTTPRequestHandler,
289                                        cert_path,
290                                        key_path)
291        self.config_server(self._server, docroot, wait_urls, url_handlers)
292
293
294    def getsockname(self):
295        return self._server.socket.getsockname()
296
297