1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Testing utilities for the webapp libraries.
19
20  GetDefaultEnvironment: Method for easily setting up CGI environment.
21  RequestHandlerTestBase: Base class for setting up handler tests.
22"""
23
24__author__ = 'rafek@google.com (Rafe Kaplan)'
25
26import cStringIO
27import threading
28import urllib2
29from wsgiref import simple_server
30from wsgiref import validate
31
32from . import protojson
33from . import remote
34from . import test_util
35from . import transport
36from .webapp import service_handlers
37from .webapp.google_imports import webapp
38
39
40class TestService(remote.Service):
41  """Service used to do end to end tests with."""
42
43  @remote.method(test_util.OptionalMessage,
44                 test_util.OptionalMessage)
45  def optional_message(self, request):
46    if request.string_value:
47      request.string_value = '+%s' % request.string_value
48    return request
49
50
51def GetDefaultEnvironment():
52  """Function for creating a default CGI environment."""
53  return {
54    'LC_NUMERIC': 'C',
55    'wsgi.multiprocess': True,
56    'SERVER_PROTOCOL': 'HTTP/1.0',
57    'SERVER_SOFTWARE': 'Dev AppServer 0.1',
58    'SCRIPT_NAME': '',
59    'LOGNAME': 'nickjohnson',
60    'USER': 'nickjohnson',
61    'QUERY_STRING': 'foo=bar&foo=baz&foo2=123',
62    'PATH': '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/bin/X11',
63    'LANG': 'en_US',
64    'LANGUAGE': 'en',
65    'REMOTE_ADDR': '127.0.0.1',
66    'LC_MONETARY': 'C',
67    'CONTENT_TYPE': 'application/x-www-form-urlencoded',
68    'wsgi.url_scheme': 'http',
69    'SERVER_PORT': '8080',
70    'HOME': '/home/mruser',
71    'USERNAME': 'mruser',
72    'CONTENT_LENGTH': '',
73    'USER_IS_ADMIN': '1',
74    'PYTHONPATH': '/tmp/setup',
75    'LC_TIME': 'C',
76    'HTTP_USER_AGENT': 'Mozilla/5.0 (X11; U; Linux i686 (x86_64); en-US; '
77        'rv:1.8.1.6) Gecko/20070725 Firefox/2.0.0.6',
78    'wsgi.multithread': False,
79    'wsgi.version': (1, 0),
80    'USER_EMAIL': 'test@example.com',
81    'USER_EMAIL': '112',
82    'wsgi.input': cStringIO.StringIO(),
83    'PATH_TRANSLATED': '/tmp/request.py',
84    'SERVER_NAME': 'localhost',
85    'GATEWAY_INTERFACE': 'CGI/1.1',
86    'wsgi.run_once': True,
87    'LC_COLLATE': 'C',
88    'HOSTNAME': 'myhost',
89    'wsgi.errors': cStringIO.StringIO(),
90    'PWD': '/tmp',
91    'REQUEST_METHOD': 'GET',
92    'MAIL': '/dev/null',
93    'MAILCHECK': '0',
94    'USER_NICKNAME': 'test',
95    'HTTP_COOKIE': 'dev_appserver_login="test:test@example.com:True"',
96    'PATH_INFO': '/tmp/myhandler'
97  }
98
99
100class RequestHandlerTestBase(test_util.TestCase):
101  """Base class for writing RequestHandler tests.
102
103  To test a specific request handler override CreateRequestHandler.
104  To change the environment for that handler override GetEnvironment.
105  """
106
107  def setUp(self):
108    """Set up test for request handler."""
109    self.ResetHandler()
110
111  def GetEnvironment(self):
112    """Get environment.
113
114    Override for more specific configurations.
115
116    Returns:
117      dict of CGI environment.
118    """
119    return GetDefaultEnvironment()
120
121  def CreateRequestHandler(self):
122    """Create RequestHandler instances.
123
124    Override to create more specific kinds of RequestHandler instances.
125
126    Returns:
127      RequestHandler instance used in test.
128    """
129    return webapp.RequestHandler()
130
131  def CheckResponse(self,
132                    expected_status,
133                    expected_headers,
134                    expected_content):
135    """Check that the web response is as expected.
136
137    Args:
138      expected_status: Expected status message.
139      expected_headers: Dictionary of expected headers.  Will ignore unexpected
140        headers and only check the value of those expected.
141      expected_content: Expected body.
142    """
143    def check_content(content):
144      self.assertEquals(expected_content, content)
145
146    def start_response(status, headers):
147      self.assertEquals(expected_status, status)
148
149      found_keys = set()
150      for name, value in headers:
151        name = name.lower()
152        try:
153          expected_value = expected_headers[name]
154        except KeyError:
155          pass
156        else:
157          found_keys.add(name)
158          self.assertEquals(expected_value, value)
159
160      missing_headers = set(expected_headers.keys()) - found_keys
161      if missing_headers:
162        self.fail('Expected keys %r not found' % (list(missing_headers),))
163
164      return check_content
165
166    self.handler.response.wsgi_write(start_response)
167
168  def ResetHandler(self, change_environ=None):
169    """Reset this tests environment with environment changes.
170
171    Resets the entire test with a new handler which includes some changes to
172    the default request environment.
173
174    Args:
175      change_environ: Dictionary of values that are added to default
176        environment.
177    """
178    environment = self.GetEnvironment()
179    environment.update(change_environ or {})
180
181    self.request = webapp.Request(environment)
182    self.response = webapp.Response()
183    self.handler = self.CreateRequestHandler()
184    self.handler.initialize(self.request, self.response)
185
186
187class SyncedWSGIServer(simple_server.WSGIServer):
188  pass
189
190
191class ServerThread(threading.Thread):
192  """Thread responsible for managing wsgi server.
193
194  This server does not just attach to the socket and listen for requests.  This
195  is because the server classes in Python 2.5 or less have no way to shut them
196  down.  Instead, the thread must be notified of how many requests it will
197  receive so that it listens for each one individually.  Tests should tell how
198  many requests to listen for using the handle_request method.
199  """
200
201  def __init__(self, server, *args, **kwargs):
202    """Constructor.
203
204    Args:
205      server: The WSGI server that is served by this thread.
206      As per threading.Thread base class.
207
208    State:
209      __serving: Server is still expected to be serving.  When False server
210        knows to shut itself down.
211    """
212    self.server = server
213    # This timeout is for the socket when a connection is made.
214    self.server.socket.settimeout(None)
215    # This timeout is for when waiting for a connection.  The allows
216    # server.handle_request() to listen for a short time, then timeout,
217    # allowing the server to check for shutdown.
218    self.server.timeout = 0.05
219    self.__serving = True
220
221    super(ServerThread, self).__init__(*args, **kwargs)
222
223  def shutdown(self):
224    """Notify server that it must shutdown gracefully."""
225    self.__serving = False
226
227  def run(self):
228    """Handle incoming requests until shutdown."""
229    while self.__serving:
230      self.server.handle_request()
231
232    self.server = None
233
234
235class TestService(remote.Service):
236  """Service used to do end to end tests with."""
237
238  def __init__(self, message='uninitialized'):
239    self.__message = message
240
241  @remote.method(test_util.OptionalMessage, test_util.OptionalMessage)
242  def optional_message(self, request):
243    if request.string_value:
244      request.string_value = '+%s' % request.string_value
245    return request
246
247  @remote.method(response_type=test_util.OptionalMessage)
248  def init_parameter(self, request):
249    return test_util.OptionalMessage(string_value=self.__message)
250
251  @remote.method(test_util.NestedMessage, test_util.NestedMessage)
252  def nested_message(self, request):
253    request.string_value = '+%s' % request.string_value
254    return request
255
256  @remote.method()
257  def raise_application_error(self, request):
258    raise remote.ApplicationError('This is an application error', 'ERROR_NAME')
259
260  @remote.method()
261  def raise_unexpected_error(self, request):
262    raise TypeError('Unexpected error')
263
264  @remote.method()
265  def raise_rpc_error(self, request):
266    raise remote.NetworkError('Uncaught network error')
267
268  @remote.method(response_type=test_util.NestedMessage)
269  def return_bad_message(self, request):
270    return test_util.NestedMessage()
271
272
273class AlternateService(remote.Service):
274  """Service used to requesting non-existant methods."""
275
276  @remote.method()
277  def does_not_exist(self, request):
278    raise NotImplementedError('Not implemented')
279
280
281class WebServerTestBase(test_util.TestCase):
282
283  SERVICE_PATH = '/my/service'
284
285  def setUp(self):
286    self.server = None
287    self.schema = 'http'
288    self.ResetServer()
289
290    self.bad_path_connection = self.CreateTransport(self.service_url + '_x')
291    self.bad_path_stub = TestService.Stub(self.bad_path_connection)
292    super(WebServerTestBase, self).setUp()
293
294  def tearDown(self):
295    self.server.shutdown()
296    super(WebServerTestBase, self).tearDown()
297
298  def ResetServer(self, application=None):
299    """Reset web server.
300
301    Shuts down existing server if necessary and starts a new one.
302
303    Args:
304      application: Optional WSGI function.  If none provided will use
305        tests CreateWsgiApplication method.
306    """
307    if self.server:
308      self.server.shutdown()
309
310    self.port = test_util.pick_unused_port()
311    self.server, self.application = self.StartWebServer(self.port, application)
312
313    self.connection = self.CreateTransport(self.service_url)
314
315  def CreateTransport(self, service_url, protocol=protojson):
316    """Create a new transportation object."""
317    return transport.HttpTransport(service_url, protocol=protocol)
318
319  def StartWebServer(self, port, application=None):
320    """Start web server.
321
322    Args:
323      port: Port to start application on.
324      application: Optional WSGI function.  If none provided will use
325        tests CreateWsgiApplication method.
326
327    Returns:
328      A tuple (server, application):
329        server: An instance of ServerThread.
330        application: Application that web server responds with.
331    """
332    if not application:
333      application = self.CreateWsgiApplication()
334    validated_application = validate.validator(application)
335    server = simple_server.make_server('localhost', port, validated_application)
336    server = ServerThread(server)
337    server.start()
338    return server, application
339
340  def make_service_url(self, path):
341    """Make service URL using current schema and port."""
342    return '%s://localhost:%d%s' % (self.schema, self.port, path)
343
344  @property
345  def service_url(self):
346    return self.make_service_url(self.SERVICE_PATH)
347
348
349class EndToEndTestBase(WebServerTestBase):
350
351  # Sub-classes may override to create alternate configurations.
352  DEFAULT_MAPPING = service_handlers.service_mapping(
353    [('/my/service', TestService),
354     ('/my/other_service', TestService.new_factory('initialized')),
355    ])
356
357  def setUp(self):
358    super(EndToEndTestBase, self).setUp()
359
360    self.stub = TestService.Stub(self.connection)
361
362    self.other_connection = self.CreateTransport(self.other_service_url)
363    self.other_stub = TestService.Stub(self.other_connection)
364
365    self.mismatched_stub = AlternateService.Stub(self.connection)
366
367  @property
368  def other_service_url(self):
369    return 'http://localhost:%d/my/other_service' % self.port
370
371  def CreateWsgiApplication(self):
372    """Create WSGI application used on the server side for testing."""
373    return webapp.WSGIApplication(self.DEFAULT_MAPPING, True)
374
375  def DoRawRequest(self,
376                   method,
377                   content='',
378                   content_type='application/json',
379                   headers=None):
380    headers = headers or {}
381    headers.update({'content-length': len(content or ''),
382                    'content-type': content_type,
383                   })
384    request = urllib2.Request('%s.%s' % (self.service_url, method),
385                              content,
386                              headers)
387    return urllib2.urlopen(request)
388
389  def RawRequestError(self,
390                      method,
391                      content=None,
392                      content_type='application/json',
393                      headers=None):
394    try:
395      self.DoRawRequest(method, content, content_type, headers)
396      self.fail('Expected HTTP error')
397    except urllib2.HTTPError as err:
398      return err.code, err.read(), err.headers
399