1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Test for preservation of unknown fields in the pure Python implementation.""" 35 36__author__ = 'bohdank@google.com (Bohdan Koval)' 37 38try: 39 import unittest2 as unittest #PY26 40except ImportError: 41 import unittest 42from google.protobuf import unittest_mset_pb2 43from google.protobuf import unittest_pb2 44from google.protobuf import unittest_proto3_arena_pb2 45from google.protobuf.internal import api_implementation 46from google.protobuf.internal import encoder 47from google.protobuf.internal import message_set_extensions_pb2 48from google.protobuf.internal import missing_enum_values_pb2 49from google.protobuf.internal import test_util 50from google.protobuf.internal import type_checkers 51 52 53def SkipIfCppImplementation(func): 54 return unittest.skipIf( 55 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 56 'C++ implementation does not expose unknown fields to Python')(func) 57 58 59class UnknownFieldsTest(unittest.TestCase): 60 61 def setUp(self): 62 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 63 self.all_fields = unittest_pb2.TestAllTypes() 64 test_util.SetAllFields(self.all_fields) 65 self.all_fields_data = self.all_fields.SerializeToString() 66 self.empty_message = unittest_pb2.TestEmptyMessage() 67 self.empty_message.ParseFromString(self.all_fields_data) 68 69 def testSerialize(self): 70 data = self.empty_message.SerializeToString() 71 72 # Don't use assertEqual because we don't want to dump raw binary data to 73 # stdout. 74 self.assertTrue(data == self.all_fields_data) 75 76 def testSerializeProto3(self): 77 # Verify that proto3 doesn't preserve unknown fields. 78 message = unittest_proto3_arena_pb2.TestEmptyMessage() 79 message.ParseFromString(self.all_fields_data) 80 self.assertEqual(0, len(message.SerializeToString())) 81 82 def testByteSize(self): 83 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 84 85 def testListFields(self): 86 # Make sure ListFields doesn't return unknown fields. 87 self.assertEqual(0, len(self.empty_message.ListFields())) 88 89 def testSerializeMessageSetWireFormatUnknownExtension(self): 90 # Create a message using the message set wire format with an unknown 91 # message. 92 raw = unittest_mset_pb2.RawMessageSet() 93 94 # Add an unknown extension. 95 item = raw.item.add() 96 item.type_id = 98418603 97 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 98 message1.i = 12345 99 item.message = message1.SerializeToString() 100 101 serialized = raw.SerializeToString() 102 103 # Parse message using the message set wire format. 104 proto = message_set_extensions_pb2.TestMessageSet() 105 proto.MergeFromString(serialized) 106 107 # Verify that the unknown extension is serialized unchanged 108 reserialized = proto.SerializeToString() 109 new_raw = unittest_mset_pb2.RawMessageSet() 110 new_raw.MergeFromString(reserialized) 111 self.assertEqual(raw, new_raw) 112 113 def testEquals(self): 114 message = unittest_pb2.TestEmptyMessage() 115 message.ParseFromString(self.all_fields_data) 116 self.assertEqual(self.empty_message, message) 117 118 self.all_fields.ClearField('optional_string') 119 message.ParseFromString(self.all_fields.SerializeToString()) 120 self.assertNotEqual(self.empty_message, message) 121 122 def testDiscardUnknownFields(self): 123 self.empty_message.DiscardUnknownFields() 124 self.assertEqual(b'', self.empty_message.SerializeToString()) 125 # Test message field and repeated message field. 126 message = unittest_pb2.TestAllTypes() 127 other_message = unittest_pb2.TestAllTypes() 128 other_message.optional_string = 'discard' 129 message.optional_nested_message.ParseFromString( 130 other_message.SerializeToString()) 131 message.repeated_nested_message.add().ParseFromString( 132 other_message.SerializeToString()) 133 self.assertNotEqual( 134 b'', message.optional_nested_message.SerializeToString()) 135 self.assertNotEqual( 136 b'', message.repeated_nested_message[0].SerializeToString()) 137 message.DiscardUnknownFields() 138 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 139 self.assertEqual( 140 b'', message.repeated_nested_message[0].SerializeToString()) 141 142 143class UnknownFieldsAccessorsTest(unittest.TestCase): 144 145 def setUp(self): 146 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 147 self.all_fields = unittest_pb2.TestAllTypes() 148 test_util.SetAllFields(self.all_fields) 149 self.all_fields_data = self.all_fields.SerializeToString() 150 self.empty_message = unittest_pb2.TestEmptyMessage() 151 self.empty_message.ParseFromString(self.all_fields_data) 152 if api_implementation.Type() != 'cpp': 153 # _unknown_fields is an implementation detail. 154 self.unknown_fields = self.empty_message._unknown_fields 155 156 # All the tests that use GetField() check an implementation detail of the 157 # Python implementation, which stores unknown fields as serialized strings. 158 # These tests are skipped by the C++ implementation: it's enough to check that 159 # the message is correctly serialized. 160 161 def GetField(self, name): 162 field_descriptor = self.descriptor.fields_by_name[name] 163 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 164 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 165 result_dict = {} 166 for tag_bytes, value in self.unknown_fields: 167 if tag_bytes == field_tag: 168 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] 169 decoder(value, 0, len(value), self.all_fields, result_dict) 170 return result_dict[field_descriptor] 171 172 @SkipIfCppImplementation 173 def testEnum(self): 174 value = self.GetField('optional_nested_enum') 175 self.assertEqual(self.all_fields.optional_nested_enum, value) 176 177 @SkipIfCppImplementation 178 def testRepeatedEnum(self): 179 value = self.GetField('repeated_nested_enum') 180 self.assertEqual(self.all_fields.repeated_nested_enum, value) 181 182 @SkipIfCppImplementation 183 def testVarint(self): 184 value = self.GetField('optional_int32') 185 self.assertEqual(self.all_fields.optional_int32, value) 186 187 @SkipIfCppImplementation 188 def testFixed32(self): 189 value = self.GetField('optional_fixed32') 190 self.assertEqual(self.all_fields.optional_fixed32, value) 191 192 @SkipIfCppImplementation 193 def testFixed64(self): 194 value = self.GetField('optional_fixed64') 195 self.assertEqual(self.all_fields.optional_fixed64, value) 196 197 @SkipIfCppImplementation 198 def testLengthDelimited(self): 199 value = self.GetField('optional_string') 200 self.assertEqual(self.all_fields.optional_string, value) 201 202 @SkipIfCppImplementation 203 def testGroup(self): 204 value = self.GetField('optionalgroup') 205 self.assertEqual(self.all_fields.optionalgroup, value) 206 207 def testCopyFrom(self): 208 message = unittest_pb2.TestEmptyMessage() 209 message.CopyFrom(self.empty_message) 210 self.assertEqual(message.SerializeToString(), self.all_fields_data) 211 212 def testMergeFrom(self): 213 message = unittest_pb2.TestAllTypes() 214 message.optional_int32 = 1 215 message.optional_uint32 = 2 216 source = unittest_pb2.TestEmptyMessage() 217 source.ParseFromString(message.SerializeToString()) 218 219 message.ClearField('optional_int32') 220 message.optional_int64 = 3 221 message.optional_uint32 = 4 222 destination = unittest_pb2.TestEmptyMessage() 223 destination.ParseFromString(message.SerializeToString()) 224 225 destination.MergeFrom(source) 226 # Check that the fields where correctly merged, even stored in the unknown 227 # fields set. 228 message.ParseFromString(destination.SerializeToString()) 229 self.assertEqual(message.optional_int32, 1) 230 self.assertEqual(message.optional_uint32, 2) 231 self.assertEqual(message.optional_int64, 3) 232 233 def testClear(self): 234 self.empty_message.Clear() 235 # All cleared, even unknown fields. 236 self.assertEqual(self.empty_message.SerializeToString(), b'') 237 238 def testUnknownExtensions(self): 239 message = unittest_pb2.TestEmptyMessageWithExtensions() 240 message.ParseFromString(self.all_fields_data) 241 self.assertEqual(message.SerializeToString(), self.all_fields_data) 242 243 244class UnknownEnumValuesTest(unittest.TestCase): 245 246 def setUp(self): 247 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 248 249 self.message = missing_enum_values_pb2.TestEnumValues() 250 self.message.optional_nested_enum = ( 251 missing_enum_values_pb2.TestEnumValues.ZERO) 252 self.message.repeated_nested_enum.extend([ 253 missing_enum_values_pb2.TestEnumValues.ZERO, 254 missing_enum_values_pb2.TestEnumValues.ONE, 255 ]) 256 self.message.packed_nested_enum.extend([ 257 missing_enum_values_pb2.TestEnumValues.ZERO, 258 missing_enum_values_pb2.TestEnumValues.ONE, 259 ]) 260 self.message_data = self.message.SerializeToString() 261 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 262 self.missing_message.ParseFromString(self.message_data) 263 if api_implementation.Type() != 'cpp': 264 # _unknown_fields is an implementation detail. 265 self.unknown_fields = self.missing_message._unknown_fields 266 267 # All the tests that use GetField() check an implementation detail of the 268 # Python implementation, which stores unknown fields as serialized strings. 269 # These tests are skipped by the C++ implementation: it's enough to check that 270 # the message is correctly serialized. 271 272 def GetField(self, name): 273 field_descriptor = self.descriptor.fields_by_name[name] 274 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 275 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 276 result_dict = {} 277 for tag_bytes, value in self.unknown_fields: 278 if tag_bytes == field_tag: 279 decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ 280 tag_bytes][0] 281 decoder(value, 0, len(value), self.message, result_dict) 282 return result_dict[field_descriptor] 283 284 def testUnknownParseMismatchEnumValue(self): 285 just_string = missing_enum_values_pb2.JustString() 286 just_string.dummy = 'blah' 287 288 missing = missing_enum_values_pb2.TestEnumValues() 289 # The parse is invalid, storing the string proto into the set of 290 # unknown fields. 291 missing.ParseFromString(just_string.SerializeToString()) 292 293 # Fetching the enum field shouldn't crash, instead returning the 294 # default value. 295 self.assertEqual(missing.optional_nested_enum, 0) 296 297 @SkipIfCppImplementation 298 def testUnknownEnumValue(self): 299 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 300 value = self.GetField('optional_nested_enum') 301 self.assertEqual(self.message.optional_nested_enum, value) 302 303 @SkipIfCppImplementation 304 def testUnknownRepeatedEnumValue(self): 305 value = self.GetField('repeated_nested_enum') 306 self.assertEqual(self.message.repeated_nested_enum, value) 307 308 @SkipIfCppImplementation 309 def testUnknownPackedEnumValue(self): 310 value = self.GetField('packed_nested_enum') 311 self.assertEqual(self.message.packed_nested_enum, value) 312 313 def testRoundTrip(self): 314 new_message = missing_enum_values_pb2.TestEnumValues() 315 new_message.ParseFromString(self.missing_message.SerializeToString()) 316 self.assertEqual(self.message, new_message) 317 318 319if __name__ == '__main__': 320 unittest.main() 321