1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 //
35 // Contains methods defined in extension_set.h which cannot be part of the
36 // lite library because they use descriptors or reflection.
37 
38 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/extension_set.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/repeated_field.h>
43 #include <google/protobuf/wire_format.h>
44 #include <google/protobuf/wire_format_lite_inl.h>
45 
46 namespace google {
47 
48 namespace protobuf {
49 namespace internal {
50 
51 // A FieldSkipper used to store unknown MessageSet fields into UnknownFieldSet.
52 class MessageSetFieldSkipper
53     : public UnknownFieldSetFieldSkipper {
54  public:
MessageSetFieldSkipper(UnknownFieldSet * unknown_fields)55   explicit MessageSetFieldSkipper(UnknownFieldSet* unknown_fields)
56       : UnknownFieldSetFieldSkipper(unknown_fields) {}
~MessageSetFieldSkipper()57   virtual ~MessageSetFieldSkipper() {}
58 
59   virtual bool SkipMessageSetField(io::CodedInputStream* input,
60                                    int field_number);
61 };
SkipMessageSetField(io::CodedInputStream * input,int field_number)62 bool MessageSetFieldSkipper::SkipMessageSetField(
63     io::CodedInputStream* input, int field_number) {
64   uint32 length;
65   if (!input->ReadVarint32(&length)) return false;
66   if (unknown_fields_ == NULL) {
67     return input->Skip(length);
68   } else {
69     return input->ReadString(
70         unknown_fields_->AddLengthDelimited(field_number), length);
71   }
72 }
73 
74 
75 // Implementation of ExtensionFinder which finds extensions in a given
76 // DescriptorPool, using the given MessageFactory to construct sub-objects.
77 // This class is implemented in extension_set_heavy.cc.
78 class DescriptorPoolExtensionFinder : public ExtensionFinder {
79  public:
DescriptorPoolExtensionFinder(const DescriptorPool * pool,MessageFactory * factory,const Descriptor * containing_type)80   DescriptorPoolExtensionFinder(const DescriptorPool* pool,
81                                 MessageFactory* factory,
82                                 const Descriptor* containing_type)
83       : pool_(pool), factory_(factory), containing_type_(containing_type) {}
~DescriptorPoolExtensionFinder()84   virtual ~DescriptorPoolExtensionFinder() {}
85 
86   virtual bool Find(int number, ExtensionInfo* output);
87 
88  private:
89   const DescriptorPool* pool_;
90   MessageFactory* factory_;
91   const Descriptor* containing_type_;
92 };
93 
AppendToList(const Descriptor * containing_type,const DescriptorPool * pool,std::vector<const FieldDescriptor * > * output) const94 void ExtensionSet::AppendToList(
95     const Descriptor* containing_type,
96     const DescriptorPool* pool,
97     std::vector<const FieldDescriptor*>* output) const {
98   for (map<int, Extension>::const_iterator iter = extensions_.begin();
99        iter != extensions_.end(); ++iter) {
100     bool has = false;
101     if (iter->second.is_repeated) {
102       has = iter->second.GetSize() > 0;
103     } else {
104       has = !iter->second.is_cleared;
105     }
106 
107     if (has) {
108       // TODO(kenton): Looking up each field by number is somewhat unfortunate.
109       //   Is there a better way?  The problem is that descriptors are lazily-
110       //   initialized, so they might not even be constructed until
111       //   AppendToList() is called.
112 
113       if (iter->second.descriptor == NULL) {
114         output->push_back(pool->FindExtensionByNumber(
115             containing_type, iter->first));
116       } else {
117         output->push_back(iter->second.descriptor);
118       }
119     }
120   }
121 }
122 
real_type(FieldType type)123 inline FieldDescriptor::Type real_type(FieldType type) {
124   GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
125   return static_cast<FieldDescriptor::Type>(type);
126 }
127 
cpp_type(FieldType type)128 inline FieldDescriptor::CppType cpp_type(FieldType type) {
129   return FieldDescriptor::TypeToCppType(
130       static_cast<FieldDescriptor::Type>(type));
131 }
132 
field_type(FieldType type)133 inline WireFormatLite::FieldType field_type(FieldType type) {
134   GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
135   return static_cast<WireFormatLite::FieldType>(type);
136 }
137 
138 #define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                            \
139   GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED     \
140                                   : FieldDescriptor::LABEL_OPTIONAL,      \
141             FieldDescriptor::LABEL_##LABEL);                              \
142   GOOGLE_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
143 
GetMessage(int number,const Descriptor * message_type,MessageFactory * factory) const144 const MessageLite& ExtensionSet::GetMessage(int number,
145                                             const Descriptor* message_type,
146                                             MessageFactory* factory) const {
147   map<int, Extension>::const_iterator iter = extensions_.find(number);
148   if (iter == extensions_.end() || iter->second.is_cleared) {
149     // Not present.  Return the default value.
150     return *factory->GetPrototype(message_type);
151   } else {
152     GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
153     if (iter->second.is_lazy) {
154       return iter->second.lazymessage_value->GetMessage(
155           *factory->GetPrototype(message_type));
156     } else {
157       return *iter->second.message_value;
158     }
159   }
160 }
161 
MutableMessage(const FieldDescriptor * descriptor,MessageFactory * factory)162 MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
163                                           MessageFactory* factory) {
164   Extension* extension;
165   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
166     extension->type = descriptor->type();
167     GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
168     extension->is_repeated = false;
169     extension->is_packed = false;
170     const MessageLite* prototype =
171         factory->GetPrototype(descriptor->message_type());
172     extension->is_lazy = false;
173     extension->message_value = prototype->New(arena_);
174     extension->is_cleared = false;
175     return extension->message_value;
176   } else {
177     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
178     extension->is_cleared = false;
179     if (extension->is_lazy) {
180       return extension->lazymessage_value->MutableMessage(
181           *factory->GetPrototype(descriptor->message_type()));
182     } else {
183       return extension->message_value;
184     }
185   }
186 }
187 
ReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)188 MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
189                                           MessageFactory* factory) {
190   map<int, Extension>::iterator iter = extensions_.find(descriptor->number());
191   if (iter == extensions_.end()) {
192     // Not present.  Return NULL.
193     return NULL;
194   } else {
195     GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
196     MessageLite* ret = NULL;
197     if (iter->second.is_lazy) {
198       ret = iter->second.lazymessage_value->ReleaseMessage(
199           *factory->GetPrototype(descriptor->message_type()));
200       if (arena_ == NULL) {
201         delete iter->second.lazymessage_value;
202       }
203     } else {
204       if (arena_ != NULL) {
205         ret = (iter->second.message_value)->New();
206         ret->CheckTypeAndMergeFrom(*(iter->second.message_value));
207       } else {
208         ret = iter->second.message_value;
209       }
210     }
211     extensions_.erase(descriptor->number());
212     return ret;
213   }
214 }
215 
MaybeNewRepeatedExtension(const FieldDescriptor * descriptor)216 ExtensionSet::Extension* ExtensionSet::MaybeNewRepeatedExtension(const FieldDescriptor* descriptor) {
217   Extension* extension;
218   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
219     extension->type = descriptor->type();
220     GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
221     extension->is_repeated = true;
222     extension->repeated_message_value =
223         ::google::protobuf::Arena::CreateMessage<RepeatedPtrField<MessageLite> >(arena_);
224   } else {
225     GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
226   }
227   return extension;
228 }
229 
AddMessage(const FieldDescriptor * descriptor,MessageFactory * factory)230 MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
231                                       MessageFactory* factory) {
232   Extension* extension = MaybeNewRepeatedExtension(descriptor);
233 
234   // RepeatedPtrField<Message> does not know how to Add() since it cannot
235   // allocate an abstract object, so we have to be tricky.
236   MessageLite* result = extension->repeated_message_value
237       ->AddFromCleared<GenericTypeHandler<MessageLite> >();
238   if (result == NULL) {
239     const MessageLite* prototype;
240     if (extension->repeated_message_value->size() == 0) {
241       prototype = factory->GetPrototype(descriptor->message_type());
242       GOOGLE_CHECK(prototype != NULL);
243     } else {
244       prototype = &extension->repeated_message_value->Get(0);
245     }
246     result = prototype->New(arena_);
247     extension->repeated_message_value->AddAllocated(result);
248   }
249   return result;
250 }
251 
AddAllocatedMessage(const FieldDescriptor * descriptor,MessageLite * new_entry)252 void ExtensionSet::AddAllocatedMessage(const FieldDescriptor* descriptor,
253                                        MessageLite* new_entry) {
254   Extension* extension = MaybeNewRepeatedExtension(descriptor);
255 
256   extension->repeated_message_value->AddAllocated(new_entry);
257 }
258 
ValidateEnumUsingDescriptor(const void * arg,int number)259 static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
260   return reinterpret_cast<const EnumDescriptor*>(arg)
261       ->FindValueByNumber(number) != NULL;
262 }
263 
Find(int number,ExtensionInfo * output)264 bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
265   const FieldDescriptor* extension =
266       pool_->FindExtensionByNumber(containing_type_, number);
267   if (extension == NULL) {
268     return false;
269   } else {
270     output->type = extension->type();
271     output->is_repeated = extension->is_repeated();
272     output->is_packed = extension->options().packed();
273     output->descriptor = extension;
274     if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
275       output->message_prototype =
276           factory_->GetPrototype(extension->message_type());
277       GOOGLE_CHECK(output->message_prototype != NULL)
278           << "Extension factory's GetPrototype() returned NULL for extension: "
279           << extension->full_name();
280     } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
281       output->enum_validity_check.func = ValidateEnumUsingDescriptor;
282       output->enum_validity_check.arg = extension->enum_type();
283     }
284 
285     return true;
286   }
287 }
288 
ParseField(uint32 tag,io::CodedInputStream * input,const Message * containing_type,UnknownFieldSet * unknown_fields)289 bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
290                               const Message* containing_type,
291                               UnknownFieldSet* unknown_fields) {
292   UnknownFieldSetFieldSkipper skipper(unknown_fields);
293   if (input->GetExtensionPool() == NULL) {
294     GeneratedExtensionFinder finder(containing_type);
295     return ParseField(tag, input, &finder, &skipper);
296   } else {
297     DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
298                                          input->GetExtensionFactory(),
299                                          containing_type->GetDescriptor());
300     return ParseField(tag, input, &finder, &skipper);
301   }
302 }
303 
ParseMessageSet(io::CodedInputStream * input,const Message * containing_type,UnknownFieldSet * unknown_fields)304 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
305                                    const Message* containing_type,
306                                    UnknownFieldSet* unknown_fields) {
307   MessageSetFieldSkipper skipper(unknown_fields);
308   if (input->GetExtensionPool() == NULL) {
309     GeneratedExtensionFinder finder(containing_type);
310     return ParseMessageSet(input, &finder, &skipper);
311   } else {
312     DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
313                                          input->GetExtensionFactory(),
314                                          containing_type->GetDescriptor());
315     return ParseMessageSet(input, &finder, &skipper);
316   }
317 }
318 
SpaceUsedExcludingSelf() const319 int ExtensionSet::SpaceUsedExcludingSelf() const {
320   int total_size =
321       extensions_.size() * sizeof(map<int, Extension>::value_type);
322   for (map<int, Extension>::const_iterator iter = extensions_.begin(),
323        end = extensions_.end();
324        iter != end;
325        ++iter) {
326     total_size += iter->second.SpaceUsedExcludingSelf();
327   }
328   return total_size;
329 }
330 
RepeatedMessage_SpaceUsedExcludingSelf(RepeatedPtrFieldBase * field)331 inline int ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelf(
332     RepeatedPtrFieldBase* field) {
333   return field->SpaceUsedExcludingSelf<GenericTypeHandler<Message> >();
334 }
335 
SpaceUsedExcludingSelf() const336 int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
337   int total_size = 0;
338   if (is_repeated) {
339     switch (cpp_type(type)) {
340 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)                          \
341       case FieldDescriptor::CPPTYPE_##UPPERCASE:                   \
342         total_size += sizeof(*repeated_##LOWERCASE##_value) +      \
343             repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\
344         break
345 
346       HANDLE_TYPE(  INT32,   int32);
347       HANDLE_TYPE(  INT64,   int64);
348       HANDLE_TYPE( UINT32,  uint32);
349       HANDLE_TYPE( UINT64,  uint64);
350       HANDLE_TYPE(  FLOAT,   float);
351       HANDLE_TYPE( DOUBLE,  double);
352       HANDLE_TYPE(   BOOL,    bool);
353       HANDLE_TYPE(   ENUM,    enum);
354       HANDLE_TYPE( STRING,  string);
355 #undef HANDLE_TYPE
356 
357       case FieldDescriptor::CPPTYPE_MESSAGE:
358         // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
359         // but MessageLite has no SpaceUsed(), so we must directly call
360         // RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type
361         // handler.
362         total_size += sizeof(*repeated_message_value) +
363             RepeatedMessage_SpaceUsedExcludingSelf(repeated_message_value);
364         break;
365     }
366   } else {
367     switch (cpp_type(type)) {
368       case FieldDescriptor::CPPTYPE_STRING:
369         total_size += sizeof(*string_value) +
370                       StringSpaceUsedExcludingSelf(*string_value);
371         break;
372       case FieldDescriptor::CPPTYPE_MESSAGE:
373         if (is_lazy) {
374           total_size += lazymessage_value->SpaceUsed();
375         } else {
376           total_size += down_cast<Message*>(message_value)->SpaceUsed();
377         }
378         break;
379       default:
380         // No extra storage costs for primitive types.
381         break;
382     }
383   }
384   return total_size;
385 }
386 
387 // The Serialize*ToArray methods are only needed in the heavy library, as
388 // the lite library only generates SerializeWithCachedSizes.
SerializeWithCachedSizesToArray(int start_field_number,int end_field_number,uint8 * target) const389 uint8* ExtensionSet::SerializeWithCachedSizesToArray(
390     int start_field_number, int end_field_number,
391     uint8* target) const {
392   map<int, Extension>::const_iterator iter;
393   for (iter = extensions_.lower_bound(start_field_number);
394        iter != extensions_.end() && iter->first < end_field_number;
395        ++iter) {
396     target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first,
397                                                                target);
398   }
399   return target;
400 }
401 
SerializeMessageSetWithCachedSizesToArray(uint8 * target) const402 uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
403     uint8* target) const {
404   map<int, Extension>::const_iterator iter;
405   for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) {
406     target = iter->second.SerializeMessageSetItemWithCachedSizesToArray(
407         iter->first, target);
408   }
409   return target;
410 }
411 
SerializeFieldWithCachedSizesToArray(int number,uint8 * target) const412 uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
413     int number, uint8* target) const {
414   if (is_repeated) {
415     if (is_packed) {
416       if (cached_size == 0) return target;
417 
418       target = WireFormatLite::WriteTagToArray(number,
419           WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
420       target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target);
421 
422       switch (real_type(type)) {
423 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
424         case FieldDescriptor::TYPE_##UPPERCASE:                             \
425           for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
426             target = WireFormatLite::Write##CAMELCASE##NoTagToArray(        \
427               repeated_##LOWERCASE##_value->Get(i), target);                \
428           }                                                                 \
429           break
430 
431         HANDLE_TYPE(   INT32,    Int32,   int32);
432         HANDLE_TYPE(   INT64,    Int64,   int64);
433         HANDLE_TYPE(  UINT32,   UInt32,  uint32);
434         HANDLE_TYPE(  UINT64,   UInt64,  uint64);
435         HANDLE_TYPE(  SINT32,   SInt32,   int32);
436         HANDLE_TYPE(  SINT64,   SInt64,   int64);
437         HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
438         HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
439         HANDLE_TYPE(SFIXED32, SFixed32,   int32);
440         HANDLE_TYPE(SFIXED64, SFixed64,   int64);
441         HANDLE_TYPE(   FLOAT,    Float,   float);
442         HANDLE_TYPE(  DOUBLE,   Double,  double);
443         HANDLE_TYPE(    BOOL,     Bool,    bool);
444         HANDLE_TYPE(    ENUM,     Enum,    enum);
445 #undef HANDLE_TYPE
446 
447         case WireFormatLite::TYPE_STRING:
448         case WireFormatLite::TYPE_BYTES:
449         case WireFormatLite::TYPE_GROUP:
450         case WireFormatLite::TYPE_MESSAGE:
451           GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
452           break;
453       }
454     } else {
455       switch (real_type(type)) {
456 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
457         case FieldDescriptor::TYPE_##UPPERCASE:                             \
458           for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
459             target = WireFormatLite::Write##CAMELCASE##ToArray(number,      \
460               repeated_##LOWERCASE##_value->Get(i), target);                \
461           }                                                                 \
462           break
463 
464         HANDLE_TYPE(   INT32,    Int32,   int32);
465         HANDLE_TYPE(   INT64,    Int64,   int64);
466         HANDLE_TYPE(  UINT32,   UInt32,  uint32);
467         HANDLE_TYPE(  UINT64,   UInt64,  uint64);
468         HANDLE_TYPE(  SINT32,   SInt32,   int32);
469         HANDLE_TYPE(  SINT64,   SInt64,   int64);
470         HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
471         HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
472         HANDLE_TYPE(SFIXED32, SFixed32,   int32);
473         HANDLE_TYPE(SFIXED64, SFixed64,   int64);
474         HANDLE_TYPE(   FLOAT,    Float,   float);
475         HANDLE_TYPE(  DOUBLE,   Double,  double);
476         HANDLE_TYPE(    BOOL,     Bool,    bool);
477         HANDLE_TYPE(  STRING,   String,  string);
478         HANDLE_TYPE(   BYTES,    Bytes,  string);
479         HANDLE_TYPE(    ENUM,     Enum,    enum);
480         HANDLE_TYPE(   GROUP,    Group, message);
481         HANDLE_TYPE( MESSAGE,  Message, message);
482 #undef HANDLE_TYPE
483       }
484     }
485   } else if (!is_cleared) {
486     switch (real_type(type)) {
487 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)                 \
488       case FieldDescriptor::TYPE_##UPPERCASE:                    \
489         target = WireFormatLite::Write##CAMELCASE##ToArray(      \
490             number, VALUE, target); \
491         break
492 
493       HANDLE_TYPE(   INT32,    Int32,    int32_value);
494       HANDLE_TYPE(   INT64,    Int64,    int64_value);
495       HANDLE_TYPE(  UINT32,   UInt32,   uint32_value);
496       HANDLE_TYPE(  UINT64,   UInt64,   uint64_value);
497       HANDLE_TYPE(  SINT32,   SInt32,    int32_value);
498       HANDLE_TYPE(  SINT64,   SInt64,    int64_value);
499       HANDLE_TYPE( FIXED32,  Fixed32,   uint32_value);
500       HANDLE_TYPE( FIXED64,  Fixed64,   uint64_value);
501       HANDLE_TYPE(SFIXED32, SFixed32,    int32_value);
502       HANDLE_TYPE(SFIXED64, SFixed64,    int64_value);
503       HANDLE_TYPE(   FLOAT,    Float,    float_value);
504       HANDLE_TYPE(  DOUBLE,   Double,   double_value);
505       HANDLE_TYPE(    BOOL,     Bool,     bool_value);
506       HANDLE_TYPE(  STRING,   String,  *string_value);
507       HANDLE_TYPE(   BYTES,    Bytes,  *string_value);
508       HANDLE_TYPE(    ENUM,     Enum,     enum_value);
509       HANDLE_TYPE(   GROUP,    Group, *message_value);
510 #undef HANDLE_TYPE
511       case FieldDescriptor::TYPE_MESSAGE:
512         if (is_lazy) {
513           target = lazymessage_value->WriteMessageToArray(number, target);
514         } else {
515           target = WireFormatLite::WriteMessageToArray(
516               number, *message_value, target);
517         }
518         break;
519     }
520   }
521   return target;
522 }
523 
SerializeMessageSetItemWithCachedSizesToArray(int number,uint8 * target) const524 uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
525     int number,
526     uint8* target) const {
527   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
528     // Not a valid MessageSet extension, but serialize it the normal way.
529     GOOGLE_LOG(WARNING) << "Invalid message set extension.";
530     return SerializeFieldWithCachedSizesToArray(number, target);
531   }
532 
533   if (is_cleared) return target;
534 
535   // Start group.
536   target = io::CodedOutputStream::WriteTagToArray(
537       WireFormatLite::kMessageSetItemStartTag, target);
538   // Write type ID.
539   target = WireFormatLite::WriteUInt32ToArray(
540       WireFormatLite::kMessageSetTypeIdNumber, number, target);
541   // Write message.
542   if (is_lazy) {
543     target = lazymessage_value->WriteMessageToArray(
544         WireFormatLite::kMessageSetMessageNumber, target);
545   } else {
546     target = WireFormatLite::WriteMessageToArray(
547         WireFormatLite::kMessageSetMessageNumber, *message_value, target);
548   }
549   // End group.
550   target = io::CodedOutputStream::WriteTagToArray(
551       WireFormatLite::kMessageSetItemEndTag, target);
552   return target;
553 }
554 
555 
ParseFieldMaybeLazily(int wire_type,int field_number,io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)556 bool ExtensionSet::ParseFieldMaybeLazily(
557     int wire_type, int field_number, io::CodedInputStream* input,
558     ExtensionFinder* extension_finder,
559     MessageSetFieldSkipper* field_skipper) {
560   return ParseField(WireFormatLite::MakeTag(
561       field_number, static_cast<WireFormatLite::WireType>(wire_type)),
562                     input, extension_finder, field_skipper);
563 }
564 
ParseMessageSet(io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)565 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
566                                    ExtensionFinder* extension_finder,
567                                    MessageSetFieldSkipper* field_skipper) {
568   while (true) {
569     const uint32 tag = input->ReadTag();
570     switch (tag) {
571       case 0:
572         return true;
573       case WireFormatLite::kMessageSetItemStartTag:
574         if (!ParseMessageSetItem(input, extension_finder, field_skipper)) {
575           return false;
576         }
577         break;
578       default:
579         if (!ParseField(tag, input, extension_finder, field_skipper)) {
580           return false;
581         }
582         break;
583     }
584   }
585 }
586 
ParseMessageSet(io::CodedInputStream * input,const MessageLite * containing_type)587 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
588                                    const MessageLite* containing_type) {
589   MessageSetFieldSkipper skipper(NULL);
590   GeneratedExtensionFinder finder(containing_type);
591   return ParseMessageSet(input, &finder, &skipper);
592 }
593 
ParseMessageSetItem(io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)594 bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
595                                        ExtensionFinder* extension_finder,
596                                        MessageSetFieldSkipper* field_skipper) {
597   // TODO(kenton):  It would be nice to share code between this and
598   // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the
599   // differences would be hard to factor out.
600 
601   // This method parses a group which should contain two fields:
602   //   required int32 type_id = 2;
603   //   required data message = 3;
604 
605   uint32 last_type_id = 0;
606 
607   // If we see message data before the type_id, we'll append it to this so
608   // we can parse it later.
609   string message_data;
610 
611   while (true) {
612     const uint32 tag = input->ReadTag();
613     if (tag == 0) return false;
614 
615     switch (tag) {
616       case WireFormatLite::kMessageSetTypeIdTag: {
617         uint32 type_id;
618         if (!input->ReadVarint32(&type_id)) return false;
619         last_type_id = type_id;
620 
621         if (!message_data.empty()) {
622           // We saw some message data before the type_id.  Have to parse it
623           // now.
624           io::CodedInputStream sub_input(
625               reinterpret_cast<const uint8*>(message_data.data()),
626               message_data.size());
627           if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
628                                      last_type_id, &sub_input,
629                                      extension_finder, field_skipper)) {
630             return false;
631           }
632           message_data.clear();
633         }
634 
635         break;
636       }
637 
638       case WireFormatLite::kMessageSetMessageTag: {
639         if (last_type_id == 0) {
640           // We haven't seen a type_id yet.  Append this data to message_data.
641           string temp;
642           uint32 length;
643           if (!input->ReadVarint32(&length)) return false;
644           if (!input->ReadString(&temp, length)) return false;
645           io::StringOutputStream output_stream(&message_data);
646           io::CodedOutputStream coded_output(&output_stream);
647           coded_output.WriteVarint32(length);
648           coded_output.WriteString(temp);
649         } else {
650           // Already saw type_id, so we can parse this directly.
651           if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
652                                      last_type_id, input,
653                                      extension_finder, field_skipper)) {
654             return false;
655           }
656         }
657 
658         break;
659       }
660 
661       case WireFormatLite::kMessageSetItemEndTag: {
662         return true;
663       }
664 
665       default: {
666         if (!field_skipper->SkipField(input, tag)) return false;
667       }
668     }
669   }
670 }
671 
SerializeMessageSetItemWithCachedSizes(int number,io::CodedOutputStream * output) const672 void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes(
673     int number,
674     io::CodedOutputStream* output) const {
675   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
676     // Not a valid MessageSet extension, but serialize it the normal way.
677     SerializeFieldWithCachedSizes(number, output);
678     return;
679   }
680 
681   if (is_cleared) return;
682 
683   // Start group.
684   output->WriteTag(WireFormatLite::kMessageSetItemStartTag);
685 
686   // Write type ID.
687   WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
688                               number,
689                               output);
690   // Write message.
691   if (is_lazy) {
692     lazymessage_value->WriteMessage(
693         WireFormatLite::kMessageSetMessageNumber, output);
694   } else {
695     WireFormatLite::WriteMessageMaybeToArray(
696         WireFormatLite::kMessageSetMessageNumber,
697         *message_value,
698         output);
699   }
700 
701   // End group.
702   output->WriteTag(WireFormatLite::kMessageSetItemEndTag);
703 }
704 
MessageSetItemByteSize(int number) const705 int ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
706   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
707     // Not a valid MessageSet extension, but compute the byte size for it the
708     // normal way.
709     return ByteSize(number);
710   }
711 
712   if (is_cleared) return 0;
713 
714   int our_size = WireFormatLite::kMessageSetItemTagsSize;
715 
716   // type_id
717   our_size += io::CodedOutputStream::VarintSize32(number);
718 
719   // message
720   int message_size = 0;
721   if (is_lazy) {
722     message_size = lazymessage_value->ByteSize();
723   } else {
724     message_size = message_value->ByteSize();
725   }
726 
727   our_size += io::CodedOutputStream::VarintSize32(message_size);
728   our_size += message_size;
729 
730   return our_size;
731 }
732 
SerializeMessageSetWithCachedSizes(io::CodedOutputStream * output) const733 void ExtensionSet::SerializeMessageSetWithCachedSizes(
734     io::CodedOutputStream* output) const {
735   for (map<int, Extension>::const_iterator iter = extensions_.begin();
736        iter != extensions_.end(); ++iter) {
737     iter->second.SerializeMessageSetItemWithCachedSizes(iter->first, output);
738   }
739 }
740 
MessageSetByteSize() const741 int ExtensionSet::MessageSetByteSize() const {
742   int total_size = 0;
743 
744   for (map<int, Extension>::const_iterator iter = extensions_.begin();
745        iter != extensions_.end(); ++iter) {
746     total_size += iter->second.MessageSetItemByteSize(iter->first);
747   }
748 
749   return total_size;
750 }
751 
752 }  // namespace internal
753 }  // namespace protobuf
754 }  // namespace google
755