1from unittest import TestCase 2from wsgiref.util import setup_testing_defaults 3from wsgiref.headers import Headers 4from wsgiref.handlers import BaseHandler, BaseCGIHandler 5from wsgiref import util 6from wsgiref.validate import validator 7from wsgiref.simple_server import WSGIServer, WSGIRequestHandler 8from wsgiref.simple_server import make_server 9from StringIO import StringIO 10from SocketServer import BaseServer 11 12import os 13import re 14import sys 15 16from test import test_support 17 18class MockServer(WSGIServer): 19 """Non-socket HTTP server""" 20 21 def __init__(self, server_address, RequestHandlerClass): 22 BaseServer.__init__(self, server_address, RequestHandlerClass) 23 self.server_bind() 24 25 def server_bind(self): 26 host, port = self.server_address 27 self.server_name = host 28 self.server_port = port 29 self.setup_environ() 30 31 32class MockHandler(WSGIRequestHandler): 33 """Non-socket HTTP handler""" 34 def setup(self): 35 self.connection = self.request 36 self.rfile, self.wfile = self.connection 37 38 def finish(self): 39 pass 40 41 42def hello_app(environ,start_response): 43 start_response("200 OK", [ 44 ('Content-Type','text/plain'), 45 ('Date','Mon, 05 Jun 2006 18:49:54 GMT') 46 ]) 47 return ["Hello, world!"] 48 49def run_amock(app=hello_app, data="GET / HTTP/1.0\n\n"): 50 server = make_server("", 80, app, MockServer, MockHandler) 51 inp, out, err, olderr = StringIO(data), StringIO(), StringIO(), sys.stderr 52 sys.stderr = err 53 54 try: 55 server.finish_request((inp,out), ("127.0.0.1",8888)) 56 finally: 57 sys.stderr = olderr 58 59 return out.getvalue(), err.getvalue() 60 61 62def compare_generic_iter(make_it,match): 63 """Utility to compare a generic 2.1/2.2+ iterator with an iterable 64 65 If running under Python 2.2+, this tests the iterator using iter()/next(), 66 as well as __getitem__. 'make_it' must be a function returning a fresh 67 iterator to be tested (since this may test the iterator twice).""" 68 69 it = make_it() 70 n = 0 71 for item in match: 72 if not it[n]==item: raise AssertionError 73 n+=1 74 try: 75 it[n] 76 except IndexError: 77 pass 78 else: 79 raise AssertionError("Too many items from __getitem__",it) 80 81 try: 82 iter, StopIteration 83 except NameError: 84 pass 85 else: 86 # Only test iter mode under 2.2+ 87 it = make_it() 88 if not iter(it) is it: raise AssertionError 89 for item in match: 90 if not it.next()==item: raise AssertionError 91 try: 92 it.next() 93 except StopIteration: 94 pass 95 else: 96 raise AssertionError("Too many items from .next()",it) 97 98 99class IntegrationTests(TestCase): 100 101 def check_hello(self, out, has_length=True): 102 self.assertEqual(out, 103 "HTTP/1.0 200 OK\r\n" 104 "Server: WSGIServer/0.1 Python/"+sys.version.split()[0]+"\r\n" 105 "Content-Type: text/plain\r\n" 106 "Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" + 107 (has_length and "Content-Length: 13\r\n" or "") + 108 "\r\n" 109 "Hello, world!" 110 ) 111 112 def test_plain_hello(self): 113 out, err = run_amock() 114 self.check_hello(out) 115 116 def test_request_length(self): 117 out, err = run_amock(data="GET " + ("x" * 65537) + " HTTP/1.0\n\n") 118 self.assertEqual(out.splitlines()[0], 119 "HTTP/1.0 414 Request-URI Too Long") 120 121 def test_validated_hello(self): 122 out, err = run_amock(validator(hello_app)) 123 # the middleware doesn't support len(), so content-length isn't there 124 self.check_hello(out, has_length=False) 125 126 def test_simple_validation_error(self): 127 def bad_app(environ,start_response): 128 start_response("200 OK", ('Content-Type','text/plain')) 129 return ["Hello, world!"] 130 out, err = run_amock(validator(bad_app)) 131 self.assertTrue(out.endswith( 132 "A server error occurred. Please contact the administrator." 133 )) 134 self.assertEqual( 135 err.splitlines()[-2], 136 "AssertionError: Headers (('Content-Type', 'text/plain')) must" 137 " be of type list: <type 'tuple'>" 138 ) 139 140 141class UtilityTests(TestCase): 142 143 def checkShift(self,sn_in,pi_in,part,sn_out,pi_out): 144 env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in} 145 util.setup_testing_defaults(env) 146 self.assertEqual(util.shift_path_info(env),part) 147 self.assertEqual(env['PATH_INFO'],pi_out) 148 self.assertEqual(env['SCRIPT_NAME'],sn_out) 149 return env 150 151 def checkDefault(self, key, value, alt=None): 152 # Check defaulting when empty 153 env = {} 154 util.setup_testing_defaults(env) 155 if isinstance(value, StringIO): 156 self.assertIsInstance(env[key], StringIO) 157 else: 158 self.assertEqual(env[key], value) 159 160 # Check existing value 161 env = {key:alt} 162 util.setup_testing_defaults(env) 163 self.assertIs(env[key], alt) 164 165 def checkCrossDefault(self,key,value,**kw): 166 util.setup_testing_defaults(kw) 167 self.assertEqual(kw[key],value) 168 169 def checkAppURI(self,uri,**kw): 170 util.setup_testing_defaults(kw) 171 self.assertEqual(util.application_uri(kw),uri) 172 173 def checkReqURI(self,uri,query=1,**kw): 174 util.setup_testing_defaults(kw) 175 self.assertEqual(util.request_uri(kw,query),uri) 176 177 def checkFW(self,text,size,match): 178 179 def make_it(text=text,size=size): 180 return util.FileWrapper(StringIO(text),size) 181 182 compare_generic_iter(make_it,match) 183 184 it = make_it() 185 self.assertFalse(it.filelike.closed) 186 187 for item in it: 188 pass 189 190 self.assertFalse(it.filelike.closed) 191 192 it.close() 193 self.assertTrue(it.filelike.closed) 194 195 def testSimpleShifts(self): 196 self.checkShift('','/', '', '/', '') 197 self.checkShift('','/x', 'x', '/x', '') 198 self.checkShift('/','', None, '/', '') 199 self.checkShift('/a','/x/y', 'x', '/a/x', '/y') 200 self.checkShift('/a','/x/', 'x', '/a/x', '/') 201 202 def testNormalizedShifts(self): 203 self.checkShift('/a/b', '/../y', '..', '/a', '/y') 204 self.checkShift('', '/../y', '..', '', '/y') 205 self.checkShift('/a/b', '//y', 'y', '/a/b/y', '') 206 self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/') 207 self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '') 208 self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/') 209 self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/') 210 self.checkShift('/a/b', '///', '', '/a/b/', '') 211 self.checkShift('/a/b', '/.//', '', '/a/b/', '') 212 self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/') 213 self.checkShift('/a/b', '/.', None, '/a/b', '') 214 215 def testDefaults(self): 216 for key, value in [ 217 ('SERVER_NAME','127.0.0.1'), 218 ('SERVER_PORT', '80'), 219 ('SERVER_PROTOCOL','HTTP/1.0'), 220 ('HTTP_HOST','127.0.0.1'), 221 ('REQUEST_METHOD','GET'), 222 ('SCRIPT_NAME',''), 223 ('PATH_INFO','/'), 224 ('wsgi.version', (1,0)), 225 ('wsgi.run_once', 0), 226 ('wsgi.multithread', 0), 227 ('wsgi.multiprocess', 0), 228 ('wsgi.input', StringIO("")), 229 ('wsgi.errors', StringIO()), 230 ('wsgi.url_scheme','http'), 231 ]: 232 self.checkDefault(key,value) 233 234 def testCrossDefaults(self): 235 self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar") 236 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on") 237 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1") 238 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes") 239 self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo") 240 self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo") 241 self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on") 242 243 def testGuessScheme(self): 244 self.assertEqual(util.guess_scheme({}), "http") 245 self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http") 246 self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https") 247 self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https") 248 self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https") 249 250 def testAppURIs(self): 251 self.checkAppURI("http://127.0.0.1/") 252 self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam") 253 self.checkAppURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m") 254 self.checkAppURI("http://spam.example.com:2071/", 255 HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071") 256 self.checkAppURI("http://spam.example.com/", 257 SERVER_NAME="spam.example.com") 258 self.checkAppURI("http://127.0.0.1/", 259 HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com") 260 self.checkAppURI("https://127.0.0.1/", HTTPS="on") 261 self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000", 262 HTTP_HOST=None) 263 264 def testReqURIs(self): 265 self.checkReqURI("http://127.0.0.1/") 266 self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam") 267 self.checkReqURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m") 268 self.checkReqURI("http://127.0.0.1/spammity/spam", 269 SCRIPT_NAME="/spammity", PATH_INFO="/spam") 270 self.checkReqURI("http://127.0.0.1/spammity/sp%E4m", 271 SCRIPT_NAME="/spammity", PATH_INFO="/sp\xe4m") 272 self.checkReqURI("http://127.0.0.1/spammity/spam;ham", 273 SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham") 274 self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678", 275 SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678") 276 self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni", 277 SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni") 278 self.checkReqURI("http://127.0.0.1/spammity/spam?s%E4y=ni", 279 SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="s%E4y=ni") 280 self.checkReqURI("http://127.0.0.1/spammity/spam", 0, 281 SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni") 282 283 def testFileWrapper(self): 284 self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10]) 285 286 def testHopByHop(self): 287 for hop in ( 288 "Connection Keep-Alive Proxy-Authenticate Proxy-Authorization " 289 "TE Trailers Transfer-Encoding Upgrade" 290 ).split(): 291 for alt in hop, hop.title(), hop.upper(), hop.lower(): 292 self.assertTrue(util.is_hop_by_hop(alt)) 293 294 # Not comprehensive, just a few random header names 295 for hop in ( 296 "Accept Cache-Control Date Pragma Trailer Via Warning" 297 ).split(): 298 for alt in hop, hop.title(), hop.upper(), hop.lower(): 299 self.assertFalse(util.is_hop_by_hop(alt)) 300 301class HeaderTests(TestCase): 302 303 def testMappingInterface(self): 304 test = [('x','y')] 305 self.assertEqual(len(Headers([])),0) 306 self.assertEqual(len(Headers(test[:])),1) 307 self.assertEqual(Headers(test[:]).keys(), ['x']) 308 self.assertEqual(Headers(test[:]).values(), ['y']) 309 self.assertEqual(Headers(test[:]).items(), test) 310 self.assertIsNot(Headers(test).items(), test) # must be copy! 311 312 h=Headers([]) 313 del h['foo'] # should not raise an error 314 315 h['Foo'] = 'bar' 316 for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__: 317 self.assertTrue(m('foo')) 318 self.assertTrue(m('Foo')) 319 self.assertTrue(m('FOO')) 320 self.assertFalse(m('bar')) 321 322 self.assertEqual(h['foo'],'bar') 323 h['foo'] = 'baz' 324 self.assertEqual(h['FOO'],'baz') 325 self.assertEqual(h.get_all('foo'),['baz']) 326 327 self.assertEqual(h.get("foo","whee"), "baz") 328 self.assertEqual(h.get("zoo","whee"), "whee") 329 self.assertEqual(h.setdefault("foo","whee"), "baz") 330 self.assertEqual(h.setdefault("zoo","whee"), "whee") 331 self.assertEqual(h["foo"],"baz") 332 self.assertEqual(h["zoo"],"whee") 333 334 def testRequireList(self): 335 self.assertRaises(TypeError, Headers, "foo") 336 337 338 def testExtras(self): 339 h = Headers([]) 340 self.assertEqual(str(h),'\r\n') 341 342 h.add_header('foo','bar',baz="spam") 343 self.assertEqual(h['foo'], 'bar; baz="spam"') 344 self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n') 345 346 h.add_header('Foo','bar',cheese=None) 347 self.assertEqual(h.get_all('foo'), 348 ['bar; baz="spam"', 'bar; cheese']) 349 350 self.assertEqual(str(h), 351 'foo: bar; baz="spam"\r\n' 352 'Foo: bar; cheese\r\n' 353 '\r\n' 354 ) 355 356 357class ErrorHandler(BaseCGIHandler): 358 """Simple handler subclass for testing BaseHandler""" 359 360 # BaseHandler records the OS environment at import time, but envvars 361 # might have been changed later by other tests, which trips up 362 # HandlerTests.testEnviron(). 363 os_environ = dict(os.environ.items()) 364 365 def __init__(self,**kw): 366 setup_testing_defaults(kw) 367 BaseCGIHandler.__init__( 368 self, StringIO(''), StringIO(), StringIO(), kw, 369 multithread=True, multiprocess=True 370 ) 371 372class TestHandler(ErrorHandler): 373 """Simple handler subclass for testing BaseHandler, w/error passthru""" 374 375 def handle_error(self): 376 raise # for testing, we want to see what's happening 377 378 379class HandlerTests(TestCase): 380 381 def checkEnvironAttrs(self, handler): 382 env = handler.environ 383 for attr in [ 384 'version','multithread','multiprocess','run_once','file_wrapper' 385 ]: 386 if attr=='file_wrapper' and handler.wsgi_file_wrapper is None: 387 continue 388 self.assertEqual(getattr(handler,'wsgi_'+attr),env['wsgi.'+attr]) 389 390 def checkOSEnviron(self,handler): 391 empty = {}; setup_testing_defaults(empty) 392 env = handler.environ 393 from os import environ 394 for k,v in environ.items(): 395 if k not in empty: 396 self.assertEqual(env[k],v) 397 for k,v in empty.items(): 398 self.assertIn(k, env) 399 400 def testEnviron(self): 401 h = TestHandler(X="Y") 402 h.setup_environ() 403 self.checkEnvironAttrs(h) 404 self.checkOSEnviron(h) 405 self.assertEqual(h.environ["X"],"Y") 406 407 def testCGIEnviron(self): 408 h = BaseCGIHandler(None,None,None,{}) 409 h.setup_environ() 410 for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors': 411 self.assertIn(key, h.environ) 412 413 def testScheme(self): 414 h=TestHandler(HTTPS="on"); h.setup_environ() 415 self.assertEqual(h.environ['wsgi.url_scheme'],'https') 416 h=TestHandler(); h.setup_environ() 417 self.assertEqual(h.environ['wsgi.url_scheme'],'http') 418 419 def testAbstractMethods(self): 420 h = BaseHandler() 421 for name in [ 422 '_flush','get_stdin','get_stderr','add_cgi_vars' 423 ]: 424 self.assertRaises(NotImplementedError, getattr(h,name)) 425 self.assertRaises(NotImplementedError, h._write, "test") 426 427 def testContentLength(self): 428 # Demo one reason iteration is better than write()... ;) 429 430 def trivial_app1(e,s): 431 s('200 OK',[]) 432 return [e['wsgi.url_scheme']] 433 434 def trivial_app2(e,s): 435 s('200 OK',[])(e['wsgi.url_scheme']) 436 return [] 437 438 def trivial_app4(e,s): 439 # Simulate a response to a HEAD request 440 s('200 OK',[('Content-Length', '12345')]) 441 return [] 442 443 h = TestHandler() 444 h.run(trivial_app1) 445 self.assertEqual(h.stdout.getvalue(), 446 "Status: 200 OK\r\n" 447 "Content-Length: 4\r\n" 448 "\r\n" 449 "http") 450 451 h = TestHandler() 452 h.run(trivial_app2) 453 self.assertEqual(h.stdout.getvalue(), 454 "Status: 200 OK\r\n" 455 "\r\n" 456 "http") 457 458 459 h = TestHandler() 460 h.run(trivial_app4) 461 self.assertEqual(h.stdout.getvalue(), 462 b'Status: 200 OK\r\n' 463 b'Content-Length: 12345\r\n' 464 b'\r\n') 465 466 def testBasicErrorOutput(self): 467 468 def non_error_app(e,s): 469 s('200 OK',[]) 470 return [] 471 472 def error_app(e,s): 473 raise AssertionError("This should be caught by handler") 474 475 h = ErrorHandler() 476 h.run(non_error_app) 477 self.assertEqual(h.stdout.getvalue(), 478 "Status: 200 OK\r\n" 479 "Content-Length: 0\r\n" 480 "\r\n") 481 self.assertEqual(h.stderr.getvalue(),"") 482 483 h = ErrorHandler() 484 h.run(error_app) 485 self.assertEqual(h.stdout.getvalue(), 486 "Status: %s\r\n" 487 "Content-Type: text/plain\r\n" 488 "Content-Length: %d\r\n" 489 "\r\n%s" % (h.error_status,len(h.error_body),h.error_body)) 490 491 self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1) 492 493 def testErrorAfterOutput(self): 494 MSG = "Some output has been sent" 495 def error_app(e,s): 496 s("200 OK",[])(MSG) 497 raise AssertionError("This should be caught by handler") 498 499 h = ErrorHandler() 500 h.run(error_app) 501 self.assertEqual(h.stdout.getvalue(), 502 "Status: 200 OK\r\n" 503 "\r\n"+MSG) 504 self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1) 505 506 def testHeaderFormats(self): 507 508 def non_error_app(e,s): 509 s('200 OK',[]) 510 return [] 511 512 stdpat = ( 513 r"HTTP/%s 200 OK\r\n" 514 r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n" 515 r"%s" r"Content-Length: 0\r\n" r"\r\n" 516 ) 517 shortpat = ( 518 "Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n" 519 ) 520 521 for ssw in "FooBar/1.0", None: 522 sw = ssw and "Server: %s\r\n" % ssw or "" 523 524 for version in "1.0", "1.1": 525 for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1": 526 527 h = TestHandler(SERVER_PROTOCOL=proto) 528 h.origin_server = False 529 h.http_version = version 530 h.server_software = ssw 531 h.run(non_error_app) 532 self.assertEqual(shortpat,h.stdout.getvalue()) 533 534 h = TestHandler(SERVER_PROTOCOL=proto) 535 h.origin_server = True 536 h.http_version = version 537 h.server_software = ssw 538 h.run(non_error_app) 539 if proto=="HTTP/0.9": 540 self.assertEqual(h.stdout.getvalue(),"") 541 else: 542 self.assertTrue( 543 re.match(stdpat%(version,sw), h.stdout.getvalue()), 544 (stdpat%(version,sw), h.stdout.getvalue()) 545 ) 546 547 def testCloseOnError(self): 548 side_effects = {'close_called': False} 549 MSG = b"Some output has been sent" 550 def error_app(e,s): 551 s("200 OK",[])(MSG) 552 class CrashyIterable(object): 553 def __iter__(self): 554 while True: 555 yield b'blah' 556 raise AssertionError("This should be caught by handler") 557 558 def close(self): 559 side_effects['close_called'] = True 560 return CrashyIterable() 561 562 h = ErrorHandler() 563 h.run(error_app) 564 self.assertEqual(side_effects['close_called'], True) 565 566 567def test_main(): 568 test_support.run_unittest(__name__) 569 570if __name__ == "__main__": 571 test_main() 572