1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 //
35 // Contains methods defined in extension_set.h which cannot be part of the
36 // lite library because they use descriptors or reflection.
37 
38 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
39 #include <google/protobuf/descriptor.h>
40 #include <google/protobuf/extension_set.h>
41 #include <google/protobuf/message.h>
42 #include <google/protobuf/repeated_field.h>
43 #include <google/protobuf/wire_format.h>
44 #include <google/protobuf/wire_format_lite_inl.h>
45 
46 namespace google {
47 
48 namespace protobuf {
49 namespace internal {
50 
51 // A FieldSkipper used to store unknown MessageSet fields into UnknownFieldSet.
52 class MessageSetFieldSkipper
53     : public UnknownFieldSetFieldSkipper {
54  public:
MessageSetFieldSkipper(UnknownFieldSet * unknown_fields)55   explicit MessageSetFieldSkipper(UnknownFieldSet* unknown_fields)
56       : UnknownFieldSetFieldSkipper(unknown_fields) {}
~MessageSetFieldSkipper()57   virtual ~MessageSetFieldSkipper() {}
58 
59   virtual bool SkipMessageSetField(io::CodedInputStream* input,
60                                    int field_number);
61 };
SkipMessageSetField(io::CodedInputStream * input,int field_number)62 bool MessageSetFieldSkipper::SkipMessageSetField(
63     io::CodedInputStream* input, int field_number) {
64   uint32 length;
65   if (!input->ReadVarint32(&length)) return false;
66   if (unknown_fields_ == NULL) {
67     return input->Skip(length);
68   } else {
69     return input->ReadString(
70         unknown_fields_->AddLengthDelimited(field_number), length);
71   }
72 }
73 
74 
75 // Implementation of ExtensionFinder which finds extensions in a given
76 // DescriptorPool, using the given MessageFactory to construct sub-objects.
77 // This class is implemented in extension_set_heavy.cc.
78 class DescriptorPoolExtensionFinder : public ExtensionFinder {
79  public:
DescriptorPoolExtensionFinder(const DescriptorPool * pool,MessageFactory * factory,const Descriptor * containing_type)80   DescriptorPoolExtensionFinder(const DescriptorPool* pool,
81                                 MessageFactory* factory,
82                                 const Descriptor* containing_type)
83       : pool_(pool), factory_(factory), containing_type_(containing_type) {}
~DescriptorPoolExtensionFinder()84   virtual ~DescriptorPoolExtensionFinder() {}
85 
86   virtual bool Find(int number, ExtensionInfo* output);
87 
88  private:
89   const DescriptorPool* pool_;
90   MessageFactory* factory_;
91   const Descriptor* containing_type_;
92 };
93 
AppendToList(const Descriptor * containing_type,const DescriptorPool * pool,std::vector<const FieldDescriptor * > * output) const94 void ExtensionSet::AppendToList(const Descriptor* containing_type,
95                                 const DescriptorPool* pool,
96                                 std::vector<const FieldDescriptor*>* output) const {
97   for (map<int, Extension>::const_iterator iter = extensions_.begin();
98        iter != extensions_.end(); ++iter) {
99     bool has = false;
100     if (iter->second.is_repeated) {
101       has = iter->second.GetSize() > 0;
102     } else {
103       has = !iter->second.is_cleared;
104     }
105 
106     if (has) {
107       // TODO(kenton): Looking up each field by number is somewhat unfortunate.
108       //   Is there a better way?  The problem is that descriptors are lazily-
109       //   initialized, so they might not even be constructed until
110       //   AppendToList() is called.
111 
112       if (iter->second.descriptor == NULL) {
113         output->push_back(pool->FindExtensionByNumber(
114             containing_type, iter->first));
115       } else {
116         output->push_back(iter->second.descriptor);
117       }
118     }
119   }
120 }
121 
real_type(FieldType type)122 inline FieldDescriptor::Type real_type(FieldType type) {
123   GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
124   return static_cast<FieldDescriptor::Type>(type);
125 }
126 
cpp_type(FieldType type)127 inline FieldDescriptor::CppType cpp_type(FieldType type) {
128   return FieldDescriptor::TypeToCppType(
129       static_cast<FieldDescriptor::Type>(type));
130 }
131 
field_type(FieldType type)132 inline WireFormatLite::FieldType field_type(FieldType type) {
133   GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
134   return static_cast<WireFormatLite::FieldType>(type);
135 }
136 
137 #define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE)                            \
138   GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED     \
139                                   : FieldDescriptor::LABEL_OPTIONAL,      \
140             FieldDescriptor::LABEL_##LABEL);                              \
141   GOOGLE_DCHECK_EQ(cpp_type((EXTENSION).type), FieldDescriptor::CPPTYPE_##CPPTYPE)
142 
GetMessage(int number,const Descriptor * message_type,MessageFactory * factory) const143 const MessageLite& ExtensionSet::GetMessage(int number,
144                                             const Descriptor* message_type,
145                                             MessageFactory* factory) const {
146   map<int, Extension>::const_iterator iter = extensions_.find(number);
147   if (iter == extensions_.end() || iter->second.is_cleared) {
148     // Not present.  Return the default value.
149     return *factory->GetPrototype(message_type);
150   } else {
151     GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
152     if (iter->second.is_lazy) {
153       return iter->second.lazymessage_value->GetMessage(
154           *factory->GetPrototype(message_type));
155     } else {
156       return *iter->second.message_value;
157     }
158   }
159 }
160 
MutableMessage(const FieldDescriptor * descriptor,MessageFactory * factory)161 MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
162                                           MessageFactory* factory) {
163   Extension* extension;
164   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
165     extension->type = descriptor->type();
166     GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
167     extension->is_repeated = false;
168     extension->is_packed = false;
169     const MessageLite* prototype =
170         factory->GetPrototype(descriptor->message_type());
171     extension->is_lazy = false;
172     extension->message_value = prototype->New();
173     extension->is_cleared = false;
174     return extension->message_value;
175   } else {
176     GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
177     extension->is_cleared = false;
178     if (extension->is_lazy) {
179       return extension->lazymessage_value->MutableMessage(
180           *factory->GetPrototype(descriptor->message_type()));
181     } else {
182       return extension->message_value;
183     }
184   }
185 }
186 
ReleaseMessage(const FieldDescriptor * descriptor,MessageFactory * factory)187 MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
188                                           MessageFactory* factory) {
189   map<int, Extension>::iterator iter = extensions_.find(descriptor->number());
190   if (iter == extensions_.end()) {
191     // Not present.  Return NULL.
192     return NULL;
193   } else {
194     GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
195     MessageLite* ret = NULL;
196     if (iter->second.is_lazy) {
197       ret = iter->second.lazymessage_value->ReleaseMessage(
198           *factory->GetPrototype(descriptor->message_type()));
199       delete iter->second.lazymessage_value;
200     } else {
201       ret = iter->second.message_value;
202     }
203     extensions_.erase(descriptor->number());
204     return ret;
205   }
206 }
207 
AddMessage(const FieldDescriptor * descriptor,MessageFactory * factory)208 MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
209                                       MessageFactory* factory) {
210   Extension* extension;
211   if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
212     extension->type = descriptor->type();
213     GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
214     extension->is_repeated = true;
215     extension->repeated_message_value =
216       new RepeatedPtrField<MessageLite>();
217   } else {
218     GOOGLE_DCHECK_TYPE(*extension, REPEATED, MESSAGE);
219   }
220 
221   // RepeatedPtrField<Message> does not know how to Add() since it cannot
222   // allocate an abstract object, so we have to be tricky.
223   MessageLite* result = extension->repeated_message_value
224       ->AddFromCleared<GenericTypeHandler<MessageLite> >();
225   if (result == NULL) {
226     const MessageLite* prototype;
227     if (extension->repeated_message_value->size() == 0) {
228       prototype = factory->GetPrototype(descriptor->message_type());
229       GOOGLE_CHECK(prototype != NULL);
230     } else {
231       prototype = &extension->repeated_message_value->Get(0);
232     }
233     result = prototype->New();
234     extension->repeated_message_value->AddAllocated(result);
235   }
236   return result;
237 }
238 
ValidateEnumUsingDescriptor(const void * arg,int number)239 static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
240   return reinterpret_cast<const EnumDescriptor*>(arg)
241       ->FindValueByNumber(number) != NULL;
242 }
243 
Find(int number,ExtensionInfo * output)244 bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
245   const FieldDescriptor* extension =
246       pool_->FindExtensionByNumber(containing_type_, number);
247   if (extension == NULL) {
248     return false;
249   } else {
250     output->type = extension->type();
251     output->is_repeated = extension->is_repeated();
252     output->is_packed = extension->options().packed();
253     output->descriptor = extension;
254     if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
255       output->message_prototype =
256           factory_->GetPrototype(extension->message_type());
257       GOOGLE_CHECK(output->message_prototype != NULL)
258           << "Extension factory's GetPrototype() returned NULL for extension: "
259           << extension->full_name();
260     } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
261       output->enum_validity_check.func = ValidateEnumUsingDescriptor;
262       output->enum_validity_check.arg = extension->enum_type();
263     }
264 
265     return true;
266   }
267 }
268 
ParseField(uint32 tag,io::CodedInputStream * input,const Message * containing_type,UnknownFieldSet * unknown_fields)269 bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
270                               const Message* containing_type,
271                               UnknownFieldSet* unknown_fields) {
272   UnknownFieldSetFieldSkipper skipper(unknown_fields);
273   if (input->GetExtensionPool() == NULL) {
274     GeneratedExtensionFinder finder(containing_type);
275     return ParseField(tag, input, &finder, &skipper);
276   } else {
277     DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
278                                          input->GetExtensionFactory(),
279                                          containing_type->GetDescriptor());
280     return ParseField(tag, input, &finder, &skipper);
281   }
282 }
283 
ParseMessageSet(io::CodedInputStream * input,const Message * containing_type,UnknownFieldSet * unknown_fields)284 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
285                                    const Message* containing_type,
286                                    UnknownFieldSet* unknown_fields) {
287   MessageSetFieldSkipper skipper(unknown_fields);
288   if (input->GetExtensionPool() == NULL) {
289     GeneratedExtensionFinder finder(containing_type);
290     return ParseMessageSet(input, &finder, &skipper);
291   } else {
292     DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
293                                          input->GetExtensionFactory(),
294                                          containing_type->GetDescriptor());
295     return ParseMessageSet(input, &finder, &skipper);
296   }
297 }
298 
SpaceUsedExcludingSelf() const299 int ExtensionSet::SpaceUsedExcludingSelf() const {
300   int total_size =
301       extensions_.size() * sizeof(map<int, Extension>::value_type);
302   for (map<int, Extension>::const_iterator iter = extensions_.begin(),
303        end = extensions_.end();
304        iter != end;
305        ++iter) {
306     total_size += iter->second.SpaceUsedExcludingSelf();
307   }
308   return total_size;
309 }
310 
RepeatedMessage_SpaceUsedExcludingSelf(RepeatedPtrFieldBase * field)311 inline int ExtensionSet::RepeatedMessage_SpaceUsedExcludingSelf(
312     RepeatedPtrFieldBase* field) {
313   return field->SpaceUsedExcludingSelf<GenericTypeHandler<Message> >();
314 }
315 
SpaceUsedExcludingSelf() const316 int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
317   int total_size = 0;
318   if (is_repeated) {
319     switch (cpp_type(type)) {
320 #define HANDLE_TYPE(UPPERCASE, LOWERCASE)                          \
321       case FieldDescriptor::CPPTYPE_##UPPERCASE:                   \
322         total_size += sizeof(*repeated_##LOWERCASE##_value) +      \
323             repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\
324         break
325 
326       HANDLE_TYPE(  INT32,   int32);
327       HANDLE_TYPE(  INT64,   int64);
328       HANDLE_TYPE( UINT32,  uint32);
329       HANDLE_TYPE( UINT64,  uint64);
330       HANDLE_TYPE(  FLOAT,   float);
331       HANDLE_TYPE( DOUBLE,  double);
332       HANDLE_TYPE(   BOOL,    bool);
333       HANDLE_TYPE(   ENUM,    enum);
334       HANDLE_TYPE( STRING,  string);
335 #undef HANDLE_TYPE
336 
337       case FieldDescriptor::CPPTYPE_MESSAGE:
338         // repeated_message_value is actually a RepeatedPtrField<MessageLite>,
339         // but MessageLite has no SpaceUsed(), so we must directly call
340         // RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type
341         // handler.
342         total_size += sizeof(*repeated_message_value) +
343             RepeatedMessage_SpaceUsedExcludingSelf(repeated_message_value);
344         break;
345     }
346   } else {
347     switch (cpp_type(type)) {
348       case FieldDescriptor::CPPTYPE_STRING:
349         total_size += sizeof(*string_value) +
350                       StringSpaceUsedExcludingSelf(*string_value);
351         break;
352       case FieldDescriptor::CPPTYPE_MESSAGE:
353         if (is_lazy) {
354           total_size += lazymessage_value->SpaceUsed();
355         } else {
356           total_size += down_cast<Message*>(message_value)->SpaceUsed();
357         }
358         break;
359       default:
360         // No extra storage costs for primitive types.
361         break;
362     }
363   }
364   return total_size;
365 }
366 
367 // The Serialize*ToArray methods are only needed in the heavy library, as
368 // the lite library only generates SerializeWithCachedSizes.
SerializeWithCachedSizesToArray(int start_field_number,int end_field_number,uint8 * target) const369 uint8* ExtensionSet::SerializeWithCachedSizesToArray(
370     int start_field_number, int end_field_number,
371     uint8* target) const {
372   map<int, Extension>::const_iterator iter;
373   for (iter = extensions_.lower_bound(start_field_number);
374        iter != extensions_.end() && iter->first < end_field_number;
375        ++iter) {
376     target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first,
377                                                                target);
378   }
379   return target;
380 }
381 
SerializeMessageSetWithCachedSizesToArray(uint8 * target) const382 uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
383     uint8* target) const {
384   map<int, Extension>::const_iterator iter;
385   for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) {
386     target = iter->second.SerializeMessageSetItemWithCachedSizesToArray(
387         iter->first, target);
388   }
389   return target;
390 }
391 
SerializeFieldWithCachedSizesToArray(int number,uint8 * target) const392 uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
393     int number, uint8* target) const {
394   if (is_repeated) {
395     if (is_packed) {
396       if (cached_size == 0) return target;
397 
398       target = WireFormatLite::WriteTagToArray(number,
399           WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
400       target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target);
401 
402       switch (real_type(type)) {
403 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
404         case FieldDescriptor::TYPE_##UPPERCASE:                             \
405           for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
406             target = WireFormatLite::Write##CAMELCASE##NoTagToArray(        \
407               repeated_##LOWERCASE##_value->Get(i), target);                \
408           }                                                                 \
409           break
410 
411         HANDLE_TYPE(   INT32,    Int32,   int32);
412         HANDLE_TYPE(   INT64,    Int64,   int64);
413         HANDLE_TYPE(  UINT32,   UInt32,  uint32);
414         HANDLE_TYPE(  UINT64,   UInt64,  uint64);
415         HANDLE_TYPE(  SINT32,   SInt32,   int32);
416         HANDLE_TYPE(  SINT64,   SInt64,   int64);
417         HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
418         HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
419         HANDLE_TYPE(SFIXED32, SFixed32,   int32);
420         HANDLE_TYPE(SFIXED64, SFixed64,   int64);
421         HANDLE_TYPE(   FLOAT,    Float,   float);
422         HANDLE_TYPE(  DOUBLE,   Double,  double);
423         HANDLE_TYPE(    BOOL,     Bool,    bool);
424         HANDLE_TYPE(    ENUM,     Enum,    enum);
425 #undef HANDLE_TYPE
426 
427         case WireFormatLite::TYPE_STRING:
428         case WireFormatLite::TYPE_BYTES:
429         case WireFormatLite::TYPE_GROUP:
430         case WireFormatLite::TYPE_MESSAGE:
431           GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
432           break;
433       }
434     } else {
435       switch (real_type(type)) {
436 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE)                        \
437         case FieldDescriptor::TYPE_##UPPERCASE:                             \
438           for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) {  \
439             target = WireFormatLite::Write##CAMELCASE##ToArray(number,      \
440               repeated_##LOWERCASE##_value->Get(i), target);                \
441           }                                                                 \
442           break
443 
444         HANDLE_TYPE(   INT32,    Int32,   int32);
445         HANDLE_TYPE(   INT64,    Int64,   int64);
446         HANDLE_TYPE(  UINT32,   UInt32,  uint32);
447         HANDLE_TYPE(  UINT64,   UInt64,  uint64);
448         HANDLE_TYPE(  SINT32,   SInt32,   int32);
449         HANDLE_TYPE(  SINT64,   SInt64,   int64);
450         HANDLE_TYPE( FIXED32,  Fixed32,  uint32);
451         HANDLE_TYPE( FIXED64,  Fixed64,  uint64);
452         HANDLE_TYPE(SFIXED32, SFixed32,   int32);
453         HANDLE_TYPE(SFIXED64, SFixed64,   int64);
454         HANDLE_TYPE(   FLOAT,    Float,   float);
455         HANDLE_TYPE(  DOUBLE,   Double,  double);
456         HANDLE_TYPE(    BOOL,     Bool,    bool);
457         HANDLE_TYPE(  STRING,   String,  string);
458         HANDLE_TYPE(   BYTES,    Bytes,  string);
459         HANDLE_TYPE(    ENUM,     Enum,    enum);
460         HANDLE_TYPE(   GROUP,    Group, message);
461         HANDLE_TYPE( MESSAGE,  Message, message);
462 #undef HANDLE_TYPE
463       }
464     }
465   } else if (!is_cleared) {
466     switch (real_type(type)) {
467 #define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE)                 \
468       case FieldDescriptor::TYPE_##UPPERCASE:                    \
469         target = WireFormatLite::Write##CAMELCASE##ToArray(      \
470             number, VALUE, target); \
471         break
472 
473       HANDLE_TYPE(   INT32,    Int32,    int32_value);
474       HANDLE_TYPE(   INT64,    Int64,    int64_value);
475       HANDLE_TYPE(  UINT32,   UInt32,   uint32_value);
476       HANDLE_TYPE(  UINT64,   UInt64,   uint64_value);
477       HANDLE_TYPE(  SINT32,   SInt32,    int32_value);
478       HANDLE_TYPE(  SINT64,   SInt64,    int64_value);
479       HANDLE_TYPE( FIXED32,  Fixed32,   uint32_value);
480       HANDLE_TYPE( FIXED64,  Fixed64,   uint64_value);
481       HANDLE_TYPE(SFIXED32, SFixed32,    int32_value);
482       HANDLE_TYPE(SFIXED64, SFixed64,    int64_value);
483       HANDLE_TYPE(   FLOAT,    Float,    float_value);
484       HANDLE_TYPE(  DOUBLE,   Double,   double_value);
485       HANDLE_TYPE(    BOOL,     Bool,     bool_value);
486       HANDLE_TYPE(  STRING,   String,  *string_value);
487       HANDLE_TYPE(   BYTES,    Bytes,  *string_value);
488       HANDLE_TYPE(    ENUM,     Enum,     enum_value);
489       HANDLE_TYPE(   GROUP,    Group, *message_value);
490 #undef HANDLE_TYPE
491       case FieldDescriptor::TYPE_MESSAGE:
492         if (is_lazy) {
493           target = lazymessage_value->WriteMessageToArray(number, target);
494         } else {
495           target = WireFormatLite::WriteMessageToArray(
496               number, *message_value, target);
497         }
498         break;
499     }
500   }
501   return target;
502 }
503 
SerializeMessageSetItemWithCachedSizesToArray(int number,uint8 * target) const504 uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
505     int number,
506     uint8* target) const {
507   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
508     // Not a valid MessageSet extension, but serialize it the normal way.
509     GOOGLE_LOG(WARNING) << "Invalid message set extension.";
510     return SerializeFieldWithCachedSizesToArray(number, target);
511   }
512 
513   if (is_cleared) return target;
514 
515   // Start group.
516   target = io::CodedOutputStream::WriteTagToArray(
517       WireFormatLite::kMessageSetItemStartTag, target);
518   // Write type ID.
519   target = WireFormatLite::WriteUInt32ToArray(
520       WireFormatLite::kMessageSetTypeIdNumber, number, target);
521   // Write message.
522   if (is_lazy) {
523     target = lazymessage_value->WriteMessageToArray(
524         WireFormatLite::kMessageSetMessageNumber, target);
525   } else {
526     target = WireFormatLite::WriteMessageToArray(
527         WireFormatLite::kMessageSetMessageNumber, *message_value, target);
528   }
529   // End group.
530   target = io::CodedOutputStream::WriteTagToArray(
531       WireFormatLite::kMessageSetItemEndTag, target);
532   return target;
533 }
534 
535 
ParseFieldMaybeLazily(int wire_type,int field_number,io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)536 bool ExtensionSet::ParseFieldMaybeLazily(
537     int wire_type, int field_number, io::CodedInputStream* input,
538     ExtensionFinder* extension_finder,
539     MessageSetFieldSkipper* field_skipper) {
540   return ParseField(WireFormatLite::MakeTag(
541       field_number, static_cast<WireFormatLite::WireType>(wire_type)),
542                     input, extension_finder, field_skipper);
543 }
544 
ParseMessageSet(io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)545 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
546                                    ExtensionFinder* extension_finder,
547                                    MessageSetFieldSkipper* field_skipper) {
548   while (true) {
549     const uint32 tag = input->ReadTag();
550     switch (tag) {
551       case 0:
552         return true;
553       case WireFormatLite::kMessageSetItemStartTag:
554         if (!ParseMessageSetItem(input, extension_finder, field_skipper)) {
555           return false;
556         }
557         break;
558       default:
559         if (!ParseField(tag, input, extension_finder, field_skipper)) {
560           return false;
561         }
562         break;
563     }
564   }
565 }
566 
ParseMessageSet(io::CodedInputStream * input,const MessageLite * containing_type)567 bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
568                                    const MessageLite* containing_type) {
569   MessageSetFieldSkipper skipper(NULL);
570   GeneratedExtensionFinder finder(containing_type);
571   return ParseMessageSet(input, &finder, &skipper);
572 }
573 
ParseMessageSetItem(io::CodedInputStream * input,ExtensionFinder * extension_finder,MessageSetFieldSkipper * field_skipper)574 bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
575                                        ExtensionFinder* extension_finder,
576                                        MessageSetFieldSkipper* field_skipper) {
577   // TODO(kenton):  It would be nice to share code between this and
578   // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the
579   // differences would be hard to factor out.
580 
581   // This method parses a group which should contain two fields:
582   //   required int32 type_id = 2;
583   //   required data message = 3;
584 
585   uint32 last_type_id = 0;
586 
587   // If we see message data before the type_id, we'll append it to this so
588   // we can parse it later.
589   string message_data;
590 
591   while (true) {
592     const uint32 tag = input->ReadTag();
593     if (tag == 0) return false;
594 
595     switch (tag) {
596       case WireFormatLite::kMessageSetTypeIdTag: {
597         uint32 type_id;
598         if (!input->ReadVarint32(&type_id)) return false;
599         last_type_id = type_id;
600 
601         if (!message_data.empty()) {
602           // We saw some message data before the type_id.  Have to parse it
603           // now.
604           io::CodedInputStream sub_input(
605               reinterpret_cast<const uint8*>(message_data.data()),
606               message_data.size());
607           if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
608                                      last_type_id, &sub_input,
609                                      extension_finder, field_skipper)) {
610             return false;
611           }
612           message_data.clear();
613         }
614 
615         break;
616       }
617 
618       case WireFormatLite::kMessageSetMessageTag: {
619         if (last_type_id == 0) {
620           // We haven't seen a type_id yet.  Append this data to message_data.
621           string temp;
622           uint32 length;
623           if (!input->ReadVarint32(&length)) return false;
624           if (!input->ReadString(&temp, length)) return false;
625           io::StringOutputStream output_stream(&message_data);
626           io::CodedOutputStream coded_output(&output_stream);
627           coded_output.WriteVarint32(length);
628           coded_output.WriteString(temp);
629         } else {
630           // Already saw type_id, so we can parse this directly.
631           if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
632                                      last_type_id, input,
633                                      extension_finder, field_skipper)) {
634             return false;
635           }
636         }
637 
638         break;
639       }
640 
641       case WireFormatLite::kMessageSetItemEndTag: {
642         return true;
643       }
644 
645       default: {
646         if (!field_skipper->SkipField(input, tag)) return false;
647       }
648     }
649   }
650 }
651 
SerializeMessageSetItemWithCachedSizes(int number,io::CodedOutputStream * output) const652 void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes(
653     int number,
654     io::CodedOutputStream* output) const {
655   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
656     // Not a valid MessageSet extension, but serialize it the normal way.
657     SerializeFieldWithCachedSizes(number, output);
658     return;
659   }
660 
661   if (is_cleared) return;
662 
663   // Start group.
664   output->WriteTag(WireFormatLite::kMessageSetItemStartTag);
665 
666   // Write type ID.
667   WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
668                               number,
669                               output);
670   // Write message.
671   if (is_lazy) {
672     lazymessage_value->WriteMessage(
673         WireFormatLite::kMessageSetMessageNumber, output);
674   } else {
675     WireFormatLite::WriteMessageMaybeToArray(
676         WireFormatLite::kMessageSetMessageNumber,
677         *message_value,
678         output);
679   }
680 
681   // End group.
682   output->WriteTag(WireFormatLite::kMessageSetItemEndTag);
683 }
684 
MessageSetItemByteSize(int number) const685 int ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
686   if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
687     // Not a valid MessageSet extension, but compute the byte size for it the
688     // normal way.
689     return ByteSize(number);
690   }
691 
692   if (is_cleared) return 0;
693 
694   int our_size = WireFormatLite::kMessageSetItemTagsSize;
695 
696   // type_id
697   our_size += io::CodedOutputStream::VarintSize32(number);
698 
699   // message
700   int message_size = 0;
701   if (is_lazy) {
702     message_size = lazymessage_value->ByteSize();
703   } else {
704     message_size = message_value->ByteSize();
705   }
706 
707   our_size += io::CodedOutputStream::VarintSize32(message_size);
708   our_size += message_size;
709 
710   return our_size;
711 }
712 
SerializeMessageSetWithCachedSizes(io::CodedOutputStream * output) const713 void ExtensionSet::SerializeMessageSetWithCachedSizes(
714     io::CodedOutputStream* output) const {
715   for (map<int, Extension>::const_iterator iter = extensions_.begin();
716        iter != extensions_.end(); ++iter) {
717     iter->second.SerializeMessageSetItemWithCachedSizes(iter->first, output);
718   }
719 }
720 
MessageSetByteSize() const721 int ExtensionSet::MessageSetByteSize() const {
722   int total_size = 0;
723 
724   for (map<int, Extension>::const_iterator iter = extensions_.begin();
725        iter != extensions_.end(); ++iter) {
726     total_size += iter->second.MessageSetItemByteSize(iter->first);
727   }
728 
729   return total_size;
730 }
731 
732 }  // namespace internal
733 }  // namespace protobuf
734 }  // namespace google
735