1# Lint as: python2, python3 2# Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6"""Spins up a trivial HTTP cgi form listener in a thread. 7 8 This HTTPThread class is a utility for use with test cases that 9 need to call back to the Autotest test case with some form value, e.g. 10 http://localhost:nnnn/?status="Browser started!" 11""" 12 13import cgi, errno, logging, os, posixpath, six.moves.SimpleHTTPServer, socket, ssl, sys 14import threading, six.moves.urllib.parse 15from six.moves import urllib 16from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer 17from six.moves.socketserver import BaseServer, ThreadingMixIn 18 19 20def _handle_http_errors(func): 21 """Decorator function for cleaner presentation of certain exceptions.""" 22 def wrapper(self): 23 try: 24 func(self) 25 except IOError as e: 26 if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET: 27 # Instead of dumping a stack trace, a single line is sufficient. 28 self.log_error(str(e)) 29 else: 30 raise 31 32 return wrapper 33 34 35class FormHandler(six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler): 36 """Implements a form handler (for POST requests only) which simply 37 echoes the key=value parameters back in the response. 38 39 If the form submission is a file upload, the file will be written 40 to disk with the name contained in the 'filename' field. 41 """ 42 43 six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({ 44 '.webm': 'video/webm', 45 }) 46 47 # Override the default logging methods to use the logging module directly. 48 def log_error(self, format, *args): 49 logging.warning("(httpd error) %s - - [%s] %s\n" % 50 (self.address_string(), self.log_date_time_string(), 51 format%args)) 52 53 def log_message(self, format, *args): 54 logging.debug("%s - - [%s] %s\n" % 55 (self.address_string(), self.log_date_time_string(), 56 format%args)) 57 58 @_handle_http_errors 59 def do_POST(self): 60 form = cgi.FieldStorage( 61 fp=self.rfile, 62 headers=self.headers, 63 environ={'REQUEST_METHOD': 'POST', 64 'CONTENT_TYPE': self.headers['Content-Type']}) 65 # You'd think form.keys() would just return [], like it does for empty 66 # python dicts; you'd be wrong. It raises TypeError if called when it 67 # has no keys. 68 if form: 69 for field in form.keys(): 70 field_item = form[field] 71 self.server._form_entries[field] = field_item.value 72 path = six.moves.urllib.parse.urlparse(self.path)[2] 73 if path in self.server._url_handlers: 74 self.server._url_handlers[path](self, form) 75 else: 76 # Echo back information about what was posted in the form. 77 self.write_post_response(form) 78 self._fire_event() 79 80 81 def write_post_response(self, form): 82 """Called to fill out the response to an HTTP POST. 83 84 Override this class to give custom responses. 85 """ 86 # Send response boilerplate 87 self.send_response(200) 88 self.end_headers() 89 self.wfile.write(('Hello from Autotest!\nClient: %s\n' % 90 str(self.client_address)).encode('utf-8')) 91 self.wfile.write(('Request for path: %s\n' % self.path).encode('utf-8')) 92 self.wfile.write(b'Got form data:\n') 93 94 # See the note in do_POST about form.keys(). 95 if form: 96 for field in form.keys(): 97 field_item = form[field] 98 if field_item.filename: 99 # The field contains an uploaded file 100 upload = field_item.file.read() 101 self.wfile.write(('\tUploaded %s (%d bytes)<br>' % 102 (field, len(upload))).encode('utf-8')) 103 # Write submitted file to specified filename. 104 open(field_item.filename, 'w').write(upload) 105 del upload 106 else: 107 self.wfile.write(('\t%s=%s<br>' % (field, form[field].value)).encode('utf-8')) 108 109 110 def translate_path(self, path): 111 """Override SimpleHTTPRequestHandler's translate_path to serve 112 from arbitrary docroot 113 """ 114 # abandon query parameters 115 path = six.moves.urllib.parse.urlparse(path)[2] 116 path = posixpath.normpath(urllib.parse.unquote(path)) 117 words = path.split('/') 118 words = [_f for _f in words if _f] 119 path = self.server.docroot 120 for word in words: 121 drive, word = os.path.splitdrive(word) 122 head, word = os.path.split(word) 123 if word in (os.curdir, os.pardir): continue 124 path = os.path.join(path, word) 125 logging.debug('Translated path: %s', path) 126 return path 127 128 129 def _fire_event(self): 130 wait_urls = self.server._wait_urls 131 if self.path in wait_urls: 132 _, e = wait_urls[self.path] 133 e.set() 134 del wait_urls[self.path] 135 else: 136 if self.path not in self.server._urls: 137 # if the url is not in _urls, this means it was neither setup 138 # as a permanent, or event url. 139 logging.debug('URL %s not in watch list' % self.path) 140 141 142 @_handle_http_errors 143 def do_GET(self): 144 form = cgi.FieldStorage( 145 fp=self.rfile, 146 headers=self.headers, 147 environ={'REQUEST_METHOD': 'GET'}) 148 split_url = six.moves.urllib.parse.urlsplit(self.path) 149 path = split_url[2] 150 # Strip off query parameters to ensure that the url path 151 # matches any registered events. 152 self.path = path 153 args = six.moves.urllib.parse.parse_qs(split_url[3]) 154 if path in self.server._url_handlers: 155 self.server._url_handlers[path](self, args) 156 else: 157 six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) 158 self._fire_event() 159 160 161 @_handle_http_errors 162 def do_HEAD(self): 163 six.moves.SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self) 164 165 166class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): 167 def __init__(self, server_address, HandlerClass): 168 HTTPServer.__init__(self, server_address, HandlerClass) 169 170 171class HTTPListener(object): 172 # Point default docroot to a non-existent directory (instead of None) to 173 # avoid exceptions when page content is served through handlers only. 174 def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}): 175 self._server = ThreadedHTTPServer(('', port), FormHandler) 176 self.config_server(self._server, docroot, wait_urls, url_handlers) 177 178 def config_server(self, server, docroot, wait_urls, url_handlers): 179 # Stuff some convenient data fields into the server object. 180 self._server.docroot = docroot 181 self._server._urls = set() 182 self._server._wait_urls = wait_urls 183 self._server._url_handlers = url_handlers 184 self._server._form_entries = {} 185 self._server_thread = threading.Thread( 186 target=self._server.serve_forever) 187 188 def add_url(self, url): 189 """ 190 Add a url to the urls that the http server is actively watching for. 191 192 Not adding a url via add_url or add_wait_url, and only installing a 193 handler will still result in that handler being executed, but this 194 server will warn in the debug logs that it does not expect that url. 195 196 Args: 197 url (string): url suffix to listen to 198 """ 199 self._server._urls.add(url) 200 201 def add_wait_url(self, url='/', matchParams={}): 202 """ 203 Add a wait url to the urls that the http server is aware of. 204 205 Not adding a url via add_url or add_wait_url, and only installing a 206 handler will still result in that handler being executed, but this 207 server will warn in the debug logs that it does not expect that url. 208 209 Args: 210 url (string): url suffix to listen to 211 matchParams (dictionary): an unused dictionary 212 213 Returns: 214 e, and event object. Call e.wait() on the object to wait (block) 215 until the server receives the first request for the wait url. 216 217 """ 218 e = threading.Event() 219 self._server._wait_urls[url] = (matchParams, e) 220 self._server._urls.add(url) 221 return e 222 223 def add_url_handler(self, url, handler_func): 224 self._server._url_handlers[url] = handler_func 225 226 def clear_form_entries(self): 227 self._server._form_entries = {} 228 229 230 def get_form_entries(self): 231 """Returns a dictionary of all field=values recieved by the server. 232 """ 233 return self._server._form_entries 234 235 236 def run(self): 237 logging.debug('http server on %s:%d' % 238 (self._server.server_name, self._server.server_port)) 239 self._server_thread.start() 240 241 242 def stop(self): 243 self._server.shutdown() 244 self._server.socket.close() 245 self._server_thread.join() 246 247 248class SecureHTTPServer(ThreadingMixIn, HTTPServer): 249 def __init__(self, server_address, HandlerClass, cert_path, key_path): 250 _socket = socket.socket(self.address_family, self.socket_type) 251 self.socket = ssl.wrap_socket(_socket, 252 server_side=True, 253 ssl_version=ssl.PROTOCOL_TLSv1, 254 certfile=cert_path, 255 keyfile=key_path) 256 BaseServer.__init__(self, server_address, HandlerClass) 257 self.server_bind() 258 self.server_activate() 259 260 261class SecureHTTPRequestHandler(FormHandler): 262 def setup(self): 263 self.connection = self.request 264 self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize) 265 self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize) 266 267 # Override the default logging methods to use the logging module directly. 268 def log_error(self, format, *args): 269 logging.warning("(httpd error) %s - - [%s] %s\n" % 270 (self.address_string(), self.log_date_time_string(), 271 format%args)) 272 273 def log_message(self, format, *args): 274 logging.debug("%s - - [%s] %s\n" % 275 (self.address_string(), self.log_date_time_string(), 276 format%args)) 277 278 279class SecureHTTPListener(HTTPListener): 280 def __init__(self, 281 cert_path='/etc/login_trust_root.pem', 282 key_path='/etc/mock_server.key', 283 port=0, 284 docroot='/_', 285 wait_urls={}, 286 url_handlers={}): 287 self._server = SecureHTTPServer(('', port), 288 SecureHTTPRequestHandler, 289 cert_path, 290 key_path) 291 self.config_server(self._server, docroot, wait_urls, url_handlers) 292 293 294 def getsockname(self): 295 return self._server.socket.getsockname() 296 297