1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Test utilities for message testing.
19
20Includes module interface test to ensure that public parts of module are
21correctly declared in __all__.
22
23Includes message types that correspond to those defined in
24services_test.proto.
25
26Includes additional test utilities to make sure encoding/decoding libraries
27conform.
28"""
29from six.moves import range
30
31__author__ = 'rafek@google.com (Rafe Kaplan)'
32
33import cgi
34import datetime
35import inspect
36import os
37import re
38import socket
39import types
40import unittest2 as unittest
41
42import six
43
44from . import message_types
45from . import messages
46from . import util
47
48# Unicode of the word "Russian" in cyrillic.
49RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
50
51# All characters binary value interspersed with nulls.
52BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
53
54
55class TestCase(unittest.TestCase):
56
57  def assertRaisesWithRegexpMatch(self,
58                                  exception,
59                                  regexp,
60                                  function,
61                                  *params,
62                                  **kwargs):
63    """Check that exception is raised and text matches regular expression.
64
65    Args:
66      exception: Exception type that is expected.
67      regexp: String regular expression that is expected in error message.
68      function: Callable to test.
69      params: Parameters to forward to function.
70      kwargs: Keyword arguments to forward to function.
71    """
72    try:
73      function(*params, **kwargs)
74      self.fail('Expected exception %s was not raised' % exception.__name__)
75    except exception as err:
76      match = bool(re.match(regexp, str(err)))
77      self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
78                                                                  err))
79
80  def assertHeaderSame(self, header1, header2):
81    """Check that two HTTP headers are the same.
82
83    Args:
84      header1: Header value string 1.
85      header2: header value string 2.
86    """
87    value1, params1 = cgi.parse_header(header1)
88    value2, params2 = cgi.parse_header(header2)
89    self.assertEqual(value1, value2)
90    self.assertEqual(params1, params2)
91
92  def assertIterEqual(self, iter1, iter2):
93    """Check that two iterators or iterables are equal independent of order.
94
95    Similar to Python 2.7 assertItemsEqual.  Named differently in order to
96    avoid potential conflict.
97
98    Args:
99      iter1: An iterator or iterable.
100      iter2: An iterator or iterable.
101    """
102    list1 = list(iter1)
103    list2 = list(iter2)
104
105    unmatched1 = list()
106
107    while list1:
108      item1 = list1[0]
109      del list1[0]
110      for index in range(len(list2)):
111        if item1 == list2[index]:
112          del list2[index]
113          break
114      else:
115        unmatched1.append(item1)
116
117    error_message = []
118    for item in unmatched1:
119      error_message.append(
120          '  Item from iter1 not found in iter2: %r' % item)
121    for item in list2:
122      error_message.append(
123          '  Item from iter2 not found in iter1: %r' % item)
124    if error_message:
125      self.fail('Collections not equivalent:\n' + '\n'.join(error_message))
126
127
128class ModuleInterfaceTest(object):
129  """Test to ensure module interface is carefully constructed.
130
131  A module interface is the set of public objects listed in the module __all__
132  attribute.  Modules that that are considered public should have this interface
133  carefully declared.  At all times, the __all__ attribute should have objects
134  intended to be publically used and all other objects in the module should be
135  considered unused.
136
137  Protected attributes (those beginning with '_') and other imported modules
138  should not be part of this set of variables.  An exception is for variables
139  that begin and end with '__' which are implicitly part of the interface
140  (eg. __name__, __file__, __all__ itself, etc.).
141
142  Modules that are imported in to the tested modules are an exception and may
143  be left out of the __all__ definition. The test is done by checking the value
144  of what would otherwise be a public name and not allowing it to be exported
145  if it is an instance of a module.  Modules that are explicitly exported are
146  for the time being not permitted.
147
148  To use this test class a module should define a new class that inherits first
149  from ModuleInterfaceTest and then from test_util.TestCase.  No other tests
150  should be added to this test case, making the order of inheritance less
151  important, but if setUp for some reason is overidden, it is important that
152  ModuleInterfaceTest is first in the list so that its setUp method is
153  invoked.
154
155  Multiple inheretance is required so that ModuleInterfaceTest is not itself
156  a test, and is not itself executed as one.
157
158  The test class is expected to have the following class attributes defined:
159
160    MODULE: A reference to the module that is being validated for interface
161      correctness.
162
163  Example:
164    Module definition (hello.py):
165
166      import sys
167
168      __all__ = ['hello']
169
170      def _get_outputter():
171        return sys.stdout
172
173      def hello():
174        _get_outputter().write('Hello\n')
175
176    Test definition:
177
178      import unittest
179      from protorpc import test_util
180
181      import hello
182
183      class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
184                                test_util.TestCase):
185
186        MODULE = hello
187
188
189      class HelloTest(test_util.TestCase):
190        ... Test 'hello' module ...
191
192
193      if __name__ == '__main__':
194        unittest.main()
195  """
196
197  def setUp(self):
198    """Set up makes sure that MODULE and IMPORTED_MODULES is defined.
199
200    This is a basic configuration test for the test itself so does not
201    get it's own test case.
202    """
203    if not hasattr(self, 'MODULE'):
204      self.fail(
205          "You must define 'MODULE' on ModuleInterfaceTest sub-class %s." %
206          type(self).__name__)
207
208  def testAllExist(self):
209    """Test that all attributes defined in __all__ exist."""
210    missing_attributes = []
211    for attribute in self.MODULE.__all__:
212      if not hasattr(self.MODULE, attribute):
213        missing_attributes.append(attribute)
214    if missing_attributes:
215      self.fail('%s of __all__ are not defined in module.' %
216                missing_attributes)
217
218  def testAllExported(self):
219    """Test that all public attributes not imported are in __all__."""
220    missing_attributes = []
221    for attribute in dir(self.MODULE):
222      if not attribute.startswith('_'):
223        if (attribute not in self.MODULE.__all__ and
224            not isinstance(getattr(self.MODULE, attribute),
225                           types.ModuleType) and
226            attribute != 'with_statement'):
227          missing_attributes.append(attribute)
228    if missing_attributes:
229      self.fail('%s are not modules and not defined in __all__.' %
230                missing_attributes)
231
232  def testNoExportedProtectedVariables(self):
233    """Test that there are no protected variables listed in __all__."""
234    protected_variables = []
235    for attribute in self.MODULE.__all__:
236      if attribute.startswith('_'):
237        protected_variables.append(attribute)
238    if protected_variables:
239      self.fail('%s are protected variables and may not be exported.' %
240                protected_variables)
241
242  def testNoExportedModules(self):
243    """Test that no modules exist in __all__."""
244    exported_modules = []
245    for attribute in self.MODULE.__all__:
246      try:
247        value = getattr(self.MODULE, attribute)
248      except AttributeError:
249        # This is a different error case tested for in testAllExist.
250        pass
251      else:
252        if isinstance(value, types.ModuleType):
253          exported_modules.append(attribute)
254    if exported_modules:
255      self.fail('%s are modules and may not be exported.' % exported_modules)
256
257
258class NestedMessage(messages.Message):
259  """Simple message that gets nested in another message."""
260
261  a_value = messages.StringField(1, required=True)
262
263
264class HasNestedMessage(messages.Message):
265  """Message that has another message nested in it."""
266
267  nested = messages.MessageField(NestedMessage, 1)
268  repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
269
270
271class HasDefault(messages.Message):
272  """Has a default value."""
273
274  a_value = messages.StringField(1, default=u'a default')
275
276
277class OptionalMessage(messages.Message):
278  """Contains all message types."""
279
280  class SimpleEnum(messages.Enum):
281    """Simple enumeration type."""
282    VAL1 = 1
283    VAL2 = 2
284
285  double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
286  float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
287  int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
288  uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
289  int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
290  bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
291  string_value = messages.StringField(7, variant=messages.Variant.STRING)
292  bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
293  enum_value = messages.EnumField(SimpleEnum, 10)
294
295  # TODO(rafek): Add support for these variants.
296  # uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32)
297  # sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32)
298  # sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64)
299
300
301class RepeatedMessage(messages.Message):
302  """Contains all message types as repeated fields."""
303
304  class SimpleEnum(messages.Enum):
305    """Simple enumeration type."""
306    VAL1 = 1
307    VAL2 = 2
308
309  double_value = messages.FloatField(1,
310                                     variant=messages.Variant.DOUBLE,
311                                     repeated=True)
312  float_value = messages.FloatField(2,
313                                    variant=messages.Variant.FLOAT,
314                                    repeated=True)
315  int64_value = messages.IntegerField(3,
316                                      variant=messages.Variant.INT64,
317                                      repeated=True)
318  uint64_value = messages.IntegerField(4,
319                                       variant=messages.Variant.UINT64,
320                                       repeated=True)
321  int32_value = messages.IntegerField(5,
322                                      variant=messages.Variant.INT32,
323                                      repeated=True)
324  bool_value = messages.BooleanField(6,
325                                     variant=messages.Variant.BOOL,
326                                     repeated=True)
327  string_value = messages.StringField(7,
328                                      variant=messages.Variant.STRING,
329                                      repeated=True)
330  bytes_value = messages.BytesField(8,
331                                    variant=messages.Variant.BYTES,
332                                    repeated=True)
333  #uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32)
334  enum_value = messages.EnumField(SimpleEnum,
335                                  10,
336                                  repeated=True)
337  #sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32)
338  #sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64)
339
340
341class HasOptionalNestedMessage(messages.Message):
342
343  nested = messages.MessageField(OptionalMessage, 1)
344  repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
345
346
347class ProtoConformanceTestBase(object):
348  """Protocol conformance test base class.
349
350  Each supported protocol should implement two methods that support encoding
351  and decoding of Message objects in that format:
352
353    encode_message(message) - Serialize to encoding.
354    encode_message(message, encoded_message) - Deserialize from encoding.
355
356  Tests for the modules where these functions are implemented should extend
357  this class in order to support basic behavioral expectations.  This ensures
358  that protocols correctly encode and decode message transparently to the
359  caller.
360
361  In order to support these test, the base class should also extend the TestCase
362  class and implement the following class attributes which define the encoded
363  version of certain protocol buffers:
364
365    encoded_partial:
366      <OptionalMessage
367        double_value: 1.23
368        int64_value: -100000000000
369        string_value: u"a string"
370        enum_value: OptionalMessage.SimpleEnum.VAL2
371        >
372
373    encoded_full:
374      <OptionalMessage
375        double_value: 1.23
376        float_value: -2.5
377        int64_value: -100000000000
378        uint64_value: 102020202020
379        int32_value: 1020
380        bool_value: true
381        string_value: u"a string\u044f"
382        bytes_value: b"a bytes\xff\xfe"
383        enum_value: OptionalMessage.SimpleEnum.VAL2
384        >
385
386    encoded_repeated:
387      <RepeatedMessage
388        double_value: [1.23, 2.3]
389        float_value: [-2.5, 0.5]
390        int64_value: [-100000000000, 20]
391        uint64_value: [102020202020, 10]
392        int32_value: [1020, 718]
393        bool_value: [true, false]
394        string_value: [u"a string\u044f", u"another string"]
395        bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
396        enum_value: [OptionalMessage.SimpleEnum.VAL2,
397                     OptionalMessage.SimpleEnum.VAL 1]
398        >
399
400    encoded_nested:
401      <HasNestedMessage
402        nested: <NestedMessage
403          a_value: "a string"
404          >
405        >
406
407    encoded_repeated_nested:
408      <HasNestedMessage
409        repeated_nested: [
410            <NestedMessage a_value: "a string">,
411            <NestedMessage a_value: "another string">
412          ]
413        >
414
415    unexpected_tag_message:
416      An encoded message that has an undefined tag or number in the stream.
417
418    encoded_default_assigned:
419      <HasDefault
420        a_value: "a default"
421        >
422
423    encoded_nested_empty:
424      <HasOptionalNestedMessage
425        nested: <OptionalMessage>
426        >
427
428    encoded_invalid_enum:
429      <OptionalMessage
430        enum_value: (invalid value for serialization type)
431        >
432  """
433
434  encoded_empty_message = ''
435
436  def testEncodeInvalidMessage(self):
437    message = NestedMessage()
438    self.assertRaises(messages.ValidationError,
439                      self.PROTOLIB.encode_message, message)
440
441  def CompareEncoded(self, expected_encoded, actual_encoded):
442    """Compare two encoded protocol values.
443
444    Can be overridden by sub-classes to special case comparison.
445    For example, to eliminate white space from output that is not
446    relevant to encoding.
447
448    Args:
449      expected_encoded: Expected string encoded value.
450      actual_encoded: Actual string encoded value.
451    """
452    self.assertEquals(expected_encoded, actual_encoded)
453
454  def EncodeDecode(self, encoded, expected_message):
455    message = self.PROTOLIB.decode_message(type(expected_message), encoded)
456    self.assertEquals(expected_message, message)
457    self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
458
459  def testEmptyMessage(self):
460    self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
461
462  def testPartial(self):
463    """Test message with a few values set."""
464    message = OptionalMessage()
465    message.double_value = 1.23
466    message.int64_value = -100000000000
467    message.int32_value = 1020
468    message.string_value = u'a string'
469    message.enum_value = OptionalMessage.SimpleEnum.VAL2
470
471    self.EncodeDecode(self.encoded_partial, message)
472
473  def testFull(self):
474    """Test all types."""
475    message = OptionalMessage()
476    message.double_value = 1.23
477    message.float_value = -2.5
478    message.int64_value = -100000000000
479    message.uint64_value = 102020202020
480    message.int32_value = 1020
481    message.bool_value = True
482    message.string_value = u'a string\u044f'
483    message.bytes_value = b'a bytes\xff\xfe'
484    message.enum_value = OptionalMessage.SimpleEnum.VAL2
485
486    self.EncodeDecode(self.encoded_full, message)
487
488  def testRepeated(self):
489    """Test repeated fields."""
490    message = RepeatedMessage()
491    message.double_value = [1.23, 2.3]
492    message.float_value = [-2.5, 0.5]
493    message.int64_value = [-100000000000, 20]
494    message.uint64_value = [102020202020, 10]
495    message.int32_value = [1020, 718]
496    message.bool_value = [True, False]
497    message.string_value = [u'a string\u044f', u'another string']
498    message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
499    message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
500                          RepeatedMessage.SimpleEnum.VAL1]
501
502    self.EncodeDecode(self.encoded_repeated, message)
503
504  def testNested(self):
505    """Test nested messages."""
506    nested_message = NestedMessage()
507    nested_message.a_value = u'a string'
508
509    message = HasNestedMessage()
510    message.nested = nested_message
511
512    self.EncodeDecode(self.encoded_nested, message)
513
514  def testRepeatedNested(self):
515    """Test repeated nested messages."""
516    nested_message1 = NestedMessage()
517    nested_message1.a_value = u'a string'
518    nested_message2 = NestedMessage()
519    nested_message2.a_value = u'another string'
520
521    message = HasNestedMessage()
522    message.repeated_nested = [nested_message1, nested_message2]
523
524    self.EncodeDecode(self.encoded_repeated_nested, message)
525
526  def testStringTypes(self):
527    """Test that encoding str on StringField works."""
528    message = OptionalMessage()
529    message.string_value = 'Latin'
530    self.EncodeDecode(self.encoded_string_types, message)
531
532  def testEncodeUninitialized(self):
533    """Test that cannot encode uninitialized message."""
534    required = NestedMessage()
535    self.assertRaisesWithRegexpMatch(messages.ValidationError,
536                                     "Message NestedMessage is missing "
537                                     "required field a_value",
538                                     self.PROTOLIB.encode_message,
539                                     required)
540
541  def testUnexpectedField(self):
542    """Test decoding and encoding unexpected fields."""
543    loaded_message = self.PROTOLIB.decode_message(OptionalMessage,
544                                                  self.unexpected_tag_message)
545    # Message should be equal to an empty message, since unknown values aren't
546    # included in equality.
547    self.assertEquals(OptionalMessage(), loaded_message)
548    # Verify that the encoded message matches the source, including the
549    # unknown value.
550    self.assertEquals(self.unexpected_tag_message,
551                      self.PROTOLIB.encode_message(loaded_message))
552
553  def testDoNotSendDefault(self):
554    """Test that default is not sent when nothing is assigned."""
555    self.EncodeDecode(self.encoded_empty_message, HasDefault())
556
557  def testSendDefaultExplicitlyAssigned(self):
558    """Test that default is sent when explcitly assigned."""
559    message = HasDefault()
560
561    message.a_value = HasDefault.a_value.default
562
563    self.EncodeDecode(self.encoded_default_assigned, message)
564
565  def testEncodingNestedEmptyMessage(self):
566    """Test encoding a nested empty message."""
567    message = HasOptionalNestedMessage()
568    message.nested = OptionalMessage()
569
570    self.EncodeDecode(self.encoded_nested_empty, message)
571
572  def testEncodingRepeatedNestedEmptyMessage(self):
573    """Test encoding a nested empty message."""
574    message = HasOptionalNestedMessage()
575    message.repeated_nested = [OptionalMessage(), OptionalMessage()]
576
577    self.EncodeDecode(self.encoded_repeated_nested_empty, message)
578
579  def testContentType(self):
580    self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
581
582  def testDecodeInvalidEnumType(self):
583    self.assertRaisesWithRegexpMatch(messages.DecodeError,
584                                     'Invalid enum value ',
585                                     self.PROTOLIB.decode_message,
586                                     OptionalMessage,
587                                     self.encoded_invalid_enum)
588
589  def testDateTimeNoTimeZone(self):
590    """Test that DateTimeFields are encoded/decoded correctly."""
591
592    class MyMessage(messages.Message):
593      value = message_types.DateTimeField(1)
594
595    value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
596    message = MyMessage(value=value)
597    decoded = self.PROTOLIB.decode_message(
598        MyMessage, self.PROTOLIB.encode_message(message))
599    self.assertEquals(decoded.value, value)
600
601  def testDateTimeWithTimeZone(self):
602    """Test DateTimeFields with time zones."""
603
604    class MyMessage(messages.Message):
605      value = message_types.DateTimeField(1)
606
607    value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
608                              util.TimeZoneOffset(8 * 60))
609    message = MyMessage(value=value)
610    decoded = self.PROTOLIB.decode_message(
611        MyMessage, self.PROTOLIB.encode_message(message))
612    self.assertEquals(decoded.value, value)
613
614
615def do_with(context, function, *args, **kwargs):
616  """Simulate a with statement.
617
618  Avoids need to import with from future.
619
620  Does not support simulation of 'as'.
621
622  Args:
623    context: Context object normally used with 'with'.
624    function: Callable to evoke.  Replaces with-block.
625  """
626  context.__enter__()
627  try:
628    function(*args, **kwargs)
629  except:
630    context.__exit__(*sys.exc_info())
631  finally:
632    context.__exit__(None, None, None)
633
634
635def pick_unused_port():
636  """Find an unused port to use in tests.
637
638    Derived from Damon Kohlers example:
639
640      http://code.activestate.com/recipes/531822-pick-unused-port
641  """
642  temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
643  try:
644    temp.bind(('localhost', 0))
645    port = temp.getsockname()[1]
646  finally:
647    temp.close()
648  return port
649
650
651def get_module_name(module_attribute):
652  """Get the module name.
653
654  Args:
655    module_attribute: An attribute of the module.
656
657  Returns:
658    The fully qualified module name or simple module name where
659    'module_attribute' is defined if the module name is "__main__".
660  """
661  if module_attribute.__module__ == '__main__':
662    module_file = inspect.getfile(module_attribute)
663    default = os.path.basename(module_file).split('.')[0]
664    return default
665  else:
666    return module_attribute.__module__
667