1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import errno
19import six.moves.http_client
20import os
21import socket
22import unittest
23
24from protorpc import messages
25from protorpc import protobuf
26from protorpc import protojson
27from protorpc import remote
28from protorpc import test_util
29from protorpc import transport
30from protorpc import webapp_test_util
31from protorpc.wsgi import util as wsgi_util
32
33import mox
34
35package = 'transport_test'
36
37
38class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
39                          test_util.TestCase):
40
41  MODULE = transport
42
43
44class Message(messages.Message):
45
46  value = messages.StringField(1)
47
48
49class Service(remote.Service):
50
51  @remote.method(Message, Message)
52  def method(self, request):
53    pass
54
55
56# Remove when RPC is no longer subclasses.
57class TestRpc(transport.Rpc):
58
59  waited = False
60
61  def _wait_impl(self):
62    self.waited = True
63
64
65class RpcTest(test_util.TestCase):
66
67  def setUp(self):
68    self.request = Message(value=u'request')
69    self.response = Message(value=u'response')
70    self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
71                                   error_message='an error',
72                                   error_name='blam')
73
74    self.rpc = TestRpc(self.request)
75
76  def testConstructor(self):
77    self.assertEquals(self.request, self.rpc.request)
78    self.assertEquals(remote.RpcState.RUNNING, self.rpc.state)
79    self.assertEquals(None, self.rpc.error_message)
80    self.assertEquals(None, self.rpc.error_name)
81
82  def response(self):
83    self.assertFalse(self.rpc.waited)
84    self.assertEquals(None, self.rpc.response)
85    self.assertTrue(self.rpc.waited)
86
87  def testSetResponse(self):
88    self.rpc.set_response(self.response)
89
90    self.assertEquals(self.request, self.rpc.request)
91    self.assertEquals(remote.RpcState.OK, self.rpc.state)
92    self.assertEquals(self.response, self.rpc.response)
93    self.assertEquals(None, self.rpc.error_message)
94    self.assertEquals(None, self.rpc.error_name)
95
96  def testSetResponseAlreadySet(self):
97    self.rpc.set_response(self.response)
98
99    self.assertRaisesWithRegexpMatch(
100      transport.RpcStateError,
101      'RPC must be in RUNNING state to change to OK',
102      self.rpc.set_response,
103      self.response)
104
105  def testSetResponseAlreadyError(self):
106    self.rpc.set_status(self.status)
107
108    self.assertRaisesWithRegexpMatch(
109      transport.RpcStateError,
110      'RPC must be in RUNNING state to change to OK',
111      self.rpc.set_response,
112      self.response)
113
114  def testSetStatus(self):
115    self.rpc.set_status(self.status)
116
117    self.assertEquals(self.request, self.rpc.request)
118    self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state)
119    self.assertEquals('an error', self.rpc.error_message)
120    self.assertEquals('blam', self.rpc.error_name)
121    self.assertRaisesWithRegexpMatch(remote.ApplicationError,
122                                     'an error',
123                                     getattr, self.rpc, 'response')
124
125  def testSetStatusAlreadySet(self):
126    self.rpc.set_response(self.response)
127
128    self.assertRaisesWithRegexpMatch(
129      transport.RpcStateError,
130      'RPC must be in RUNNING state to change to OK',
131      self.rpc.set_response,
132      self.response)
133
134  def testSetNonMessage(self):
135    self.assertRaisesWithRegexpMatch(
136      TypeError,
137      'Expected Message type, received 10',
138      self.rpc.set_response,
139      10)
140
141  def testSetStatusAlreadyError(self):
142    self.rpc.set_status(self.status)
143
144    self.assertRaisesWithRegexpMatch(
145      transport.RpcStateError,
146      'RPC must be in RUNNING state to change to OK',
147      self.rpc.set_response,
148      self.response)
149
150  def testSetUninitializedStatus(self):
151    self.assertRaises(messages.ValidationError,
152                      self.rpc.set_status,
153                      remote.RpcStatus())
154
155
156class TransportTest(test_util.TestCase):
157
158  def setUp(self):
159    remote.Protocols.set_default(remote.Protocols.new_default())
160
161  def do_test(self, protocol, trans):
162    request = Message()
163    request.value = u'request'
164
165    response = Message()
166    response.value = u'response'
167
168    encoded_request = protocol.encode_message(request)
169    encoded_response = protocol.encode_message(response)
170
171    self.assertEquals(protocol, trans.protocol)
172
173    received_rpc = [None]
174    def transport_rpc(remote, rpc_request):
175      self.assertEquals(remote, Service.method.remote)
176      self.assertEquals(request, rpc_request)
177      rpc = TestRpc(request)
178      rpc.set_response(response)
179      return rpc
180    trans._start_rpc = transport_rpc
181
182    rpc = trans.send_rpc(Service.method.remote, request)
183    self.assertEquals(response, rpc.response)
184
185  def testDefaultProtocol(self):
186    trans = transport.Transport()
187    self.do_test(protobuf, trans)
188    self.assertEquals(protobuf, trans.protocol_config.protocol)
189    self.assertEquals('default', trans.protocol_config.name)
190
191  def testAlternateProtocol(self):
192    trans = transport.Transport(protocol=protojson)
193    self.do_test(protojson, trans)
194    self.assertEquals(protojson, trans.protocol_config.protocol)
195    self.assertEquals('default', trans.protocol_config.name)
196
197  def testProtocolConfig(self):
198    protocol_config = remote.ProtocolConfig(
199      protojson, 'protoconfig', 'image/png')
200    trans = transport.Transport(protocol=protocol_config)
201    self.do_test(protojson, trans)
202    self.assertTrue(trans.protocol_config is protocol_config)
203
204  def testProtocolByName(self):
205    remote.Protocols.get_default().add_protocol(
206      protojson, 'png', 'image/png', ())
207    trans = transport.Transport(protocol='png')
208    self.do_test(protojson, trans)
209
210
211@remote.method(Message, Message)
212def my_method(self, request):
213  self.fail('self.my_method should not be directly invoked.')
214
215
216class FakeConnectionClass(object):
217
218  def __init__(self, mox):
219    self.request = mox.CreateMockAnything()
220    self.response = mox.CreateMockAnything()
221
222
223class HttpTransportTest(webapp_test_util.WebServerTestBase):
224
225  def setUp(self):
226    # Do not need much parent construction functionality.
227
228    self.schema = 'http'
229    self.server = None
230
231    self.request = Message(value=u'The request value')
232    self.encoded_request = protojson.encode_message(self.request)
233
234    self.response = Message(value=u'The response value')
235    self.encoded_response = protojson.encode_message(self.response)
236
237  def testCallSucceeds(self):
238    self.ResetServer(wsgi_util.static_page(self.encoded_response,
239                                           content_type='application/json'))
240
241    rpc = self.connection.send_rpc(my_method.remote, self.request)
242    self.assertEquals(self.response, rpc.response)
243
244  def testHttps(self):
245    self.schema = 'https'
246    self.ResetServer(wsgi_util.static_page(self.encoded_response,
247                                           content_type='application/json'))
248
249    # Create a fake https connection function that really just calls http.
250    self.used_https = False
251    def https_connection(*args, **kwargs):
252      self.used_https = True
253      return six.moves.http_client.HTTPConnection(*args, **kwargs)
254
255    original_https_connection = six.moves.http_client.HTTPSConnection
256    six.moves.http_client.HTTPSConnection = https_connection
257    try:
258      rpc = self.connection.send_rpc(my_method.remote, self.request)
259    finally:
260      six.moves.http_client.HTTPSConnection = original_https_connection
261    self.assertEquals(self.response, rpc.response)
262    self.assertTrue(self.used_https)
263
264  def testHttpSocketError(self):
265    self.ResetServer(wsgi_util.static_page(self.encoded_response,
266                                           content_type='application/json'))
267
268    bad_transport = transport.HttpTransport('http://localhost:-1/blar')
269    try:
270      bad_transport.send_rpc(my_method.remote, self.request)
271    except remote.NetworkError as err:
272      self.assertTrue(str(err).startswith('Socket error: error ('))
273      self.assertEquals(errno.ECONNREFUSED, err.cause.errno)
274    else:
275      self.fail('Expected error')
276
277  def testHttpRequestError(self):
278    self.ResetServer(wsgi_util.static_page(self.encoded_response,
279                                           content_type='application/json'))
280
281    def request_error(*args, **kwargs):
282      raise TypeError('Generic Error')
283    original_request = six.moves.http_client.HTTPConnection.request
284    six.moves.http_client.HTTPConnection.request = request_error
285    try:
286      try:
287        self.connection.send_rpc(my_method.remote, self.request)
288      except remote.NetworkError as err:
289        self.assertEquals('Error communicating with HTTP server', str(err))
290        self.assertEquals(TypeError, type(err.cause))
291        self.assertEquals('Generic Error', str(err.cause))
292      else:
293        self.fail('Expected error')
294    finally:
295      six.moves.http_client.HTTPConnection.request = original_request
296
297  def testHandleGenericServiceError(self):
298    self.ResetServer(wsgi_util.error(six.moves.http_client.INTERNAL_SERVER_ERROR,
299                                     'arbitrary error',
300                                     content_type='text/plain'))
301
302    rpc = self.connection.send_rpc(my_method.remote, self.request)
303    try:
304      rpc.response
305    except remote.ServerError as err:
306      self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip())
307    else:
308      self.fail('Expected ServerError')
309
310  def testHandleGenericServiceErrorNoMessage(self):
311    self.ResetServer(wsgi_util.error(six.moves.http_client.NOT_IMPLEMENTED,
312                                     ' ',
313                                     content_type='text/plain'))
314
315    rpc = self.connection.send_rpc(my_method.remote, self.request)
316    try:
317      rpc.response
318    except remote.ServerError as err:
319      self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip())
320    else:
321      self.fail('Expected ServerError')
322
323  def testHandleStatusContent(self):
324    self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",'
325                                           ' "error_message": "a request error"'
326                                           '}',
327                                           status=six.moves.http_client.BAD_REQUEST,
328                                           content_type='application/json'))
329
330    rpc = self.connection.send_rpc(my_method.remote, self.request)
331    try:
332      rpc.response
333    except remote.RequestError as err:
334      self.assertEquals('a request error', str(err))
335    else:
336      self.fail('Expected RequestError')
337
338  def testHandleApplicationError(self):
339    self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",'
340                                           ' "error_message": "an app error",'
341                                           ' "error_name": "MY_ERROR_NAME"}',
342                                           status=six.moves.http_client.BAD_REQUEST,
343                                           content_type='application/json'))
344
345    rpc = self.connection.send_rpc(my_method.remote, self.request)
346    try:
347      rpc.response
348    except remote.ApplicationError as err:
349      self.assertEquals('an app error', str(err))
350      self.assertEquals('MY_ERROR_NAME', err.error_name)
351    else:
352      self.fail('Expected RequestError')
353
354  def testHandleUnparsableErrorContent(self):
355    self.ResetServer(wsgi_util.static_page('oops',
356                                           status=six.moves.http_client.BAD_REQUEST,
357                                           content_type='application/json'))
358
359    rpc = self.connection.send_rpc(my_method.remote, self.request)
360    try:
361      rpc.response
362    except remote.ServerError as err:
363      self.assertEquals('HTTP Error 400: oops', str(err))
364    else:
365      self.fail('Expected ServerError')
366
367  def testHandleEmptyBadRpcStatus(self):
368    self.ResetServer(wsgi_util.static_page('{"error_message": "x"}',
369                                           status=six.moves.http_client.BAD_REQUEST,
370                                           content_type='application/json'))
371
372    rpc = self.connection.send_rpc(my_method.remote, self.request)
373    try:
374      rpc.response
375    except remote.ServerError as err:
376      self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err))
377    else:
378      self.fail('Expected ServerError')
379
380  def testUseProtocolConfigContentType(self):
381    expected_content_type = 'image/png'
382    def expect_content_type(environ, start_response):
383      self.assertEquals(expected_content_type, environ['CONTENT_TYPE'])
384      app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE'])
385      return app(environ, start_response)
386
387    self.ResetServer(expect_content_type)
388
389    protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png')
390    self.connection = self.CreateTransport(self.service_url, protocol_config)
391
392    rpc = self.connection.send_rpc(my_method.remote, self.request)
393    self.assertEquals(Message(), rpc.response)
394
395
396class SimpleRequest(messages.Message):
397
398  content = messages.StringField(1)
399
400
401class SimpleResponse(messages.Message):
402
403  content = messages.StringField(1)
404  factory_value = messages.StringField(2)
405  remote_host = messages.StringField(3)
406  remote_address = messages.StringField(4)
407  server_host = messages.StringField(5)
408  server_port = messages.IntegerField(6)
409
410
411class LocalService(remote.Service):
412
413  def __init__(self, factory_value='default'):
414    self.factory_value = factory_value
415
416  @remote.method(SimpleRequest, SimpleResponse)
417  def call_method(self, request):
418    return SimpleResponse(content=request.content,
419                          factory_value=self.factory_value,
420                          remote_host=self.request_state.remote_host,
421                          remote_address=self.request_state.remote_address,
422                          server_host=self.request_state.server_host,
423                          server_port=self.request_state.server_port)
424
425  @remote.method()
426  def raise_totally_unexpected(self, request):
427    raise TypeError('Kablam')
428
429  @remote.method()
430  def raise_unexpected(self, request):
431    raise remote.RequestError('Huh?')
432
433  @remote.method()
434  def raise_application_error(self, request):
435    raise remote.ApplicationError('App error', 10)
436
437
438class LocalTransportTest(test_util.TestCase):
439
440  def CreateService(self, factory_value='default'):
441    return
442
443  def testBasicCallWithClass(self):
444    stub = LocalService.Stub(transport.LocalTransport(LocalService))
445    response = stub.call_method(content='Hello')
446    self.assertEquals(SimpleResponse(content='Hello',
447                                     factory_value='default',
448                                     remote_host=os.uname()[1],
449                                     remote_address='127.0.0.1',
450                                     server_host=os.uname()[1],
451                                     server_port=-1),
452                      response)
453
454  def testBasicCallWithFactory(self):
455    stub = LocalService.Stub(
456      transport.LocalTransport(LocalService.new_factory('assigned')))
457    response = stub.call_method(content='Hello')
458    self.assertEquals(SimpleResponse(content='Hello',
459                                     factory_value='assigned',
460                                     remote_host=os.uname()[1],
461                                     remote_address='127.0.0.1',
462                                     server_host=os.uname()[1],
463                                     server_port=-1),
464                      response)
465
466  def testTotallyUnexpectedError(self):
467    stub = LocalService.Stub(transport.LocalTransport(LocalService))
468    self.assertRaisesWithRegexpMatch(
469      remote.ServerError,
470      'Unexpected error TypeError: Kablam',
471      stub.raise_totally_unexpected)
472
473  def testUnexpectedError(self):
474    stub = LocalService.Stub(transport.LocalTransport(LocalService))
475    self.assertRaisesWithRegexpMatch(
476      remote.ServerError,
477      'Unexpected error RequestError: Huh?',
478      stub.raise_unexpected)
479
480  def testApplicationError(self):
481    stub = LocalService.Stub(transport.LocalTransport(LocalService))
482    self.assertRaisesWithRegexpMatch(
483      remote.ApplicationError,
484      'App error',
485      stub.raise_application_error)
486
487
488def main():
489  unittest.main()
490
491
492if __name__ == '__main__':
493  main()
494