1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Tests for apitools.base.protorpclite.messages."""
19import pickle
20import re
21import sys
22import types
23import unittest
24
25import six
26
27from apitools.base.protorpclite import descriptor
28from apitools.base.protorpclite import message_types
29from apitools.base.protorpclite import messages
30from apitools.base.protorpclite import test_util
31
32# This package plays lots of games with modifying global variables inside
33# test cases. Hence:
34# pylint:disable=function-redefined
35# pylint:disable=global-variable-not-assigned
36# pylint:disable=global-variable-undefined
37# pylint:disable=redefined-outer-name
38# pylint:disable=undefined-variable
39# pylint:disable=unused-variable
40# pylint:disable=too-many-lines
41
42try:
43    long        # Python 2
44except NameError:
45    long = int  # Python 3
46
47
48class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
49                          test_util.TestCase):
50
51    MODULE = messages
52
53
54class ValidationErrorTest(test_util.TestCase):
55
56    def testStr_NoFieldName(self):
57        """Test string version of ValidationError when no name provided."""
58        self.assertEquals('Validation error',
59                          str(messages.ValidationError('Validation error')))
60
61    def testStr_FieldName(self):
62        """Test string version of ValidationError when no name provided."""
63        validation_error = messages.ValidationError('Validation error')
64        validation_error.field_name = 'a_field'
65        self.assertEquals('Validation error', str(validation_error))
66
67
68class EnumTest(test_util.TestCase):
69
70    def setUp(self):
71        """Set up tests."""
72        # Redefine Color class in case so that changes to it (an
73        # error) in one test does not affect other tests.
74        global Color  # pylint:disable=global-variable-not-assigned
75
76        # pylint:disable=unused-variable
77        class Color(messages.Enum):
78            RED = 20
79            ORANGE = 2
80            YELLOW = 40
81            GREEN = 4
82            BLUE = 50
83            INDIGO = 5
84            VIOLET = 80
85
86    def testNames(self):
87        """Test that names iterates over enum names."""
88        self.assertEquals(
89            set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED',
90                 'VIOLET', 'YELLOW']),
91            set(Color.names()))
92
93    def testNumbers(self):
94        """Tests that numbers iterates of enum numbers."""
95        self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
96
97    def testIterate(self):
98        """Test that __iter__ iterates over all enum values."""
99        self.assertEquals(set(Color),
100                          set([Color.RED,
101                               Color.ORANGE,
102                               Color.YELLOW,
103                               Color.GREEN,
104                               Color.BLUE,
105                               Color.INDIGO,
106                               Color.VIOLET]))
107
108    def testNaturalOrder(self):
109        """Test that natural order enumeration is in numeric order."""
110        self.assertEquals([Color.ORANGE,
111                           Color.GREEN,
112                           Color.INDIGO,
113                           Color.RED,
114                           Color.YELLOW,
115                           Color.BLUE,
116                           Color.VIOLET],
117                          sorted(Color))
118
119    def testByName(self):
120        """Test look-up by name."""
121        self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
122        self.assertRaises(KeyError, Color.lookup_by_name, 20)
123        self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
124
125    def testByNumber(self):
126        """Test look-up by number."""
127        self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
128        self.assertEquals(Color.RED, Color.lookup_by_number(20))
129        self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
130
131    def testConstructor(self):
132        """Test that constructor look-up by name or number."""
133        self.assertEquals(Color.RED, Color('RED'))
134        self.assertEquals(Color.RED, Color(u'RED'))
135        self.assertEquals(Color.RED, Color(20))
136        if six.PY2:
137            self.assertEquals(Color.RED, Color(long(20)))
138        self.assertEquals(Color.RED, Color(Color.RED))
139        self.assertRaises(TypeError, Color, 'Not exists')
140        self.assertRaises(TypeError, Color, 'Red')
141        self.assertRaises(TypeError, Color, 100)
142        self.assertRaises(TypeError, Color, 10.0)
143
144    def testLen(self):
145        """Test that len function works to count enums."""
146        self.assertEquals(7, len(Color))
147
148    def testNoSubclasses(self):
149        """Test that it is not possible to sub-class enum classes."""
150        def declare_subclass():
151            class MoreColor(Color):
152                pass
153        self.assertRaises(messages.EnumDefinitionError,
154                          declare_subclass)
155
156    def testClassNotMutable(self):
157        """Test that enum classes themselves are not mutable."""
158        self.assertRaises(AttributeError,
159                          setattr,
160                          Color,
161                          'something_new',
162                          10)
163
164    def testInstancesMutable(self):
165        """Test that enum instances are not mutable."""
166        self.assertRaises(TypeError,
167                          setattr,
168                          Color.RED,
169                          'something_new',
170                          10)
171
172    def testDefEnum(self):
173        """Test def_enum works by building enum class from dict."""
174        WeekDay = messages.Enum.def_enum({'Monday': 1,
175                                          'Tuesday': 2,
176                                          'Wednesday': 3,
177                                          'Thursday': 4,
178                                          'Friday': 6,
179                                          'Saturday': 7,
180                                          'Sunday': 8},
181                                         'WeekDay')
182        self.assertEquals('Wednesday', WeekDay(3).name)
183        self.assertEquals(6, WeekDay('Friday').number)
184        self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
185
186    def testNonInt(self):
187        """Test that non-integer values rejection by enum def."""
188        self.assertRaises(messages.EnumDefinitionError,
189                          messages.Enum.def_enum,
190                          {'Bad': '1'},
191                          'BadEnum')
192
193    def testNegativeInt(self):
194        """Test that negative numbers rejection by enum def."""
195        self.assertRaises(messages.EnumDefinitionError,
196                          messages.Enum.def_enum,
197                          {'Bad': -1},
198                          'BadEnum')
199
200    def testLowerBound(self):
201        """Test that zero is accepted by enum def."""
202        class NotImportant(messages.Enum):
203            """Testing for value zero"""
204            VALUE = 0
205
206        self.assertEquals(0, int(NotImportant.VALUE))
207
208    def testTooLargeInt(self):
209        """Test that numbers too large are rejected."""
210        self.assertRaises(messages.EnumDefinitionError,
211                          messages.Enum.def_enum,
212                          {'Bad': (2 ** 29)},
213                          'BadEnum')
214
215    def testRepeatedInt(self):
216        """Test duplicated numbers are forbidden."""
217        self.assertRaises(messages.EnumDefinitionError,
218                          messages.Enum.def_enum,
219                          {'Ok': 1, 'Repeated': 1},
220                          'BadEnum')
221
222    def testStr(self):
223        """Test converting to string."""
224        self.assertEquals('RED', str(Color.RED))
225        self.assertEquals('ORANGE', str(Color.ORANGE))
226
227    def testInt(self):
228        """Test converting to int."""
229        self.assertEquals(20, int(Color.RED))
230        self.assertEquals(2, int(Color.ORANGE))
231
232    def testRepr(self):
233        """Test enum representation."""
234        self.assertEquals('Color(RED, 20)', repr(Color.RED))
235        self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
236
237    def testDocstring(self):
238        """Test that docstring is supported ok."""
239        class NotImportant(messages.Enum):
240            """I have a docstring."""
241
242            VALUE1 = 1
243
244        self.assertEquals('I have a docstring.', NotImportant.__doc__)
245
246    def testDeleteEnumValue(self):
247        """Test that enum values cannot be deleted."""
248        self.assertRaises(TypeError, delattr, Color, 'RED')
249
250    def testEnumName(self):
251        """Test enum name."""
252        module_name = test_util.get_module_name(EnumTest)
253        self.assertEquals('%s.Color' % module_name, Color.definition_name())
254        self.assertEquals(module_name, Color.outer_definition_name())
255        self.assertEquals(module_name, Color.definition_package())
256
257    def testDefinitionName_OverrideModule(self):
258        """Test enum module is overriden by module package name."""
259        global package
260        try:
261            package = 'my.package'
262            self.assertEquals('my.package.Color', Color.definition_name())
263            self.assertEquals('my.package', Color.outer_definition_name())
264            self.assertEquals('my.package', Color.definition_package())
265        finally:
266            del package
267
268    def testDefinitionName_NoModule(self):
269        """Test what happens when there is no module for enum."""
270        class Enum1(messages.Enum):
271            pass
272
273        original_modules = sys.modules
274        sys.modules = dict(sys.modules)
275        try:
276            del sys.modules[__name__]
277            self.assertEquals('Enum1', Enum1.definition_name())
278            self.assertEquals(None, Enum1.outer_definition_name())
279            self.assertEquals(None, Enum1.definition_package())
280            self.assertEquals(six.text_type, type(Enum1.definition_name()))
281        finally:
282            sys.modules = original_modules
283
284    def testDefinitionName_Nested(self):
285        """Test nested Enum names."""
286        class MyMessage(messages.Message):
287
288            class NestedEnum(messages.Enum):
289
290                pass
291
292            class NestedMessage(messages.Message):
293
294                class NestedEnum(messages.Enum):
295
296                    pass
297
298        module_name = test_util.get_module_name(EnumTest)
299        self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
300                          MyMessage.NestedEnum.definition_name())
301        self.assertEquals('%s.MyMessage' % module_name,
302                          MyMessage.NestedEnum.outer_definition_name())
303        self.assertEquals(module_name,
304                          MyMessage.NestedEnum.definition_package())
305
306        self.assertEquals(
307            '%s.MyMessage.NestedMessage.NestedEnum' % module_name,
308            MyMessage.NestedMessage.NestedEnum.definition_name())
309        self.assertEquals(
310            '%s.MyMessage.NestedMessage' % module_name,
311            MyMessage.NestedMessage.NestedEnum.outer_definition_name())
312        self.assertEquals(
313            module_name,
314            MyMessage.NestedMessage.NestedEnum.definition_package())
315
316    def testMessageDefinition(self):
317        """Test that enumeration knows its enclosing message definition."""
318        class OuterEnum(messages.Enum):
319            pass
320
321        self.assertEquals(None, OuterEnum.message_definition())
322
323        class OuterMessage(messages.Message):
324
325            class InnerEnum(messages.Enum):
326                pass
327
328        self.assertEquals(
329            OuterMessage, OuterMessage.InnerEnum.message_definition())
330
331    def testComparison(self):
332        """Test comparing various enums to different types."""
333        class Enum1(messages.Enum):
334            VAL1 = 1
335            VAL2 = 2
336
337        class Enum2(messages.Enum):
338            VAL1 = 1
339
340        self.assertEquals(Enum1.VAL1, Enum1.VAL1)
341        self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
342        self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
343        self.assertNotEquals(Enum1.VAL1, 'VAL1')
344        self.assertNotEquals(Enum1.VAL1, 1)
345        self.assertNotEquals(Enum1.VAL1, 2)
346        self.assertNotEquals(Enum1.VAL1, None)
347        self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
348
349        self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
350        self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
351
352        self.assertNotEquals(1, Enum2.VAL1)
353
354    def testPickle(self):
355        """Testing pickling and unpickling of Enum instances."""
356        colors = list(Color)
357        unpickled = pickle.loads(pickle.dumps(colors))
358        self.assertEquals(colors, unpickled)
359        # Unpickling shouldn't create new enum instances.
360        for i, color in enumerate(colors):
361            self.assertTrue(color is unpickled[i])
362
363
364class FieldListTest(test_util.TestCase):
365
366    def setUp(self):
367        self.integer_field = messages.IntegerField(1, repeated=True)
368
369    def testConstructor(self):
370        self.assertEquals([1, 2, 3],
371                          messages.FieldList(self.integer_field, [1, 2, 3]))
372        self.assertEquals([1, 2, 3],
373                          messages.FieldList(self.integer_field, (1, 2, 3)))
374        self.assertEquals([], messages.FieldList(self.integer_field, []))
375
376    def testNone(self):
377        self.assertRaises(TypeError, messages.FieldList,
378                          self.integer_field, None)
379
380    def testDoNotAutoConvertString(self):
381        string_field = messages.StringField(1, repeated=True)
382        self.assertRaises(messages.ValidationError,
383                          messages.FieldList, string_field, 'abc')
384
385    def testConstructorCopies(self):
386        a_list = [1, 3, 6]
387        field_list = messages.FieldList(self.integer_field, a_list)
388        self.assertFalse(a_list is field_list)
389        self.assertFalse(field_list is
390                         messages.FieldList(self.integer_field, field_list))
391
392    def testNonRepeatedField(self):
393        self.assertRaisesWithRegexpMatch(
394            messages.FieldDefinitionError,
395            'FieldList may only accept repeated fields',
396            messages.FieldList,
397            messages.IntegerField(1),
398            [])
399
400    def testConstructor_InvalidValues(self):
401        self.assertRaisesWithRegexpMatch(
402            messages.ValidationError,
403            re.escape("Expected type %r "
404                      "for IntegerField, found 1 (type %r)"
405                      % (six.integer_types, str)),
406            messages.FieldList, self.integer_field, ["1", "2", "3"])
407
408    def testConstructor_Scalars(self):
409        self.assertRaisesWithRegexpMatch(
410            messages.ValidationError,
411            "IntegerField is repeated. Found: 3",
412            messages.FieldList, self.integer_field, 3)
413
414        self.assertRaisesWithRegexpMatch(
415            messages.ValidationError,
416            ("IntegerField is repeated. Found: "
417             "<(list[_]?|sequence)iterator object"),
418            messages.FieldList, self.integer_field, iter([1, 2, 3]))
419
420    def testSetSlice(self):
421        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
422        field_list[1:3] = [10, 20]
423        self.assertEquals([1, 10, 20, 4, 5], field_list)
424
425    def testSetSlice_InvalidValues(self):
426        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
427
428        def setslice():
429            field_list[1:3] = ['10', '20']
430
431        msg_re = re.escape("Expected type %r "
432                           "for IntegerField, found 10 (type %r)"
433                           % (six.integer_types, str))
434        self.assertRaisesWithRegexpMatch(
435            messages.ValidationError,
436            msg_re,
437            setslice)
438
439    def testSetItem(self):
440        field_list = messages.FieldList(self.integer_field, [2])
441        field_list[0] = 10
442        self.assertEquals([10], field_list)
443
444    def testSetItem_InvalidValues(self):
445        field_list = messages.FieldList(self.integer_field, [2])
446
447        def setitem():
448            field_list[0] = '10'
449        self.assertRaisesWithRegexpMatch(
450            messages.ValidationError,
451            re.escape("Expected type %r "
452                      "for IntegerField, found 10 (type %r)"
453                      % (six.integer_types, str)),
454            setitem)
455
456    def testAppend(self):
457        field_list = messages.FieldList(self.integer_field, [2])
458        field_list.append(10)
459        self.assertEquals([2, 10], field_list)
460
461    def testAppend_InvalidValues(self):
462        field_list = messages.FieldList(self.integer_field, [2])
463        field_list.name = 'a_field'
464
465        def append():
466            field_list.append('10')
467        self.assertRaisesWithRegexpMatch(
468            messages.ValidationError,
469            re.escape("Expected type %r "
470                      "for IntegerField, found 10 (type %r)"
471                      % (six.integer_types, str)),
472            append)
473
474    def testExtend(self):
475        field_list = messages.FieldList(self.integer_field, [2])
476        field_list.extend([10])
477        self.assertEquals([2, 10], field_list)
478
479    def testExtend_InvalidValues(self):
480        field_list = messages.FieldList(self.integer_field, [2])
481
482        def extend():
483            field_list.extend(['10'])
484        self.assertRaisesWithRegexpMatch(
485            messages.ValidationError,
486            re.escape("Expected type %r "
487                      "for IntegerField, found 10 (type %r)"
488                      % (six.integer_types, str)),
489            extend)
490
491    def testInsert(self):
492        field_list = messages.FieldList(self.integer_field, [2, 3])
493        field_list.insert(1, 10)
494        self.assertEquals([2, 10, 3], field_list)
495
496    def testInsert_InvalidValues(self):
497        field_list = messages.FieldList(self.integer_field, [2, 3])
498
499        def insert():
500            field_list.insert(1, '10')
501        self.assertRaisesWithRegexpMatch(
502            messages.ValidationError,
503            re.escape("Expected type %r "
504                      "for IntegerField, found 10 (type %r)"
505                      % (six.integer_types, str)),
506            insert)
507
508    def testPickle(self):
509        """Testing pickling and unpickling of FieldList instances."""
510        field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
511        unpickled = pickle.loads(pickle.dumps(field_list))
512        self.assertEquals(field_list, unpickled)
513        self.assertIsInstance(unpickled.field, messages.IntegerField)
514        self.assertEquals(1, unpickled.field.number)
515        self.assertTrue(unpickled.field.repeated)
516
517
518class FieldTest(test_util.TestCase):
519
520    def ActionOnAllFieldClasses(self, action):
521        """Test all field classes except Message and Enum.
522
523        Message and Enum require separate tests.
524
525        Args:
526          action: Callable that takes the field class as a parameter.
527        """
528        classes = (messages.IntegerField,
529                   messages.FloatField,
530                   messages.BooleanField,
531                   messages.BytesField,
532                   messages.StringField)
533        for field_class in classes:
534            action(field_class)
535
536    def testNumberAttribute(self):
537        """Test setting the number attribute."""
538        def action(field_class):
539            # Check range.
540            self.assertRaises(messages.InvalidNumberError,
541                              field_class,
542                              0)
543            self.assertRaises(messages.InvalidNumberError,
544                              field_class,
545                              -1)
546            self.assertRaises(messages.InvalidNumberError,
547                              field_class,
548                              messages.MAX_FIELD_NUMBER + 1)
549
550            # Check reserved.
551            self.assertRaises(messages.InvalidNumberError,
552                              field_class,
553                              messages.FIRST_RESERVED_FIELD_NUMBER)
554            self.assertRaises(messages.InvalidNumberError,
555                              field_class,
556                              messages.LAST_RESERVED_FIELD_NUMBER)
557            self.assertRaises(messages.InvalidNumberError,
558                              field_class,
559                              '1')
560
561            # This one should work.
562            field_class(number=1)
563        self.ActionOnAllFieldClasses(action)
564
565    def testRequiredAndRepeated(self):
566        """Test setting the required and repeated fields."""
567        def action(field_class):
568            field_class(1, required=True)
569            field_class(1, repeated=True)
570            self.assertRaises(messages.FieldDefinitionError,
571                              field_class,
572                              1,
573                              required=True,
574                              repeated=True)
575        self.ActionOnAllFieldClasses(action)
576
577    def testInvalidVariant(self):
578        """Test field with invalid variants."""
579        def action(field_class):
580            if field_class is not message_types.DateTimeField:
581                self.assertRaises(messages.InvalidVariantError,
582                                  field_class,
583                                  1,
584                                  variant=messages.Variant.ENUM)
585        self.ActionOnAllFieldClasses(action)
586
587    def testDefaultVariant(self):
588        """Test that default variant is used when not set."""
589        def action(field_class):
590            field = field_class(1)
591            self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
592
593        self.ActionOnAllFieldClasses(action)
594
595    def testAlternateVariant(self):
596        """Test that default variant is used when not set."""
597        field = messages.IntegerField(1, variant=messages.Variant.UINT32)
598        self.assertEquals(messages.Variant.UINT32, field.variant)
599
600    def testDefaultFields_Single(self):
601        """Test default field is correct type (single)."""
602        defaults = {
603            messages.IntegerField: 10,
604            messages.FloatField: 1.5,
605            messages.BooleanField: False,
606            messages.BytesField: b'abc',
607            messages.StringField: u'abc',
608        }
609
610        def action(field_class):
611            field_class(1, default=defaults[field_class])
612        self.ActionOnAllFieldClasses(action)
613
614        # Run defaults test again checking for str/unicode compatiblity.
615        defaults[messages.StringField] = 'abc'
616        self.ActionOnAllFieldClasses(action)
617
618    def testStringField_BadUnicodeInDefault(self):
619        """Test binary values in string field."""
620        self.assertRaisesWithRegexpMatch(
621            messages.InvalidDefaultError,
622            r"Invalid default value for StringField:.*: "
623            r"Field encountered non-UTF-8 string .*: "
624            r"'utf.?8' codec can't decode byte 0xc3 in position 0: "
625            r"invalid continuation byte",
626            messages.StringField, 1, default=b'\xc3\x28')
627
628    def testDefaultFields_InvalidSingle(self):
629        """Test default field is correct type (invalid single)."""
630        def action(field_class):
631            self.assertRaises(messages.InvalidDefaultError,
632                              field_class,
633                              1,
634                              default=object())
635        self.ActionOnAllFieldClasses(action)
636
637    def testDefaultFields_InvalidRepeated(self):
638        """Test default field does not accept defaults."""
639        self.assertRaisesWithRegexpMatch(
640            messages.FieldDefinitionError,
641            'Repeated fields may not have defaults',
642            messages.StringField, 1, repeated=True, default=[1, 2, 3])
643
644    def testDefaultFields_None(self):
645        """Test none is always acceptable."""
646        def action(field_class):
647            field_class(1, default=None)
648            field_class(1, required=True, default=None)
649            field_class(1, repeated=True, default=None)
650        self.ActionOnAllFieldClasses(action)
651
652    def testDefaultFields_Enum(self):
653        """Test the default for enum fields."""
654        class Symbol(messages.Enum):
655
656            ALPHA = 1
657            BETA = 2
658            GAMMA = 3
659
660        field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
661
662        self.assertEquals(Symbol.ALPHA, field.default)
663
664    def testDefaultFields_EnumStringDelayedResolution(self):
665        """Test that enum fields resolve default strings."""
666        field = messages.EnumField(
667            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
668            1,
669            default='OPTIONAL')
670
671        self.assertEquals(
672            descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
673
674    def testDefaultFields_EnumIntDelayedResolution(self):
675        """Test that enum fields resolve default integers."""
676        field = messages.EnumField(
677            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
678            1,
679            default=2)
680
681        self.assertEquals(
682            descriptor.FieldDescriptor.Label.REQUIRED, field.default)
683
684    def testDefaultFields_EnumOkIfTypeKnown(self):
685        """Test enum fields accept valid default values when type is known."""
686        field = messages.EnumField(descriptor.FieldDescriptor.Label,
687                                   1,
688                                   default='REPEATED')
689
690        self.assertEquals(
691            descriptor.FieldDescriptor.Label.REPEATED, field.default)
692
693    def testDefaultFields_EnumForceCheckIfTypeKnown(self):
694        """Test that enum fields validate default values if type is known."""
695        self.assertRaisesWithRegexpMatch(TypeError,
696                                         'No such value for NOT_A_LABEL in '
697                                         'Enum Label',
698                                         messages.EnumField,
699                                         descriptor.FieldDescriptor.Label,
700                                         1,
701                                         default='NOT_A_LABEL')
702
703    def testDefaultFields_EnumInvalidDelayedResolution(self):
704        """Test that enum fields raise errors upon delayed resolution error."""
705        field = messages.EnumField(
706            'apitools.base.protorpclite.descriptor.FieldDescriptor.Label',
707            1,
708            default=200)
709
710        self.assertRaisesWithRegexpMatch(TypeError,
711                                         'No such value for 200 in Enum Label',
712                                         getattr,
713                                         field,
714                                         'default')
715
716    def testValidate_Valid(self):
717        """Test validation of valid values."""
718        values = {
719            messages.IntegerField: 10,
720            messages.FloatField: 1.5,
721            messages.BooleanField: False,
722            messages.BytesField: b'abc',
723            messages.StringField: u'abc',
724        }
725
726        def action(field_class):
727            # Optional.
728            field = field_class(1)
729            field.validate(values[field_class])
730
731            # Required.
732            field = field_class(1, required=True)
733            field.validate(values[field_class])
734
735            # Repeated.
736            field = field_class(1, repeated=True)
737            field.validate([])
738            field.validate(())
739            field.validate([values[field_class]])
740            field.validate((values[field_class],))
741
742            # Right value, but not repeated.
743            self.assertRaises(messages.ValidationError,
744                              field.validate,
745                              values[field_class])
746            self.assertRaises(messages.ValidationError,
747                              field.validate,
748                              values[field_class])
749
750        self.ActionOnAllFieldClasses(action)
751
752    def testValidate_Invalid(self):
753        """Test validation of valid values."""
754        values = {
755            messages.IntegerField: "10",
756            messages.FloatField: "blah",
757            messages.BooleanField: 0,
758            messages.BytesField: 10.20,
759            messages.StringField: 42,
760        }
761
762        def action(field_class):
763            # Optional.
764            field = field_class(1)
765            self.assertRaises(messages.ValidationError,
766                              field.validate,
767                              values[field_class])
768
769            # Required.
770            field = field_class(1, required=True)
771            self.assertRaises(messages.ValidationError,
772                              field.validate,
773                              values[field_class])
774
775            # Repeated.
776            field = field_class(1, repeated=True)
777            self.assertRaises(messages.ValidationError,
778                              field.validate,
779                              [values[field_class]])
780            self.assertRaises(messages.ValidationError,
781                              field.validate,
782                              (values[field_class],))
783        self.ActionOnAllFieldClasses(action)
784
785    def testValidate_None(self):
786        """Test that None is valid for non-required fields."""
787        def action(field_class):
788            # Optional.
789            field = field_class(1)
790            field.validate(None)
791
792            # Required.
793            field = field_class(1, required=True)
794            self.assertRaisesWithRegexpMatch(messages.ValidationError,
795                                             'Required field is missing',
796                                             field.validate,
797                                             None)
798
799            # Repeated.
800            field = field_class(1, repeated=True)
801            field.validate(None)
802            self.assertRaisesWithRegexpMatch(
803                messages.ValidationError,
804                'Repeated values for %s may '
805                'not be None' % field_class.__name__,
806                field.validate,
807                [None])
808            self.assertRaises(messages.ValidationError,
809                              field.validate,
810                              (None,))
811        self.ActionOnAllFieldClasses(action)
812
813    def testValidateElement(self):
814        """Test validation of valid values."""
815        values = {
816            messages.IntegerField: (10, -1, 0),
817            messages.FloatField: (1.5, -1.5, 3),  # for json it is all a number
818            messages.BooleanField: (True, False),
819            messages.BytesField: (b'abc',),
820            messages.StringField: (u'abc',),
821        }
822
823        def action(field_class):
824            # Optional.
825            field = field_class(1)
826            for value in values[field_class]:
827                field.validate_element(value)
828
829            # Required.
830            field = field_class(1, required=True)
831            for value in values[field_class]:
832                field.validate_element(value)
833
834            # Repeated.
835            field = field_class(1, repeated=True)
836            self.assertRaises(messages.ValidationError,
837                              field.validate_element,
838                              [])
839            self.assertRaises(messages.ValidationError,
840                              field.validate_element,
841                              ())
842            for value in values[field_class]:
843                field.validate_element(value)
844
845            # Right value, but repeated.
846            self.assertRaises(messages.ValidationError,
847                              field.validate_element,
848                              list(values[field_class]))  # testing list
849            self.assertRaises(messages.ValidationError,
850                              field.validate_element,
851                              values[field_class])  # testing tuple
852
853        self.ActionOnAllFieldClasses(action)
854
855    def testValidateCastingElement(self):
856        field = messages.FloatField(1)
857        self.assertEquals(type(field.validate_element(12)), float)
858        self.assertEquals(type(field.validate_element(12.0)), float)
859        field = messages.IntegerField(1)
860        self.assertEquals(type(field.validate_element(12)), int)
861        self.assertRaises(messages.ValidationError,
862                          field.validate_element,
863                          12.0)  # should fails from float to int
864
865    def testReadOnly(self):
866        """Test that objects are all read-only."""
867        def action(field_class):
868            field = field_class(10)
869            self.assertRaises(AttributeError,
870                              setattr,
871                              field,
872                              'number',
873                              20)
874            self.assertRaises(AttributeError,
875                              setattr,
876                              field,
877                              'anything_else',
878                              'whatever')
879        self.ActionOnAllFieldClasses(action)
880
881    def testMessageField(self):
882        """Test the construction of message fields."""
883        self.assertRaises(messages.FieldDefinitionError,
884                          messages.MessageField,
885                          str,
886                          10)
887
888        self.assertRaises(messages.FieldDefinitionError,
889                          messages.MessageField,
890                          messages.Message,
891                          10)
892
893        class MyMessage(messages.Message):
894            pass
895
896        field = messages.MessageField(MyMessage, 10)
897        self.assertEquals(MyMessage, field.type)
898
899    def testMessageField_ForwardReference(self):
900        """Test the construction of forward reference message fields."""
901        global MyMessage
902        global ForwardMessage
903        try:
904            class MyMessage(messages.Message):
905
906                self_reference = messages.MessageField('MyMessage', 1)
907                forward = messages.MessageField('ForwardMessage', 2)
908                nested = messages.MessageField(
909                    'ForwardMessage.NestedMessage', 3)
910                inner = messages.MessageField('Inner', 4)
911
912                class Inner(messages.Message):
913
914                    sibling = messages.MessageField('Sibling', 1)
915
916                class Sibling(messages.Message):
917
918                    pass
919
920            class ForwardMessage(messages.Message):
921
922                class NestedMessage(messages.Message):
923
924                    pass
925
926            self.assertEquals(MyMessage,
927                              MyMessage.field_by_name('self_reference').type)
928
929            self.assertEquals(ForwardMessage,
930                              MyMessage.field_by_name('forward').type)
931
932            self.assertEquals(ForwardMessage.NestedMessage,
933                              MyMessage.field_by_name('nested').type)
934
935            self.assertEquals(MyMessage.Inner,
936                              MyMessage.field_by_name('inner').type)
937
938            self.assertEquals(MyMessage.Sibling,
939                              MyMessage.Inner.field_by_name('sibling').type)
940        finally:
941            try:
942                del MyMessage
943                del ForwardMessage
944            except:  # pylint:disable=bare-except
945                pass
946
947    def testMessageField_WrongType(self):
948        """Test that forward referencing the wrong type raises an error."""
949        global AnEnum
950        try:
951            class AnEnum(messages.Enum):
952                pass
953
954            class AnotherMessage(messages.Message):
955
956                a_field = messages.MessageField('AnEnum', 1)
957
958            self.assertRaises(messages.FieldDefinitionError,
959                              getattr,
960                              AnotherMessage.field_by_name('a_field'),
961                              'type')
962        finally:
963            del AnEnum
964
965    def testMessageFieldValidate(self):
966        """Test validation on message field."""
967        class MyMessage(messages.Message):
968            pass
969
970        class AnotherMessage(messages.Message):
971            pass
972
973        field = messages.MessageField(MyMessage, 10)
974        field.validate(MyMessage())
975
976        self.assertRaises(messages.ValidationError,
977                          field.validate,
978                          AnotherMessage())
979
980    def testMessageFieldMessageType(self):
981        """Test message_type property."""
982        class MyMessage(messages.Message):
983            pass
984
985        class HasMessage(messages.Message):
986            field = messages.MessageField(MyMessage, 1)
987
988        self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
989
990    def testMessageFieldValueFromMessage(self):
991        class MyMessage(messages.Message):
992            pass
993
994        class HasMessage(messages.Message):
995            field = messages.MessageField(MyMessage, 1)
996
997        instance = MyMessage()
998
999        self.assertTrue(
1000            instance is HasMessage.field.value_from_message(instance))
1001
1002    def testMessageFieldValueFromMessageWrongType(self):
1003        class MyMessage(messages.Message):
1004            pass
1005
1006        class HasMessage(messages.Message):
1007            field = messages.MessageField(MyMessage, 1)
1008
1009        self.assertRaisesWithRegexpMatch(
1010            messages.DecodeError,
1011            'Expected type MyMessage, got int: 10',
1012            HasMessage.field.value_from_message, 10)
1013
1014    def testMessageFieldValueToMessage(self):
1015        class MyMessage(messages.Message):
1016            pass
1017
1018        class HasMessage(messages.Message):
1019            field = messages.MessageField(MyMessage, 1)
1020
1021        instance = MyMessage()
1022
1023        self.assertTrue(
1024            instance is HasMessage.field.value_to_message(instance))
1025
1026    def testMessageFieldValueToMessageWrongType(self):
1027        class MyMessage(messages.Message):
1028            pass
1029
1030        class MyOtherMessage(messages.Message):
1031            pass
1032
1033        class HasMessage(messages.Message):
1034            field = messages.MessageField(MyMessage, 1)
1035
1036        instance = MyOtherMessage()
1037
1038        self.assertRaisesWithRegexpMatch(
1039            messages.EncodeError,
1040            'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
1041            HasMessage.field.value_to_message, instance)
1042
1043    def testIntegerField_AllowLong(self):
1044        """Test that the integer field allows for longs."""
1045        if six.PY2:
1046            messages.IntegerField(10, default=long(10))
1047
1048    def testMessageFieldValidate_Initialized(self):
1049        """Test validation on message field."""
1050        class MyMessage(messages.Message):
1051            field1 = messages.IntegerField(1, required=True)
1052
1053        field = messages.MessageField(MyMessage, 10)
1054
1055        # Will validate messages where is_initialized() is False.
1056        message = MyMessage()
1057        field.validate(message)
1058        message.field1 = 20
1059        field.validate(message)
1060
1061    def testEnumField(self):
1062        """Test the construction of enum fields."""
1063        self.assertRaises(messages.FieldDefinitionError,
1064                          messages.EnumField,
1065                          str,
1066                          10)
1067
1068        self.assertRaises(messages.FieldDefinitionError,
1069                          messages.EnumField,
1070                          messages.Enum,
1071                          10)
1072
1073        class Color(messages.Enum):
1074            RED = 1
1075            GREEN = 2
1076            BLUE = 3
1077
1078        field = messages.EnumField(Color, 10)
1079        self.assertEquals(Color, field.type)
1080
1081        class Another(messages.Enum):
1082            VALUE = 1
1083
1084        self.assertRaises(messages.InvalidDefaultError,
1085                          messages.EnumField,
1086                          Color,
1087                          10,
1088                          default=Another.VALUE)
1089
1090    def testEnumField_ForwardReference(self):
1091        """Test the construction of forward reference enum fields."""
1092        global MyMessage
1093        global ForwardEnum
1094        global ForwardMessage
1095        try:
1096            class MyMessage(messages.Message):
1097
1098                forward = messages.EnumField('ForwardEnum', 1)
1099                nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
1100                inner = messages.EnumField('Inner', 3)
1101
1102                class Inner(messages.Enum):
1103                    pass
1104
1105            class ForwardEnum(messages.Enum):
1106                pass
1107
1108            class ForwardMessage(messages.Message):
1109
1110                class NestedEnum(messages.Enum):
1111                    pass
1112
1113            self.assertEquals(ForwardEnum,
1114                              MyMessage.field_by_name('forward').type)
1115
1116            self.assertEquals(ForwardMessage.NestedEnum,
1117                              MyMessage.field_by_name('nested').type)
1118
1119            self.assertEquals(MyMessage.Inner,
1120                              MyMessage.field_by_name('inner').type)
1121        finally:
1122            try:
1123                del MyMessage
1124                del ForwardEnum
1125                del ForwardMessage
1126            except:  # pylint:disable=bare-except
1127                pass
1128
1129    def testEnumField_WrongType(self):
1130        """Test that forward referencing the wrong type raises an error."""
1131        global AMessage
1132        try:
1133            class AMessage(messages.Message):
1134                pass
1135
1136            class AnotherMessage(messages.Message):
1137
1138                a_field = messages.EnumField('AMessage', 1)
1139
1140            self.assertRaises(messages.FieldDefinitionError,
1141                              getattr,
1142                              AnotherMessage.field_by_name('a_field'),
1143                              'type')
1144        finally:
1145            del AMessage
1146
1147    def testMessageDefinition(self):
1148        """Test that message definition is set on fields."""
1149        class MyMessage(messages.Message):
1150
1151            my_field = messages.StringField(1)
1152
1153        self.assertEquals(
1154            MyMessage,
1155            MyMessage.field_by_name('my_field').message_definition())
1156
1157    def testNoneAssignment(self):
1158        """Test that assigning None does not change comparison."""
1159        class MyMessage(messages.Message):
1160
1161            my_field = messages.StringField(1)
1162
1163        m1 = MyMessage()
1164        m2 = MyMessage()
1165        m2.my_field = None
1166        self.assertEquals(m1, m2)
1167
1168    def testNonUtf8Str(self):
1169        """Test validation fails for non-UTF-8 StringField values."""
1170        class Thing(messages.Message):
1171            string_field = messages.StringField(2)
1172
1173        thing = Thing()
1174        self.assertRaisesWithRegexpMatch(
1175            messages.ValidationError,
1176            'Field string_field encountered non-UTF-8 string',
1177            setattr, thing, 'string_field', test_util.BINARY)
1178
1179
1180class MessageTest(test_util.TestCase):
1181    """Tests for message class."""
1182
1183    def CreateMessageClass(self):
1184        """Creates a simple message class with 3 fields.
1185
1186        Fields are defined in alphabetical order but with conflicting numeric
1187        order.
1188        """
1189        class ComplexMessage(messages.Message):
1190            a3 = messages.IntegerField(3)
1191            b1 = messages.StringField(1)
1192            c2 = messages.StringField(2)
1193
1194        return ComplexMessage
1195
1196    def testSameNumbers(self):
1197        """Test that cannot assign two fields with same numbers."""
1198
1199        def action():
1200            class BadMessage(messages.Message):
1201                f1 = messages.IntegerField(1)
1202                f2 = messages.IntegerField(1)
1203        self.assertRaises(messages.DuplicateNumberError,
1204                          action)
1205
1206    def testStrictAssignment(self):
1207        """Tests that cannot assign to unknown or non-reserved attributes."""
1208        class SimpleMessage(messages.Message):
1209            field = messages.IntegerField(1)
1210
1211        simple_message = SimpleMessage()
1212        self.assertRaises(AttributeError,
1213                          setattr,
1214                          simple_message,
1215                          'does_not_exist',
1216                          10)
1217
1218    def testListAssignmentDoesNotCopy(self):
1219        class SimpleMessage(messages.Message):
1220            repeated = messages.IntegerField(1, repeated=True)
1221
1222        message = SimpleMessage()
1223        original = message.repeated
1224        message.repeated = []
1225        self.assertFalse(original is message.repeated)
1226
1227    def testValidate_Optional(self):
1228        """Tests validation of optional fields."""
1229        class SimpleMessage(messages.Message):
1230            non_required = messages.IntegerField(1)
1231
1232        simple_message = SimpleMessage()
1233        simple_message.check_initialized()
1234        simple_message.non_required = 10
1235        simple_message.check_initialized()
1236
1237    def testValidate_Required(self):
1238        """Tests validation of required fields."""
1239        class SimpleMessage(messages.Message):
1240            required = messages.IntegerField(1, required=True)
1241
1242        simple_message = SimpleMessage()
1243        self.assertRaises(messages.ValidationError,
1244                          simple_message.check_initialized)
1245        simple_message.required = 10
1246        simple_message.check_initialized()
1247
1248    def testValidate_Repeated(self):
1249        """Tests validation of repeated fields."""
1250        class SimpleMessage(messages.Message):
1251            repeated = messages.IntegerField(1, repeated=True)
1252
1253        simple_message = SimpleMessage()
1254
1255        # Check valid values.
1256        for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
1257            simple_message.repeated = valid_value
1258            simple_message.check_initialized()
1259
1260        # Check cleared.
1261        simple_message.repeated = []
1262        simple_message.check_initialized()
1263
1264        # Check invalid values.
1265        for invalid_value in 10, ['10', '20'], [None], (None,):
1266            self.assertRaises(
1267                messages.ValidationError,
1268                setattr, simple_message, 'repeated', invalid_value)
1269
1270    def testIsInitialized(self):
1271        """Tests is_initialized."""
1272        class SimpleMessage(messages.Message):
1273            required = messages.IntegerField(1, required=True)
1274
1275        simple_message = SimpleMessage()
1276        self.assertFalse(simple_message.is_initialized())
1277
1278        simple_message.required = 10
1279
1280        self.assertTrue(simple_message.is_initialized())
1281
1282    def testIsInitializedNestedField(self):
1283        """Tests is_initialized for nested fields."""
1284        class SimpleMessage(messages.Message):
1285            required = messages.IntegerField(1, required=True)
1286
1287        class NestedMessage(messages.Message):
1288            simple = messages.MessageField(SimpleMessage, 1)
1289
1290        simple_message = SimpleMessage()
1291        self.assertFalse(simple_message.is_initialized())
1292        nested_message = NestedMessage(simple=simple_message)
1293        self.assertFalse(nested_message.is_initialized())
1294
1295        simple_message.required = 10
1296
1297        self.assertTrue(simple_message.is_initialized())
1298        self.assertTrue(nested_message.is_initialized())
1299
1300    def testInitializeNestedFieldFromDict(self):
1301        """Tests initializing nested fields from dict."""
1302        class SimpleMessage(messages.Message):
1303            required = messages.IntegerField(1, required=True)
1304
1305        class NestedMessage(messages.Message):
1306            simple = messages.MessageField(SimpleMessage, 1)
1307
1308        class RepeatedMessage(messages.Message):
1309            simple = messages.MessageField(SimpleMessage, 1, repeated=True)
1310
1311        nested_message1 = NestedMessage(simple={'required': 10})
1312        self.assertTrue(nested_message1.is_initialized())
1313        self.assertTrue(nested_message1.simple.is_initialized())
1314
1315        nested_message2 = NestedMessage()
1316        nested_message2.simple = {'required': 10}
1317        self.assertTrue(nested_message2.is_initialized())
1318        self.assertTrue(nested_message2.simple.is_initialized())
1319
1320        repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
1321
1322        repeated_message1 = RepeatedMessage(simple=repeated_values)
1323        self.assertEquals(3, len(repeated_message1.simple))
1324        self.assertFalse(repeated_message1.is_initialized())
1325
1326        repeated_message1.simple[0].required = 0
1327        self.assertTrue(repeated_message1.is_initialized())
1328
1329        repeated_message2 = RepeatedMessage()
1330        repeated_message2.simple = repeated_values
1331        self.assertEquals(3, len(repeated_message2.simple))
1332        self.assertFalse(repeated_message2.is_initialized())
1333
1334        repeated_message2.simple[0].required = 0
1335        self.assertTrue(repeated_message2.is_initialized())
1336
1337    def testNestedMethodsNotAllowed(self):
1338        """Test that method definitions on Message classes are not allowed."""
1339        def action():
1340            class WithMethods(messages.Message):
1341
1342                def not_allowed(self):
1343                    pass
1344
1345        self.assertRaises(messages.MessageDefinitionError,
1346                          action)
1347
1348    def testNestedAttributesNotAllowed(self):
1349        """Test attribute assignment on Message classes is not allowed."""
1350        def int_attribute():
1351            class WithMethods(messages.Message):
1352                not_allowed = 1
1353
1354        def string_attribute():
1355            class WithMethods(messages.Message):
1356                not_allowed = 'not allowed'
1357
1358        def enum_attribute():
1359            class WithMethods(messages.Message):
1360                not_allowed = Color.RED
1361
1362        for action in (int_attribute, string_attribute, enum_attribute):
1363            self.assertRaises(messages.MessageDefinitionError,
1364                              action)
1365
1366    def testNameIsSetOnFields(self):
1367        """Make sure name is set on fields after Message class init."""
1368        class HasNamedFields(messages.Message):
1369            field = messages.StringField(1)
1370
1371        self.assertEquals('field', HasNamedFields.field_by_number(1).name)
1372
1373    def testSubclassingMessageDisallowed(self):
1374        """Not permitted to create sub-classes of message classes."""
1375        class SuperClass(messages.Message):
1376            pass
1377
1378        def action():
1379            class SubClass(SuperClass):
1380                pass
1381
1382        self.assertRaises(messages.MessageDefinitionError,
1383                          action)
1384
1385    def testAllFields(self):
1386        """Test all_fields method."""
1387        ComplexMessage = self.CreateMessageClass()
1388        fields = list(ComplexMessage.all_fields())
1389
1390        # Order does not matter, so sort now.
1391        fields = sorted(fields, key=lambda f: f.name)
1392
1393        self.assertEquals(3, len(fields))
1394        self.assertEquals('a3', fields[0].name)
1395        self.assertEquals('b1', fields[1].name)
1396        self.assertEquals('c2', fields[2].name)
1397
1398    def testFieldByName(self):
1399        """Test getting field by name."""
1400        ComplexMessage = self.CreateMessageClass()
1401
1402        self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
1403        self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
1404        self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
1405
1406        self.assertRaises(KeyError,
1407                          ComplexMessage.field_by_name,
1408                          'unknown')
1409
1410    def testFieldByNumber(self):
1411        """Test getting field by number."""
1412        ComplexMessage = self.CreateMessageClass()
1413
1414        self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
1415        self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
1416        self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
1417
1418        self.assertRaises(KeyError,
1419                          ComplexMessage.field_by_number,
1420                          4)
1421
1422    def testGetAssignedValue(self):
1423        """Test getting the assigned value of a field."""
1424        class SomeMessage(messages.Message):
1425            a_value = messages.StringField(1, default=u'a default')
1426
1427        message = SomeMessage()
1428        self.assertEquals(None, message.get_assigned_value('a_value'))
1429
1430        message.a_value = u'a string'
1431        self.assertEquals(u'a string', message.get_assigned_value('a_value'))
1432
1433        message.a_value = u'a default'
1434        self.assertEquals(u'a default', message.get_assigned_value('a_value'))
1435
1436        self.assertRaisesWithRegexpMatch(
1437            AttributeError,
1438            'Message SomeMessage has no field no_such_field',
1439            message.get_assigned_value,
1440            'no_such_field')
1441
1442    def testReset(self):
1443        """Test resetting a field value."""
1444        class SomeMessage(messages.Message):
1445            a_value = messages.StringField(1, default=u'a default')
1446            repeated = messages.IntegerField(2, repeated=True)
1447
1448        message = SomeMessage()
1449
1450        self.assertRaises(AttributeError, message.reset, 'unknown')
1451
1452        self.assertEquals(u'a default', message.a_value)
1453        message.reset('a_value')
1454        self.assertEquals(u'a default', message.a_value)
1455
1456        message.a_value = u'a new value'
1457        self.assertEquals(u'a new value', message.a_value)
1458        message.reset('a_value')
1459        self.assertEquals(u'a default', message.a_value)
1460
1461        message.repeated = [1, 2, 3]
1462        self.assertEquals([1, 2, 3], message.repeated)
1463        saved = message.repeated
1464        message.reset('repeated')
1465        self.assertEquals([], message.repeated)
1466        self.assertIsInstance(message.repeated, messages.FieldList)
1467        self.assertEquals([1, 2, 3], saved)
1468
1469    def testAllowNestedEnums(self):
1470        """Test allowing nested enums in a message definition."""
1471        class Trade(messages.Message):
1472
1473            class Duration(messages.Enum):
1474                GTC = 1
1475                DAY = 2
1476
1477            class Currency(messages.Enum):
1478                USD = 1
1479                GBP = 2
1480                INR = 3
1481
1482        # Sorted by name order seems to be the only feasible option.
1483        self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
1484
1485        # Message definition will now be set on Enumerated objects.
1486        self.assertEquals(Trade, Trade.Duration.message_definition())
1487
1488    def testAllowNestedMessages(self):
1489        """Test allowing nested messages in a message definition."""
1490        class Trade(messages.Message):
1491
1492            class Lot(messages.Message):
1493                pass
1494
1495            class Agent(messages.Message):
1496                pass
1497
1498        # Sorted by name order seems to be the only feasible option.
1499        self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
1500        self.assertEquals(Trade, Trade.Agent.message_definition())
1501        self.assertEquals(Trade, Trade.Lot.message_definition())
1502
1503        # But not Message itself.
1504        def action():
1505            class Trade(messages.Message):
1506                NiceTry = messages.Message
1507        self.assertRaises(messages.MessageDefinitionError, action)
1508
1509    def testDisallowClassAssignments(self):
1510        """Test setting class attributes may not happen."""
1511        class MyMessage(messages.Message):
1512            pass
1513
1514        self.assertRaises(AttributeError,
1515                          setattr,
1516                          MyMessage,
1517                          'x',
1518                          'do not assign')
1519
1520    def testEquality(self):
1521        """Test message class equality."""
1522        # Comparison against enums must work.
1523        class MyEnum(messages.Enum):
1524            val1 = 1
1525            val2 = 2
1526
1527        # Comparisons against nested messages must work.
1528        class AnotherMessage(messages.Message):
1529            string = messages.StringField(1)
1530
1531        class MyMessage(messages.Message):
1532            field1 = messages.IntegerField(1)
1533            field2 = messages.EnumField(MyEnum, 2)
1534            field3 = messages.MessageField(AnotherMessage, 3)
1535
1536        message1 = MyMessage()
1537
1538        self.assertNotEquals('hi', message1)
1539        self.assertNotEquals(AnotherMessage(), message1)
1540        self.assertEquals(message1, message1)
1541
1542        message2 = MyMessage()
1543
1544        self.assertEquals(message1, message2)
1545
1546        message1.field1 = 10
1547        self.assertNotEquals(message1, message2)
1548
1549        message2.field1 = 20
1550        self.assertNotEquals(message1, message2)
1551
1552        message2.field1 = 10
1553        self.assertEquals(message1, message2)
1554
1555        message1.field2 = MyEnum.val1
1556        self.assertNotEquals(message1, message2)
1557
1558        message2.field2 = MyEnum.val2
1559        self.assertNotEquals(message1, message2)
1560
1561        message2.field2 = MyEnum.val1
1562        self.assertEquals(message1, message2)
1563
1564        message1.field3 = AnotherMessage()
1565        message1.field3.string = 'value1'
1566        self.assertNotEquals(message1, message2)
1567
1568        message2.field3 = AnotherMessage()
1569        message2.field3.string = 'value2'
1570        self.assertNotEquals(message1, message2)
1571
1572        message2.field3.string = 'value1'
1573        self.assertEquals(message1, message2)
1574
1575    def testEqualityWithUnknowns(self):
1576        """Test message class equality with unknown fields."""
1577
1578        class MyMessage(messages.Message):
1579            field1 = messages.IntegerField(1)
1580
1581        message1 = MyMessage()
1582        message2 = MyMessage()
1583        self.assertEquals(message1, message2)
1584        message1.set_unrecognized_field('unknown1', 'value1',
1585                                        messages.Variant.STRING)
1586        self.assertEquals(message1, message2)
1587
1588        message1.set_unrecognized_field('unknown2', ['asdf', 3],
1589                                        messages.Variant.STRING)
1590        message1.set_unrecognized_field('unknown3', 4.7,
1591                                        messages.Variant.DOUBLE)
1592        self.assertEquals(message1, message2)
1593
1594    def testUnrecognizedFieldInvalidVariant(self):
1595        class MyMessage(messages.Message):
1596            field1 = messages.IntegerField(1)
1597
1598        message1 = MyMessage()
1599        self.assertRaises(
1600            TypeError, message1.set_unrecognized_field, 'unknown4',
1601            {'unhandled': 'type'}, None)
1602        self.assertRaises(
1603            TypeError, message1.set_unrecognized_field, 'unknown4',
1604            {'unhandled': 'type'}, 123)
1605
1606    def testRepr(self):
1607        """Test represtation of Message object."""
1608        class MyMessage(messages.Message):
1609            integer_value = messages.IntegerField(1)
1610            string_value = messages.StringField(2)
1611            unassigned = messages.StringField(3)
1612            unassigned_with_default = messages.StringField(
1613                4, default=u'a default')
1614
1615        my_message = MyMessage()
1616        my_message.integer_value = 42
1617        my_message.string_value = u'A string'
1618
1619        pat = re.compile(r"<MyMessage\n integer_value: 42\n"
1620                         " string_value: [u]?'A string'>")
1621        self.assertTrue(pat.match(repr(my_message)) is not None)
1622
1623    def testValidation(self):
1624        """Test validation of message values."""
1625        # Test optional.
1626        class SubMessage(messages.Message):
1627            pass
1628
1629        class Message(messages.Message):
1630            val = messages.MessageField(SubMessage, 1)
1631
1632        message = Message()
1633
1634        message_field = messages.MessageField(Message, 1)
1635        message_field.validate(message)
1636        message.val = SubMessage()
1637        message_field.validate(message)
1638        self.assertRaises(messages.ValidationError,
1639                          setattr, message, 'val', [SubMessage()])
1640
1641        # Test required.
1642        class Message(messages.Message):
1643            val = messages.MessageField(SubMessage, 1, required=True)
1644
1645        message = Message()
1646
1647        message_field = messages.MessageField(Message, 1)
1648        message_field.validate(message)
1649        message.val = SubMessage()
1650        message_field.validate(message)
1651        self.assertRaises(messages.ValidationError,
1652                          setattr, message, 'val', [SubMessage()])
1653
1654        # Test repeated.
1655        class Message(messages.Message):
1656            val = messages.MessageField(SubMessage, 1, repeated=True)
1657
1658        message = Message()
1659
1660        message_field = messages.MessageField(Message, 1)
1661        message_field.validate(message)
1662        self.assertRaisesWithRegexpMatch(
1663            messages.ValidationError,
1664            "Field val is repeated. Found: <SubMessage>",
1665            setattr, message, 'val', SubMessage())
1666        message.val = [SubMessage()]
1667        message_field.validate(message)
1668
1669    def testDefinitionName(self):
1670        """Test message name."""
1671        class MyMessage(messages.Message):
1672            pass
1673
1674        module_name = test_util.get_module_name(FieldTest)
1675        self.assertEquals('%s.MyMessage' % module_name,
1676                          MyMessage.definition_name())
1677        self.assertEquals(module_name, MyMessage.outer_definition_name())
1678        self.assertEquals(module_name, MyMessage.definition_package())
1679
1680        self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1681        self.assertEquals(six.text_type, type(
1682            MyMessage.outer_definition_name()))
1683        self.assertEquals(six.text_type, type(MyMessage.definition_package()))
1684
1685    def testDefinitionName_OverrideModule(self):
1686        """Test message module is overriden by module package name."""
1687        class MyMessage(messages.Message):
1688            pass
1689
1690        global package
1691        package = 'my.package'
1692
1693        try:
1694            self.assertEquals('my.package.MyMessage',
1695                              MyMessage.definition_name())
1696            self.assertEquals('my.package', MyMessage.outer_definition_name())
1697            self.assertEquals('my.package', MyMessage.definition_package())
1698
1699            self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1700            self.assertEquals(six.text_type, type(
1701                MyMessage.outer_definition_name()))
1702            self.assertEquals(six.text_type, type(
1703                MyMessage.definition_package()))
1704        finally:
1705            del package
1706
1707    def testDefinitionName_NoModule(self):
1708        """Test what happens when there is no module for message."""
1709        class MyMessage(messages.Message):
1710            pass
1711
1712        original_modules = sys.modules
1713        sys.modules = dict(sys.modules)
1714        try:
1715            del sys.modules[__name__]
1716            self.assertEquals('MyMessage', MyMessage.definition_name())
1717            self.assertEquals(None, MyMessage.outer_definition_name())
1718            self.assertEquals(None, MyMessage.definition_package())
1719
1720            self.assertEquals(six.text_type, type(MyMessage.definition_name()))
1721        finally:
1722            sys.modules = original_modules
1723
1724    def testDefinitionName_Nested(self):
1725        """Test nested message names."""
1726        class MyMessage(messages.Message):
1727
1728            class NestedMessage(messages.Message):
1729
1730                class NestedMessage(messages.Message):
1731
1732                    pass
1733
1734        module_name = test_util.get_module_name(MessageTest)
1735        self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
1736                          MyMessage.NestedMessage.definition_name())
1737        self.assertEquals('%s.MyMessage' % module_name,
1738                          MyMessage.NestedMessage.outer_definition_name())
1739        self.assertEquals(module_name,
1740                          MyMessage.NestedMessage.definition_package())
1741
1742        self.assertEquals(
1743            '%s.MyMessage.NestedMessage.NestedMessage' % module_name,
1744            MyMessage.NestedMessage.NestedMessage.definition_name())
1745        self.assertEquals(
1746            '%s.MyMessage.NestedMessage' % module_name,
1747            MyMessage.NestedMessage.NestedMessage.outer_definition_name())
1748        self.assertEquals(
1749            module_name,
1750            MyMessage.NestedMessage.NestedMessage.definition_package())
1751
1752    def testMessageDefinition(self):
1753        """Test that enumeration knows its enclosing message definition."""
1754        class OuterMessage(messages.Message):
1755
1756            class InnerMessage(messages.Message):
1757                pass
1758
1759        self.assertEquals(None, OuterMessage.message_definition())
1760        self.assertEquals(OuterMessage,
1761                          OuterMessage.InnerMessage.message_definition())
1762
1763    def testConstructorKwargs(self):
1764        """Test kwargs via constructor."""
1765        class SomeMessage(messages.Message):
1766            name = messages.StringField(1)
1767            number = messages.IntegerField(2)
1768
1769        expected = SomeMessage()
1770        expected.name = 'my name'
1771        expected.number = 200
1772        self.assertEquals(expected, SomeMessage(name='my name', number=200))
1773
1774    def testConstructorNotAField(self):
1775        """Test kwargs via constructor with wrong names."""
1776        class SomeMessage(messages.Message):
1777            pass
1778
1779        self.assertRaisesWithRegexpMatch(
1780            AttributeError,
1781            ('May not assign arbitrary value does_not_exist to message '
1782             'SomeMessage'),
1783            SomeMessage,
1784            does_not_exist=10)
1785
1786    def testGetUnsetRepeatedValue(self):
1787        class SomeMessage(messages.Message):
1788            repeated = messages.IntegerField(1, repeated=True)
1789
1790        instance = SomeMessage()
1791        self.assertEquals([], instance.repeated)
1792        self.assertTrue(isinstance(instance.repeated, messages.FieldList))
1793
1794    def testCompareAutoInitializedRepeatedFields(self):
1795        class SomeMessage(messages.Message):
1796            repeated = messages.IntegerField(1, repeated=True)
1797
1798        message1 = SomeMessage(repeated=[])
1799        message2 = SomeMessage()
1800        self.assertEquals(message1, message2)
1801
1802    def testUnknownValues(self):
1803        """Test message class equality with unknown fields."""
1804        class MyMessage(messages.Message):
1805            field1 = messages.IntegerField(1)
1806
1807        message = MyMessage()
1808        self.assertEquals([], message.all_unrecognized_fields())
1809        self.assertEquals((None, None),
1810                          message.get_unrecognized_field_info('doesntexist'))
1811        self.assertEquals((None, None),
1812                          message.get_unrecognized_field_info(
1813                              'doesntexist', None, None))
1814        self.assertEquals(('defaultvalue', 'defaultwire'),
1815                          message.get_unrecognized_field_info(
1816                              'doesntexist', 'defaultvalue', 'defaultwire'))
1817        self.assertEquals((3, None),
1818                          message.get_unrecognized_field_info(
1819                              'doesntexist', value_default=3))
1820
1821        message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
1822        self.assertEquals(1, len(message.all_unrecognized_fields()))
1823        self.assertTrue('exists' in message.all_unrecognized_fields())
1824        self.assertEquals((9.5, messages.Variant.DOUBLE),
1825                          message.get_unrecognized_field_info('exists'))
1826        self.assertEquals((9.5, messages.Variant.DOUBLE),
1827                          message.get_unrecognized_field_info('exists', 'type',
1828                                                              1234))
1829        self.assertEquals(
1830            (1234, None),
1831            message.get_unrecognized_field_info('doesntexist', 1234))
1832
1833        message.set_unrecognized_field(
1834            'another', 'value', messages.Variant.STRING)
1835        self.assertEquals(2, len(message.all_unrecognized_fields()))
1836        self.assertTrue('exists' in message.all_unrecognized_fields())
1837        self.assertTrue('another' in message.all_unrecognized_fields())
1838        self.assertEquals((9.5, messages.Variant.DOUBLE),
1839                          message.get_unrecognized_field_info('exists'))
1840        self.assertEquals(('value', messages.Variant.STRING),
1841                          message.get_unrecognized_field_info('another'))
1842
1843        message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
1844                                       messages.Variant.STRING)
1845        self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1846                          message.get_unrecognized_field_info('typetest1'))
1847        message.set_unrecognized_field(
1848            'typetest2', '', messages.Variant.STRING)
1849        self.assertEquals(('', messages.Variant.STRING),
1850                          message.get_unrecognized_field_info('typetest2'))
1851
1852    def testPickle(self):
1853        """Testing pickling and unpickling of Message instances."""
1854        global MyEnum
1855        global AnotherMessage
1856        global MyMessage
1857
1858        class MyEnum(messages.Enum):
1859            val1 = 1
1860            val2 = 2
1861
1862        class AnotherMessage(messages.Message):
1863            string = messages.StringField(1, repeated=True)
1864
1865        class MyMessage(messages.Message):
1866            field1 = messages.IntegerField(1)
1867            field2 = messages.EnumField(MyEnum, 2)
1868            field3 = messages.MessageField(AnotherMessage, 3)
1869
1870        message = MyMessage(field1=1, field2=MyEnum.val2,
1871                            field3=AnotherMessage(string=['a', 'b', 'c']))
1872        message.set_unrecognized_field(
1873            'exists', 'value', messages.Variant.STRING)
1874        message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
1875                                       messages.Variant.STRING)
1876        unpickled = pickle.loads(pickle.dumps(message))
1877        self.assertEquals(message, unpickled)
1878        self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
1879        self.assertTrue('exists' in message.all_unrecognized_fields())
1880        self.assertEquals(('value', messages.Variant.STRING),
1881                          message.get_unrecognized_field_info('exists'))
1882        self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
1883                          message.get_unrecognized_field_info('repeated'))
1884
1885
1886class FindDefinitionTest(test_util.TestCase):
1887    """Test finding definitions relative to various definitions and modules."""
1888
1889    def setUp(self):
1890        """Set up module-space.  Starts off empty."""
1891        self.modules = {}
1892
1893    def DefineModule(self, name):
1894        """Define a module and its parents in module space.
1895
1896        Modules that are already defined in self.modules are not re-created.
1897
1898        Args:
1899          name: Fully qualified name of modules to create.
1900
1901        Returns:
1902          Deepest nested module.  For example:
1903
1904            DefineModule('a.b.c')  # Returns c.
1905        """
1906        name_path = name.split('.')
1907        full_path = []
1908        for node in name_path:
1909            full_path.append(node)
1910            full_name = '.'.join(full_path)
1911            self.modules.setdefault(full_name, types.ModuleType(full_name))
1912        return self.modules[name]
1913
1914    def DefineMessage(self, module, name, children=None, add_to_module=True):
1915        """Define a new Message class in the context of a module.
1916
1917        Used for easily describing complex Message hierarchy. Message
1918        is defined including all child definitions.
1919
1920        Args:
1921          module: Fully qualified name of module to place Message class in.
1922          name: Name of Message to define within module.
1923          children: Define any level of nesting of children
1924            definitions. To define a message, map the name to another
1925            dictionary. The dictionary can itself contain additional
1926            definitions, and so on. To map to an Enum, define the Enum
1927            class separately and map it by name.
1928          add_to_module: If True, new Message class is added to
1929            module. If False, new Message is not added.
1930
1931        """
1932        children = children or {}
1933        # Make sure module exists.
1934        module_instance = self.DefineModule(module)
1935
1936        # Recursively define all child messages.
1937        for attribute, value in children.items():
1938            if isinstance(value, dict):
1939                children[attribute] = self.DefineMessage(
1940                    module, attribute, value, False)
1941
1942        # Override default __module__ variable.
1943        children['__module__'] = module
1944
1945        # Instantiate and possibly add to module.
1946        message_class = type(name, (messages.Message,), dict(children))
1947        if add_to_module:
1948            setattr(module_instance, name, message_class)
1949        return message_class
1950
1951    # pylint:disable=unused-argument
1952    # pylint:disable=redefined-builtin
1953    def Importer(self, module, globals='', locals='', fromlist=None):
1954        """Importer function.
1955
1956        Acts like __import__. Only loads modules from self.modules.
1957        Does not try to load real modules defined elsewhere. Does not
1958        try to handle relative imports.
1959
1960        Args:
1961          module: Fully qualified name of module to load from self.modules.
1962
1963        """
1964        if fromlist is None:
1965            module = module.split('.')[0]
1966        try:
1967            return self.modules[module]
1968        except KeyError:
1969            raise ImportError()
1970    # pylint:disable=unused-argument
1971
1972    def testNoSuchModule(self):
1973        """Test searching for definitions that do no exist."""
1974        self.assertRaises(messages.DefinitionNotFoundError,
1975                          messages.find_definition,
1976                          'does.not.exist',
1977                          importer=self.Importer)
1978
1979    def testRefersToModule(self):
1980        """Test that referring to a module does not return that module."""
1981        self.DefineModule('i.am.a.module')
1982        self.assertRaises(messages.DefinitionNotFoundError,
1983                          messages.find_definition,
1984                          'i.am.a.module',
1985                          importer=self.Importer)
1986
1987    def testNoDefinition(self):
1988        """Test not finding a definition in an existing module."""
1989        self.DefineModule('i.am.a.module')
1990        self.assertRaises(messages.DefinitionNotFoundError,
1991                          messages.find_definition,
1992                          'i.am.a.module.MyMessage',
1993                          importer=self.Importer)
1994
1995    def testNotADefinition(self):
1996        """Test trying to fetch something that is not a definition."""
1997        module = self.DefineModule('i.am.a.module')
1998        setattr(module, 'A', 'a string')
1999        self.assertRaises(messages.DefinitionNotFoundError,
2000                          messages.find_definition,
2001                          'i.am.a.module.A',
2002                          importer=self.Importer)
2003
2004    def testGlobalFind(self):
2005        """Test finding definitions from fully qualified module names."""
2006        A = self.DefineMessage('a.b.c', 'A', {})
2007        self.assertEquals(A, messages.find_definition('a.b.c.A',
2008                                                      importer=self.Importer))
2009        B = self.DefineMessage('a.b.c', 'B', {'C': {}})
2010        self.assertEquals(
2011            B.C,
2012            messages.find_definition('a.b.c.B.C', importer=self.Importer))
2013
2014    def testRelativeToModule(self):
2015        """Test finding definitions relative to modules."""
2016        # Define modules.
2017        a = self.DefineModule('a')
2018        b = self.DefineModule('a.b')
2019        c = self.DefineModule('a.b.c')
2020
2021        # Define messages.
2022        A = self.DefineMessage('a', 'A')
2023        B = self.DefineMessage('a.b', 'B')
2024        C = self.DefineMessage('a.b.c', 'C')
2025        D = self.DefineMessage('a.b.d', 'D')
2026
2027        # Find A, B, C and D relative to a.
2028        self.assertEquals(A, messages.find_definition(
2029            'A', a, importer=self.Importer))
2030        self.assertEquals(B, messages.find_definition(
2031            'b.B', a, importer=self.Importer))
2032        self.assertEquals(C, messages.find_definition(
2033            'b.c.C', a, importer=self.Importer))
2034        self.assertEquals(D, messages.find_definition(
2035            'b.d.D', a, importer=self.Importer))
2036
2037        # Find A, B, C and D relative to b.
2038        self.assertEquals(A, messages.find_definition(
2039            'A', b, importer=self.Importer))
2040        self.assertEquals(B, messages.find_definition(
2041            'B', b, importer=self.Importer))
2042        self.assertEquals(C, messages.find_definition(
2043            'c.C', b, importer=self.Importer))
2044        self.assertEquals(D, messages.find_definition(
2045            'd.D', b, importer=self.Importer))
2046
2047        # Find A, B, C and D relative to c.  Module d is the same case as c.
2048        self.assertEquals(A, messages.find_definition(
2049            'A', c, importer=self.Importer))
2050        self.assertEquals(B, messages.find_definition(
2051            'B', c, importer=self.Importer))
2052        self.assertEquals(C, messages.find_definition(
2053            'C', c, importer=self.Importer))
2054        self.assertEquals(D, messages.find_definition(
2055            'd.D', c, importer=self.Importer))
2056
2057    def testRelativeToMessages(self):
2058        """Test finding definitions relative to Message definitions."""
2059        A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
2060        B = A.B
2061        C = A.B.C
2062        D = A.B.D
2063
2064        # Find relative to A.
2065        self.assertEquals(A, messages.find_definition(
2066            'A', A, importer=self.Importer))
2067        self.assertEquals(B, messages.find_definition(
2068            'B', A, importer=self.Importer))
2069        self.assertEquals(C, messages.find_definition(
2070            'B.C', A, importer=self.Importer))
2071        self.assertEquals(D, messages.find_definition(
2072            'B.D', A, importer=self.Importer))
2073
2074        # Find relative to B.
2075        self.assertEquals(A, messages.find_definition(
2076            'A', B, importer=self.Importer))
2077        self.assertEquals(B, messages.find_definition(
2078            'B', B, importer=self.Importer))
2079        self.assertEquals(C, messages.find_definition(
2080            'C', B, importer=self.Importer))
2081        self.assertEquals(D, messages.find_definition(
2082            'D', B, importer=self.Importer))
2083
2084        # Find relative to C.
2085        self.assertEquals(A, messages.find_definition(
2086            'A', C, importer=self.Importer))
2087        self.assertEquals(B, messages.find_definition(
2088            'B', C, importer=self.Importer))
2089        self.assertEquals(C, messages.find_definition(
2090            'C', C, importer=self.Importer))
2091        self.assertEquals(D, messages.find_definition(
2092            'D', C, importer=self.Importer))
2093
2094        # Find relative to C searching from c.
2095        self.assertEquals(A, messages.find_definition(
2096            'b.A', C, importer=self.Importer))
2097        self.assertEquals(B, messages.find_definition(
2098            'b.A.B', C, importer=self.Importer))
2099        self.assertEquals(C, messages.find_definition(
2100            'b.A.B.C', C, importer=self.Importer))
2101        self.assertEquals(D, messages.find_definition(
2102            'b.A.B.D', C, importer=self.Importer))
2103
2104    def testAbsoluteReference(self):
2105        """Test finding absolute definition names."""
2106        # Define modules.
2107        a = self.DefineModule('a')
2108        b = self.DefineModule('a.a')
2109
2110        # Define messages.
2111        aA = self.DefineMessage('a', 'A')
2112        aaA = self.DefineMessage('a.a', 'A')
2113
2114        # Always find a.A.
2115        self.assertEquals(aA, messages.find_definition('.a.A', None,
2116                                                       importer=self.Importer))
2117        self.assertEquals(aA, messages.find_definition('.a.A', a,
2118                                                       importer=self.Importer))
2119        self.assertEquals(aA, messages.find_definition('.a.A', aA,
2120                                                       importer=self.Importer))
2121        self.assertEquals(aA, messages.find_definition('.a.A', aaA,
2122                                                       importer=self.Importer))
2123
2124    def testFindEnum(self):
2125        """Test that Enums are found."""
2126        class Color(messages.Enum):
2127            pass
2128        A = self.DefineMessage('a', 'A', {'Color': Color})
2129
2130        self.assertEquals(
2131            Color,
2132            messages.find_definition('Color', A, importer=self.Importer))
2133
2134    def testFalseScope(self):
2135        """Test Message definitions nested in strange objects are hidden."""
2136        global X
2137
2138        class X(object):
2139
2140            class A(messages.Message):
2141                pass
2142
2143        self.assertRaises(TypeError, messages.find_definition, 'A', X)
2144        self.assertRaises(messages.DefinitionNotFoundError,
2145                          messages.find_definition,
2146                          'X.A', sys.modules[__name__])
2147
2148    def testSearchAttributeFirst(self):
2149        """Make sure not faked out by module, but continues searching."""
2150        A = self.DefineMessage('a', 'A')
2151        module_A = self.DefineModule('a.A')
2152
2153        self.assertEquals(A, messages.find_definition(
2154            'a.A', None, importer=self.Importer))
2155
2156
2157def main():
2158    unittest.main()
2159
2160
2161if __name__ == '__main__':
2162    main()
2163