1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import errno 19import six.moves.http_client 20import os 21import socket 22import unittest 23 24from protorpc import messages 25from protorpc import protobuf 26from protorpc import protojson 27from protorpc import remote 28from protorpc import test_util 29from protorpc import transport 30from protorpc import webapp_test_util 31from protorpc.wsgi import util as wsgi_util 32 33import mox 34 35package = 'transport_test' 36 37 38class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 39 test_util.TestCase): 40 41 MODULE = transport 42 43 44class Message(messages.Message): 45 46 value = messages.StringField(1) 47 48 49class Service(remote.Service): 50 51 @remote.method(Message, Message) 52 def method(self, request): 53 pass 54 55 56# Remove when RPC is no longer subclasses. 57class TestRpc(transport.Rpc): 58 59 waited = False 60 61 def _wait_impl(self): 62 self.waited = True 63 64 65class RpcTest(test_util.TestCase): 66 67 def setUp(self): 68 self.request = Message(value=u'request') 69 self.response = Message(value=u'response') 70 self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, 71 error_message='an error', 72 error_name='blam') 73 74 self.rpc = TestRpc(self.request) 75 76 def testConstructor(self): 77 self.assertEquals(self.request, self.rpc.request) 78 self.assertEquals(remote.RpcState.RUNNING, self.rpc.state) 79 self.assertEquals(None, self.rpc.error_message) 80 self.assertEquals(None, self.rpc.error_name) 81 82 def response(self): 83 self.assertFalse(self.rpc.waited) 84 self.assertEquals(None, self.rpc.response) 85 self.assertTrue(self.rpc.waited) 86 87 def testSetResponse(self): 88 self.rpc.set_response(self.response) 89 90 self.assertEquals(self.request, self.rpc.request) 91 self.assertEquals(remote.RpcState.OK, self.rpc.state) 92 self.assertEquals(self.response, self.rpc.response) 93 self.assertEquals(None, self.rpc.error_message) 94 self.assertEquals(None, self.rpc.error_name) 95 96 def testSetResponseAlreadySet(self): 97 self.rpc.set_response(self.response) 98 99 self.assertRaisesWithRegexpMatch( 100 transport.RpcStateError, 101 'RPC must be in RUNNING state to change to OK', 102 self.rpc.set_response, 103 self.response) 104 105 def testSetResponseAlreadyError(self): 106 self.rpc.set_status(self.status) 107 108 self.assertRaisesWithRegexpMatch( 109 transport.RpcStateError, 110 'RPC must be in RUNNING state to change to OK', 111 self.rpc.set_response, 112 self.response) 113 114 def testSetStatus(self): 115 self.rpc.set_status(self.status) 116 117 self.assertEquals(self.request, self.rpc.request) 118 self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state) 119 self.assertEquals('an error', self.rpc.error_message) 120 self.assertEquals('blam', self.rpc.error_name) 121 self.assertRaisesWithRegexpMatch(remote.ApplicationError, 122 'an error', 123 getattr, self.rpc, 'response') 124 125 def testSetStatusAlreadySet(self): 126 self.rpc.set_response(self.response) 127 128 self.assertRaisesWithRegexpMatch( 129 transport.RpcStateError, 130 'RPC must be in RUNNING state to change to OK', 131 self.rpc.set_response, 132 self.response) 133 134 def testSetNonMessage(self): 135 self.assertRaisesWithRegexpMatch( 136 TypeError, 137 'Expected Message type, received 10', 138 self.rpc.set_response, 139 10) 140 141 def testSetStatusAlreadyError(self): 142 self.rpc.set_status(self.status) 143 144 self.assertRaisesWithRegexpMatch( 145 transport.RpcStateError, 146 'RPC must be in RUNNING state to change to OK', 147 self.rpc.set_response, 148 self.response) 149 150 def testSetUninitializedStatus(self): 151 self.assertRaises(messages.ValidationError, 152 self.rpc.set_status, 153 remote.RpcStatus()) 154 155 156class TransportTest(test_util.TestCase): 157 158 def setUp(self): 159 remote.Protocols.set_default(remote.Protocols.new_default()) 160 161 def do_test(self, protocol, trans): 162 request = Message() 163 request.value = u'request' 164 165 response = Message() 166 response.value = u'response' 167 168 encoded_request = protocol.encode_message(request) 169 encoded_response = protocol.encode_message(response) 170 171 self.assertEquals(protocol, trans.protocol) 172 173 received_rpc = [None] 174 def transport_rpc(remote, rpc_request): 175 self.assertEquals(remote, Service.method.remote) 176 self.assertEquals(request, rpc_request) 177 rpc = TestRpc(request) 178 rpc.set_response(response) 179 return rpc 180 trans._start_rpc = transport_rpc 181 182 rpc = trans.send_rpc(Service.method.remote, request) 183 self.assertEquals(response, rpc.response) 184 185 def testDefaultProtocol(self): 186 trans = transport.Transport() 187 self.do_test(protobuf, trans) 188 self.assertEquals(protobuf, trans.protocol_config.protocol) 189 self.assertEquals('default', trans.protocol_config.name) 190 191 def testAlternateProtocol(self): 192 trans = transport.Transport(protocol=protojson) 193 self.do_test(protojson, trans) 194 self.assertEquals(protojson, trans.protocol_config.protocol) 195 self.assertEquals('default', trans.protocol_config.name) 196 197 def testProtocolConfig(self): 198 protocol_config = remote.ProtocolConfig( 199 protojson, 'protoconfig', 'image/png') 200 trans = transport.Transport(protocol=protocol_config) 201 self.do_test(protojson, trans) 202 self.assertTrue(trans.protocol_config is protocol_config) 203 204 def testProtocolByName(self): 205 remote.Protocols.get_default().add_protocol( 206 protojson, 'png', 'image/png', ()) 207 trans = transport.Transport(protocol='png') 208 self.do_test(protojson, trans) 209 210 211@remote.method(Message, Message) 212def my_method(self, request): 213 self.fail('self.my_method should not be directly invoked.') 214 215 216class FakeConnectionClass(object): 217 218 def __init__(self, mox): 219 self.request = mox.CreateMockAnything() 220 self.response = mox.CreateMockAnything() 221 222 223class HttpTransportTest(webapp_test_util.WebServerTestBase): 224 225 def setUp(self): 226 # Do not need much parent construction functionality. 227 228 self.schema = 'http' 229 self.server = None 230 231 self.request = Message(value=u'The request value') 232 self.encoded_request = protojson.encode_message(self.request) 233 234 self.response = Message(value=u'The response value') 235 self.encoded_response = protojson.encode_message(self.response) 236 237 def testCallSucceeds(self): 238 self.ResetServer(wsgi_util.static_page(self.encoded_response, 239 content_type='application/json')) 240 241 rpc = self.connection.send_rpc(my_method.remote, self.request) 242 self.assertEquals(self.response, rpc.response) 243 244 def testHttps(self): 245 self.schema = 'https' 246 self.ResetServer(wsgi_util.static_page(self.encoded_response, 247 content_type='application/json')) 248 249 # Create a fake https connection function that really just calls http. 250 self.used_https = False 251 def https_connection(*args, **kwargs): 252 self.used_https = True 253 return six.moves.http_client.HTTPConnection(*args, **kwargs) 254 255 original_https_connection = six.moves.http_client.HTTPSConnection 256 six.moves.http_client.HTTPSConnection = https_connection 257 try: 258 rpc = self.connection.send_rpc(my_method.remote, self.request) 259 finally: 260 six.moves.http_client.HTTPSConnection = original_https_connection 261 self.assertEquals(self.response, rpc.response) 262 self.assertTrue(self.used_https) 263 264 def testHttpSocketError(self): 265 self.ResetServer(wsgi_util.static_page(self.encoded_response, 266 content_type='application/json')) 267 268 bad_transport = transport.HttpTransport('http://localhost:-1/blar') 269 try: 270 bad_transport.send_rpc(my_method.remote, self.request) 271 except remote.NetworkError as err: 272 self.assertTrue(str(err).startswith('Socket error: error (')) 273 self.assertEquals(errno.ECONNREFUSED, err.cause.errno) 274 else: 275 self.fail('Expected error') 276 277 def testHttpRequestError(self): 278 self.ResetServer(wsgi_util.static_page(self.encoded_response, 279 content_type='application/json')) 280 281 def request_error(*args, **kwargs): 282 raise TypeError('Generic Error') 283 original_request = six.moves.http_client.HTTPConnection.request 284 six.moves.http_client.HTTPConnection.request = request_error 285 try: 286 try: 287 self.connection.send_rpc(my_method.remote, self.request) 288 except remote.NetworkError as err: 289 self.assertEquals('Error communicating with HTTP server', str(err)) 290 self.assertEquals(TypeError, type(err.cause)) 291 self.assertEquals('Generic Error', str(err.cause)) 292 else: 293 self.fail('Expected error') 294 finally: 295 six.moves.http_client.HTTPConnection.request = original_request 296 297 def testHandleGenericServiceError(self): 298 self.ResetServer(wsgi_util.error(six.moves.http_client.INTERNAL_SERVER_ERROR, 299 'arbitrary error', 300 content_type='text/plain')) 301 302 rpc = self.connection.send_rpc(my_method.remote, self.request) 303 try: 304 rpc.response 305 except remote.ServerError as err: 306 self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip()) 307 else: 308 self.fail('Expected ServerError') 309 310 def testHandleGenericServiceErrorNoMessage(self): 311 self.ResetServer(wsgi_util.error(six.moves.http_client.NOT_IMPLEMENTED, 312 ' ', 313 content_type='text/plain')) 314 315 rpc = self.connection.send_rpc(my_method.remote, self.request) 316 try: 317 rpc.response 318 except remote.ServerError as err: 319 self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip()) 320 else: 321 self.fail('Expected ServerError') 322 323 def testHandleStatusContent(self): 324 self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",' 325 ' "error_message": "a request error"' 326 '}', 327 status=six.moves.http_client.BAD_REQUEST, 328 content_type='application/json')) 329 330 rpc = self.connection.send_rpc(my_method.remote, self.request) 331 try: 332 rpc.response 333 except remote.RequestError as err: 334 self.assertEquals('a request error', str(err)) 335 else: 336 self.fail('Expected RequestError') 337 338 def testHandleApplicationError(self): 339 self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",' 340 ' "error_message": "an app error",' 341 ' "error_name": "MY_ERROR_NAME"}', 342 status=six.moves.http_client.BAD_REQUEST, 343 content_type='application/json')) 344 345 rpc = self.connection.send_rpc(my_method.remote, self.request) 346 try: 347 rpc.response 348 except remote.ApplicationError as err: 349 self.assertEquals('an app error', str(err)) 350 self.assertEquals('MY_ERROR_NAME', err.error_name) 351 else: 352 self.fail('Expected RequestError') 353 354 def testHandleUnparsableErrorContent(self): 355 self.ResetServer(wsgi_util.static_page('oops', 356 status=six.moves.http_client.BAD_REQUEST, 357 content_type='application/json')) 358 359 rpc = self.connection.send_rpc(my_method.remote, self.request) 360 try: 361 rpc.response 362 except remote.ServerError as err: 363 self.assertEquals('HTTP Error 400: oops', str(err)) 364 else: 365 self.fail('Expected ServerError') 366 367 def testHandleEmptyBadRpcStatus(self): 368 self.ResetServer(wsgi_util.static_page('{"error_message": "x"}', 369 status=six.moves.http_client.BAD_REQUEST, 370 content_type='application/json')) 371 372 rpc = self.connection.send_rpc(my_method.remote, self.request) 373 try: 374 rpc.response 375 except remote.ServerError as err: 376 self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err)) 377 else: 378 self.fail('Expected ServerError') 379 380 def testUseProtocolConfigContentType(self): 381 expected_content_type = 'image/png' 382 def expect_content_type(environ, start_response): 383 self.assertEquals(expected_content_type, environ['CONTENT_TYPE']) 384 app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE']) 385 return app(environ, start_response) 386 387 self.ResetServer(expect_content_type) 388 389 protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png') 390 self.connection = self.CreateTransport(self.service_url, protocol_config) 391 392 rpc = self.connection.send_rpc(my_method.remote, self.request) 393 self.assertEquals(Message(), rpc.response) 394 395 396class SimpleRequest(messages.Message): 397 398 content = messages.StringField(1) 399 400 401class SimpleResponse(messages.Message): 402 403 content = messages.StringField(1) 404 factory_value = messages.StringField(2) 405 remote_host = messages.StringField(3) 406 remote_address = messages.StringField(4) 407 server_host = messages.StringField(5) 408 server_port = messages.IntegerField(6) 409 410 411class LocalService(remote.Service): 412 413 def __init__(self, factory_value='default'): 414 self.factory_value = factory_value 415 416 @remote.method(SimpleRequest, SimpleResponse) 417 def call_method(self, request): 418 return SimpleResponse(content=request.content, 419 factory_value=self.factory_value, 420 remote_host=self.request_state.remote_host, 421 remote_address=self.request_state.remote_address, 422 server_host=self.request_state.server_host, 423 server_port=self.request_state.server_port) 424 425 @remote.method() 426 def raise_totally_unexpected(self, request): 427 raise TypeError('Kablam') 428 429 @remote.method() 430 def raise_unexpected(self, request): 431 raise remote.RequestError('Huh?') 432 433 @remote.method() 434 def raise_application_error(self, request): 435 raise remote.ApplicationError('App error', 10) 436 437 438class LocalTransportTest(test_util.TestCase): 439 440 def CreateService(self, factory_value='default'): 441 return 442 443 def testBasicCallWithClass(self): 444 stub = LocalService.Stub(transport.LocalTransport(LocalService)) 445 response = stub.call_method(content='Hello') 446 self.assertEquals(SimpleResponse(content='Hello', 447 factory_value='default', 448 remote_host=os.uname()[1], 449 remote_address='127.0.0.1', 450 server_host=os.uname()[1], 451 server_port=-1), 452 response) 453 454 def testBasicCallWithFactory(self): 455 stub = LocalService.Stub( 456 transport.LocalTransport(LocalService.new_factory('assigned'))) 457 response = stub.call_method(content='Hello') 458 self.assertEquals(SimpleResponse(content='Hello', 459 factory_value='assigned', 460 remote_host=os.uname()[1], 461 remote_address='127.0.0.1', 462 server_host=os.uname()[1], 463 server_port=-1), 464 response) 465 466 def testTotallyUnexpectedError(self): 467 stub = LocalService.Stub(transport.LocalTransport(LocalService)) 468 self.assertRaisesWithRegexpMatch( 469 remote.ServerError, 470 'Unexpected error TypeError: Kablam', 471 stub.raise_totally_unexpected) 472 473 def testUnexpectedError(self): 474 stub = LocalService.Stub(transport.LocalTransport(LocalService)) 475 self.assertRaisesWithRegexpMatch( 476 remote.ServerError, 477 'Unexpected error RequestError: Huh?', 478 stub.raise_unexpected) 479 480 def testApplicationError(self): 481 stub = LocalService.Stub(transport.LocalTransport(LocalService)) 482 self.assertRaisesWithRegexpMatch( 483 remote.ApplicationError, 484 'App error', 485 stub.raise_application_error) 486 487 488def main(): 489 unittest.main() 490 491 492if __name__ == '__main__': 493 main() 494