1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for protorpc.descriptor."""
19
20__author__ = 'rafek@google.com (Rafe Kaplan)'
21
22
23import types
24import unittest
25
26from protorpc import descriptor
27from protorpc import message_types
28from protorpc import messages
29from protorpc import registry
30from protorpc import remote
31from protorpc import test_util
32
33
34RUSSIA = u'\u0420\u043e\u0441\u0441\u0438\u044f'
35
36
37class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
38                          test_util.TestCase):
39
40  MODULE = descriptor
41
42
43class DescribeEnumValueTest(test_util.TestCase):
44
45  def testDescribe(self):
46    class MyEnum(messages.Enum):
47      MY_NAME = 10
48
49    expected = descriptor.EnumValueDescriptor()
50    expected.name = 'MY_NAME'
51    expected.number = 10
52
53    described = descriptor.describe_enum_value(MyEnum.MY_NAME)
54    described.check_initialized()
55    self.assertEquals(expected, described)
56
57
58class DescribeEnumTest(test_util.TestCase):
59
60  def testEmptyEnum(self):
61    class EmptyEnum(messages.Enum):
62      pass
63
64    expected = descriptor.EnumDescriptor()
65    expected.name = 'EmptyEnum'
66
67    described = descriptor.describe_enum(EmptyEnum)
68    described.check_initialized()
69    self.assertEquals(expected, described)
70
71  def testNestedEnum(self):
72    class MyScope(messages.Message):
73      class NestedEnum(messages.Enum):
74        pass
75
76    expected = descriptor.EnumDescriptor()
77    expected.name = 'NestedEnum'
78
79    described = descriptor.describe_enum(MyScope.NestedEnum)
80    described.check_initialized()
81    self.assertEquals(expected, described)
82
83  def testEnumWithItems(self):
84    class EnumWithItems(messages.Enum):
85      A = 3
86      B = 1
87      C = 2
88
89    expected = descriptor.EnumDescriptor()
90    expected.name = 'EnumWithItems'
91
92    a = descriptor.EnumValueDescriptor()
93    a.name = 'A'
94    a.number = 3
95
96    b = descriptor.EnumValueDescriptor()
97    b.name = 'B'
98    b.number = 1
99
100    c = descriptor.EnumValueDescriptor()
101    c.name = 'C'
102    c.number = 2
103
104    expected.values = [b, c, a]
105
106    described = descriptor.describe_enum(EnumWithItems)
107    described.check_initialized()
108    self.assertEquals(expected, described)
109
110
111class DescribeFieldTest(test_util.TestCase):
112
113  def testLabel(self):
114    for repeated, required, expected_label in (
115        (True, False, descriptor.FieldDescriptor.Label.REPEATED),
116        (False, True, descriptor.FieldDescriptor.Label.REQUIRED),
117        (False, False, descriptor.FieldDescriptor.Label.OPTIONAL)):
118      field = messages.IntegerField(10, required=required, repeated=repeated)
119      field.name = 'a_field'
120
121      expected = descriptor.FieldDescriptor()
122      expected.name = 'a_field'
123      expected.number = 10
124      expected.label = expected_label
125      expected.variant = descriptor.FieldDescriptor.Variant.INT64
126
127      described = descriptor.describe_field(field)
128      described.check_initialized()
129      self.assertEquals(expected, described)
130
131  def testDefault(self):
132    for field_class, default, expected_default in (
133        (messages.IntegerField, 200, '200'),
134        (messages.FloatField, 1.5, '1.5'),
135        (messages.FloatField, 1e6, '1000000.0'),
136        (messages.BooleanField, True, 'true'),
137        (messages.BooleanField, False, 'false'),
138        (messages.BytesField, 'ab\xF1', 'ab\\xf1'),
139        (messages.StringField, RUSSIA, RUSSIA),
140        ):
141      field = field_class(10, default=default)
142      field.name = u'a_field'
143
144      expected = descriptor.FieldDescriptor()
145      expected.name = u'a_field'
146      expected.number = 10
147      expected.label = descriptor.FieldDescriptor.Label.OPTIONAL
148      expected.variant = field_class.DEFAULT_VARIANT
149      expected.default_value = expected_default
150
151      described = descriptor.describe_field(field)
152      described.check_initialized()
153      self.assertEquals(expected, described)
154
155  def testDefault_EnumField(self):
156    class MyEnum(messages.Enum):
157
158      VAL = 1
159
160    module_name = test_util.get_module_name(MyEnum)
161    field = messages.EnumField(MyEnum, 10, default=MyEnum.VAL)
162    field.name = 'a_field'
163
164    expected = descriptor.FieldDescriptor()
165    expected.name = 'a_field'
166    expected.number = 10
167    expected.label = descriptor.FieldDescriptor.Label.OPTIONAL
168    expected.variant = messages.EnumField.DEFAULT_VARIANT
169    expected.type_name = '%s.MyEnum' % module_name
170    expected.default_value = '1'
171
172    described = descriptor.describe_field(field)
173    self.assertEquals(expected, described)
174
175  def testMessageField(self):
176    field = messages.MessageField(descriptor.FieldDescriptor, 10)
177    field.name = 'a_field'
178
179    expected = descriptor.FieldDescriptor()
180    expected.name = 'a_field'
181    expected.number = 10
182    expected.label = descriptor.FieldDescriptor.Label.OPTIONAL
183    expected.variant = messages.MessageField.DEFAULT_VARIANT
184    expected.type_name = ('protorpc.descriptor.FieldDescriptor')
185
186    described = descriptor.describe_field(field)
187    described.check_initialized()
188    self.assertEquals(expected, described)
189
190  def testDateTimeField(self):
191    field = message_types.DateTimeField(20)
192    field.name = 'a_timestamp'
193
194    expected = descriptor.FieldDescriptor()
195    expected.name = 'a_timestamp'
196    expected.number = 20
197    expected.label = descriptor.FieldDescriptor.Label.OPTIONAL
198    expected.variant = messages.MessageField.DEFAULT_VARIANT
199    expected.type_name = ('protorpc.message_types.DateTimeMessage')
200
201    described = descriptor.describe_field(field)
202    described.check_initialized()
203    self.assertEquals(expected, described)
204
205
206class DescribeMessageTest(test_util.TestCase):
207
208  def testEmptyDefinition(self):
209    class MyMessage(messages.Message):
210      pass
211
212    expected = descriptor.MessageDescriptor()
213    expected.name = 'MyMessage'
214
215    described = descriptor.describe_message(MyMessage)
216    described.check_initialized()
217    self.assertEquals(expected, described)
218
219  def testDefinitionWithFields(self):
220    class MessageWithFields(messages.Message):
221      field1 = messages.IntegerField(10)
222      field2 = messages.StringField(30)
223      field3 = messages.IntegerField(20)
224
225    expected = descriptor.MessageDescriptor()
226    expected.name = 'MessageWithFields'
227
228    expected.fields = [
229      descriptor.describe_field(MessageWithFields.field_by_name('field1')),
230      descriptor.describe_field(MessageWithFields.field_by_name('field3')),
231      descriptor.describe_field(MessageWithFields.field_by_name('field2')),
232    ]
233
234    described = descriptor.describe_message(MessageWithFields)
235    described.check_initialized()
236    self.assertEquals(expected, described)
237
238  def testNestedEnum(self):
239    class MessageWithEnum(messages.Message):
240      class Mood(messages.Enum):
241        GOOD = 1
242        BAD = 2
243        UGLY = 3
244
245      class Music(messages.Enum):
246        CLASSIC = 1
247        JAZZ = 2
248        BLUES = 3
249
250    expected = descriptor.MessageDescriptor()
251    expected.name = 'MessageWithEnum'
252
253    expected.enum_types = [descriptor.describe_enum(MessageWithEnum.Mood),
254                           descriptor.describe_enum(MessageWithEnum.Music)]
255
256    described = descriptor.describe_message(MessageWithEnum)
257    described.check_initialized()
258    self.assertEquals(expected, described)
259
260  def testNestedMessage(self):
261    class MessageWithMessage(messages.Message):
262      class Nesty(messages.Message):
263        pass
264
265    expected = descriptor.MessageDescriptor()
266    expected.name = 'MessageWithMessage'
267
268    expected.message_types = [
269      descriptor.describe_message(MessageWithMessage.Nesty)]
270
271    described = descriptor.describe_message(MessageWithMessage)
272    described.check_initialized()
273    self.assertEquals(expected, described)
274
275
276class DescribeMethodTest(test_util.TestCase):
277  """Test describing remote methods."""
278
279  def testDescribe(self):
280    class Request(messages.Message):
281      pass
282
283    class Response(messages.Message):
284      pass
285
286    @remote.method(Request, Response)
287    def remote_method(request):
288      pass
289
290    module_name = test_util.get_module_name(DescribeMethodTest)
291    expected = descriptor.MethodDescriptor()
292    expected.name = 'remote_method'
293    expected.request_type = '%s.Request' % module_name
294    expected.response_type = '%s.Response' % module_name
295
296    described = descriptor.describe_method(remote_method)
297    described.check_initialized()
298    self.assertEquals(expected, described)
299
300
301class DescribeServiceTest(test_util.TestCase):
302  """Test describing service classes."""
303
304  def testDescribe(self):
305    class Request1(messages.Message):
306      pass
307
308    class Response1(messages.Message):
309      pass
310
311    class Request2(messages.Message):
312      pass
313
314    class Response2(messages.Message):
315      pass
316
317    class MyService(remote.Service):
318
319      @remote.method(Request1, Response1)
320      def method1(self, request):
321        pass
322
323      @remote.method(Request2, Response2)
324      def method2(self, request):
325        pass
326
327    expected = descriptor.ServiceDescriptor()
328    expected.name = 'MyService'
329    expected.methods = []
330
331    expected.methods.append(descriptor.describe_method(MyService.method1))
332    expected.methods.append(descriptor.describe_method(MyService.method2))
333
334    described = descriptor.describe_service(MyService)
335    described.check_initialized()
336    self.assertEquals(expected, described)
337
338
339class DescribeFileTest(test_util.TestCase):
340  """Test describing modules."""
341
342  def LoadModule(self, module_name, source):
343    result = {'__name__': module_name,
344              'messages': messages,
345              'remote': remote,
346              }
347    exec(source, result)
348
349    module = types.ModuleType(module_name)
350    for name, value in result.items():
351      setattr(module, name, value)
352
353    return module
354
355  def testEmptyModule(self):
356    """Test describing an empty file."""
357    module = types.ModuleType('my.package.name')
358
359    expected = descriptor.FileDescriptor()
360    expected.package = 'my.package.name'
361
362    described = descriptor.describe_file(module)
363    described.check_initialized()
364    self.assertEquals(expected, described)
365
366  def testNoPackageName(self):
367    """Test describing a module with no module name."""
368    module = types.ModuleType('')
369
370    expected = descriptor.FileDescriptor()
371
372    described = descriptor.describe_file(module)
373    described.check_initialized()
374    self.assertEquals(expected, described)
375
376  def testPackageName(self):
377    """Test using the 'package' module attribute."""
378    module = types.ModuleType('my.module.name')
379    module.package = 'my.package.name'
380
381    expected = descriptor.FileDescriptor()
382    expected.package = 'my.package.name'
383
384    described = descriptor.describe_file(module)
385    described.check_initialized()
386    self.assertEquals(expected, described)
387
388  def testMain(self):
389    """Test using the 'package' module attribute."""
390    module = types.ModuleType('__main__')
391    module.__file__ = '/blim/blam/bloom/my_package.py'
392
393    expected = descriptor.FileDescriptor()
394    expected.package = 'my_package'
395
396    described = descriptor.describe_file(module)
397    described.check_initialized()
398    self.assertEquals(expected, described)
399
400  def testMessages(self):
401    """Test that messages are described."""
402    module = self.LoadModule('my.package',
403                             'class Message1(messages.Message): pass\n'
404                             'class Message2(messages.Message): pass\n')
405
406    message1 = descriptor.MessageDescriptor()
407    message1.name = 'Message1'
408
409    message2 = descriptor.MessageDescriptor()
410    message2.name = 'Message2'
411
412    expected = descriptor.FileDescriptor()
413    expected.package = 'my.package'
414    expected.message_types = [message1, message2]
415
416    described = descriptor.describe_file(module)
417    described.check_initialized()
418    self.assertEquals(expected, described)
419
420  def testEnums(self):
421    """Test that enums are described."""
422    module = self.LoadModule('my.package',
423                             'class Enum1(messages.Enum): pass\n'
424                             'class Enum2(messages.Enum): pass\n')
425
426    enum1 = descriptor.EnumDescriptor()
427    enum1.name = 'Enum1'
428
429    enum2 = descriptor.EnumDescriptor()
430    enum2.name = 'Enum2'
431
432    expected = descriptor.FileDescriptor()
433    expected.package = 'my.package'
434    expected.enum_types = [enum1, enum2]
435
436    described = descriptor.describe_file(module)
437    described.check_initialized()
438    self.assertEquals(expected, described)
439
440  def testServices(self):
441    """Test that services are described."""
442    module = self.LoadModule('my.package',
443                             'class Service1(remote.Service): pass\n'
444                             'class Service2(remote.Service): pass\n')
445
446    service1 = descriptor.ServiceDescriptor()
447    service1.name = 'Service1'
448
449    service2 = descriptor.ServiceDescriptor()
450    service2.name = 'Service2'
451
452    expected = descriptor.FileDescriptor()
453    expected.package = 'my.package'
454    expected.service_types = [service1, service2]
455
456    described = descriptor.describe_file(module)
457    described.check_initialized()
458    self.assertEquals(expected, described)
459
460
461class DescribeFileSetTest(test_util.TestCase):
462  """Test describing multiple modules."""
463
464  def testNoModules(self):
465    """Test what happens when no modules provided."""
466    described = descriptor.describe_file_set([])
467    described.check_initialized()
468    # The described FileSet.files will be None.
469    self.assertEquals(descriptor.FileSet(), described)
470
471  def testWithModules(self):
472    """Test what happens when no modules provided."""
473    modules = [types.ModuleType('package1'), types.ModuleType('package1')]
474
475    file1 = descriptor.FileDescriptor()
476    file1.package = 'package1'
477    file2 = descriptor.FileDescriptor()
478    file2.package = 'package2'
479
480    expected = descriptor.FileSet()
481    expected.files = [file1, file1]
482
483    described = descriptor.describe_file_set(modules)
484    described.check_initialized()
485    self.assertEquals(expected, described)
486
487
488class DescribeTest(test_util.TestCase):
489
490  def testModule(self):
491    self.assertEquals(descriptor.describe_file(test_util),
492                      descriptor.describe(test_util))
493
494  def testMethod(self):
495    class Param(messages.Message):
496      pass
497
498    class Service(remote.Service):
499
500      @remote.method(Param, Param)
501      def fn(self):
502        return Param()
503
504    self.assertEquals(descriptor.describe_method(Service.fn),
505                      descriptor.describe(Service.fn))
506
507  def testField(self):
508    self.assertEquals(
509      descriptor.describe_field(test_util.NestedMessage.a_value),
510      descriptor.describe(test_util.NestedMessage.a_value))
511
512  def testEnumValue(self):
513    self.assertEquals(
514      descriptor.describe_enum_value(
515        test_util.OptionalMessage.SimpleEnum.VAL1),
516      descriptor.describe(test_util.OptionalMessage.SimpleEnum.VAL1))
517
518  def testMessage(self):
519    self.assertEquals(descriptor.describe_message(test_util.NestedMessage),
520                      descriptor.describe(test_util.NestedMessage))
521
522  def testEnum(self):
523    self.assertEquals(
524      descriptor.describe_enum(test_util.OptionalMessage.SimpleEnum),
525      descriptor.describe(test_util.OptionalMessage.SimpleEnum))
526
527  def testService(self):
528    class Service(remote.Service):
529      pass
530
531    self.assertEquals(descriptor.describe_service(Service),
532                      descriptor.describe(Service))
533
534  def testService(self):
535    class Service(remote.Service):
536      pass
537
538    self.assertEquals(descriptor.describe_service(Service),
539                      descriptor.describe(Service))
540
541  def testUndescribable(self):
542    class NonService(object):
543
544      def fn(self):
545        pass
546
547    for value in (NonService,
548                  NonService.fn,
549                  1,
550                  'string',
551                  1.2,
552                  None):
553      self.assertEquals(None, descriptor.describe(value))
554
555
556class ModuleFinderTest(test_util.TestCase):
557
558  def testFindModule(self):
559    self.assertEquals(descriptor.describe_file(registry),
560                      descriptor.import_descriptor_loader('protorpc.registry'))
561
562  def testFindMessage(self):
563    self.assertEquals(
564      descriptor.describe_message(descriptor.FileSet),
565      descriptor.import_descriptor_loader('protorpc.descriptor.FileSet'))
566
567  def testFindField(self):
568    self.assertEquals(
569      descriptor.describe_field(descriptor.FileSet.files),
570      descriptor.import_descriptor_loader('protorpc.descriptor.FileSet.files'))
571
572  def testFindEnumValue(self):
573    self.assertEquals(
574      descriptor.describe_enum_value(test_util.OptionalMessage.SimpleEnum.VAL1),
575      descriptor.import_descriptor_loader(
576        'protorpc.test_util.OptionalMessage.SimpleEnum.VAL1'))
577
578  def testFindMethod(self):
579    self.assertEquals(
580      descriptor.describe_method(registry.RegistryService.services),
581      descriptor.import_descriptor_loader(
582        'protorpc.registry.RegistryService.services'))
583
584  def testFindService(self):
585    self.assertEquals(
586      descriptor.describe_service(registry.RegistryService),
587      descriptor.import_descriptor_loader('protorpc.registry.RegistryService'))
588
589  def testFindWithAbsoluteName(self):
590    self.assertEquals(
591      descriptor.describe_service(registry.RegistryService),
592      descriptor.import_descriptor_loader('.protorpc.registry.RegistryService'))
593
594  def testFindWrongThings(self):
595    for name in ('a', 'protorpc.registry.RegistryService.__init__', '', ):
596      self.assertRaisesWithRegexpMatch(
597        messages.DefinitionNotFoundError,
598        'Could not find definition for %s' % name,
599        descriptor.import_descriptor_loader, name)
600
601
602class DescriptorLibraryTest(test_util.TestCase):
603
604  def setUp(self):
605    self.packageless = descriptor.MessageDescriptor()
606    self.packageless.name = 'Packageless'
607    self.library = descriptor.DescriptorLibrary(
608      descriptors={
609        'not.real.Packageless': self.packageless,
610        'Packageless': self.packageless,
611      })
612
613  def testLookupPackage(self):
614    self.assertEquals('csv', self.library.lookup_package('csv'))
615    self.assertEquals('protorpc', self.library.lookup_package('protorpc'))
616    self.assertEquals('protorpc.registry',
617                      self.library.lookup_package('protorpc.registry'))
618    self.assertEquals('protorpc.registry',
619                      self.library.lookup_package('.protorpc.registry'))
620    self.assertEquals(
621      'protorpc.registry',
622      self.library.lookup_package('protorpc.registry.RegistryService'))
623    self.assertEquals(
624      'protorpc.registry',
625      self.library.lookup_package(
626        'protorpc.registry.RegistryService.services'))
627
628  def testLookupNonPackages(self):
629    for name in ('', 'a', 'protorpc.descriptor.DescriptorLibrary'):
630      self.assertRaisesWithRegexpMatch(
631        messages.DefinitionNotFoundError,
632        'Could not find definition for %s' % name,
633        self.library.lookup_package, name)
634
635  def testNoPackage(self):
636    self.assertRaisesWithRegexpMatch(
637      messages.DefinitionNotFoundError,
638      'Could not find definition for not.real',
639      self.library.lookup_package, 'not.real.Packageless')
640
641    self.assertEquals(None, self.library.lookup_package('Packageless'))
642
643
644def main():
645  unittest.main()
646
647
648if __name__ == '__main__':
649  main()
650