1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_FIELD_INSTANCE_H_
16 #define SRC_FIELD_INSTANCE_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "port/protobuf.h"
22 
23 namespace protobuf_mutator {
24 
25 // Helper class for common protobuf fields operations.
26 class ConstFieldInstance {
27  public:
28   static const size_t kInvalidIndex = -1;
29 
30   struct Enum {
31     size_t index;
32     size_t count;
33   };
34 
ConstFieldInstance()35   ConstFieldInstance()
36       : message_(nullptr), descriptor_(nullptr), index_(kInvalidIndex) {}
37 
ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)38   ConstFieldInstance(const protobuf::Message* message,
39                      const protobuf::FieldDescriptor* field, size_t index)
40       : message_(message), descriptor_(field), index_(index) {
41     assert(message_);
42     assert(descriptor_);
43     assert(index_ != kInvalidIndex);
44     assert(descriptor_->is_repeated());
45   }
46 
ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field)47   ConstFieldInstance(const protobuf::Message* message,
48                      const protobuf::FieldDescriptor* field)
49       : message_(message), descriptor_(field), index_(kInvalidIndex) {
50     assert(message_);
51     assert(descriptor_);
52     assert(!descriptor_->is_repeated());
53   }
54 
GetDefault(int32_t * out)55   void GetDefault(int32_t* out) const {
56     *out = descriptor_->default_value_int32();
57   }
58 
GetDefault(int64_t * out)59   void GetDefault(int64_t* out) const {
60     *out = descriptor_->default_value_int64();
61   }
62 
GetDefault(uint32_t * out)63   void GetDefault(uint32_t* out) const {
64     *out = descriptor_->default_value_uint32();
65   }
66 
GetDefault(uint64_t * out)67   void GetDefault(uint64_t* out) const {
68     *out = descriptor_->default_value_uint64();
69   }
70 
GetDefault(double * out)71   void GetDefault(double* out) const {
72     *out = descriptor_->default_value_double();
73   }
74 
GetDefault(float * out)75   void GetDefault(float* out) const {
76     *out = descriptor_->default_value_float();
77   }
78 
GetDefault(bool * out)79   void GetDefault(bool* out) const { *out = descriptor_->default_value_bool(); }
80 
GetDefault(Enum * out)81   void GetDefault(Enum* out) const {
82     const protobuf::EnumValueDescriptor* value =
83         descriptor_->default_value_enum();
84     const protobuf::EnumDescriptor* type = value->type();
85     *out = {static_cast<size_t>(value->index()),
86             static_cast<size_t>(type->value_count())};
87   }
88 
GetDefault(std::string * out)89   void GetDefault(std::string* out) const {
90     *out = descriptor_->default_value_string();
91   }
92 
GetDefault(std::unique_ptr<protobuf::Message> * out)93   void GetDefault(std::unique_ptr<protobuf::Message>* out) const {
94     out->reset(reflection()
95                    .GetMessageFactory()
96                    ->GetPrototype(descriptor_->message_type())
97                    ->New());
98   }
99 
Load(int32_t * value)100   void Load(int32_t* value) const {
101     *value = is_repeated()
102                  ? reflection().GetRepeatedInt32(*message_, descriptor_, index_)
103                  : reflection().GetInt32(*message_, descriptor_);
104   }
105 
Load(int64_t * value)106   void Load(int64_t* value) const {
107     *value = is_repeated()
108                  ? reflection().GetRepeatedInt64(*message_, descriptor_, index_)
109                  : reflection().GetInt64(*message_, descriptor_);
110   }
111 
Load(uint32_t * value)112   void Load(uint32_t* value) const {
113     *value = is_repeated() ? reflection().GetRepeatedUInt32(*message_,
114                                                             descriptor_, index_)
115                            : reflection().GetUInt32(*message_, descriptor_);
116   }
117 
Load(uint64_t * value)118   void Load(uint64_t* value) const {
119     *value = is_repeated() ? reflection().GetRepeatedUInt64(*message_,
120                                                             descriptor_, index_)
121                            : reflection().GetUInt64(*message_, descriptor_);
122   }
123 
Load(double * value)124   void Load(double* value) const {
125     *value = is_repeated() ? reflection().GetRepeatedDouble(*message_,
126                                                             descriptor_, index_)
127                            : reflection().GetDouble(*message_, descriptor_);
128   }
129 
Load(float * value)130   void Load(float* value) const {
131     *value = is_repeated()
132                  ? reflection().GetRepeatedFloat(*message_, descriptor_, index_)
133                  : reflection().GetFloat(*message_, descriptor_);
134   }
135 
Load(bool * value)136   void Load(bool* value) const {
137     *value = is_repeated()
138                  ? reflection().GetRepeatedBool(*message_, descriptor_, index_)
139                  : reflection().GetBool(*message_, descriptor_);
140   }
141 
Load(Enum * value)142   void Load(Enum* value) const {
143     const protobuf::EnumValueDescriptor* value_descriptor =
144         is_repeated()
145             ? reflection().GetRepeatedEnum(*message_, descriptor_, index_)
146             : reflection().GetEnum(*message_, descriptor_);
147     *value = {static_cast<size_t>(value_descriptor->index()),
148               static_cast<size_t>(value_descriptor->type()->value_count())};
149   }
150 
Load(std::string * value)151   void Load(std::string* value) const {
152     *value = is_repeated() ? reflection().GetRepeatedString(*message_,
153                                                             descriptor_, index_)
154                            : reflection().GetString(*message_, descriptor_);
155   }
156 
Load(std::unique_ptr<protobuf::Message> * value)157   void Load(std::unique_ptr<protobuf::Message>* value) const {
158     const protobuf::Message& source =
159         is_repeated()
160             ? reflection().GetRepeatedMessage(*message_, descriptor_, index_)
161             : reflection().GetMessage(*message_, descriptor_);
162     value->reset(source.New());
163     (*value)->CopyFrom(source);
164   }
165 
name()166   std::string name() const { return descriptor_->name(); }
167 
cpp_type()168   protobuf::FieldDescriptor::CppType cpp_type() const {
169     return descriptor_->cpp_type();
170   }
171 
enum_type()172   const protobuf::EnumDescriptor* enum_type() const {
173     return descriptor_->enum_type();
174   }
175 
message_type()176   const protobuf::Descriptor* message_type() const {
177     return descriptor_->message_type();
178   }
179 
EnforceUtf8()180   bool EnforceUtf8() const {
181     return descriptor_->type() == protobuf::FieldDescriptor::TYPE_STRING &&
182            descriptor()->file()->syntax() ==
183                protobuf::FileDescriptor::SYNTAX_PROTO3;
184   }
185 
186  protected:
is_repeated()187   bool is_repeated() const { return descriptor_->is_repeated(); }
188 
reflection()189   const protobuf::Reflection& reflection() const {
190     return *message_->GetReflection();
191   }
192 
descriptor()193   const protobuf::FieldDescriptor* descriptor() const { return descriptor_; }
194 
index()195   size_t index() const { return index_; }
196 
197  private:
198   template <class Fn, class T>
199   friend struct FieldFunction;
200 
201   const protobuf::Message* message_;
202   const protobuf::FieldDescriptor* descriptor_;
203   size_t index_;
204 };
205 
206 class FieldInstance : public ConstFieldInstance {
207  public:
208   static const size_t kInvalidIndex = -1;
209 
FieldInstance()210   FieldInstance() : ConstFieldInstance(), message_(nullptr) {}
211 
FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)212   FieldInstance(protobuf::Message* message,
213                 const protobuf::FieldDescriptor* field, size_t index)
214       : ConstFieldInstance(message, field, index), message_(message) {}
215 
FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field)216   FieldInstance(protobuf::Message* message,
217                 const protobuf::FieldDescriptor* field)
218       : ConstFieldInstance(message, field), message_(message) {}
219 
Delete()220   void Delete() const {
221     if (!is_repeated()) return reflection().ClearField(message_, descriptor());
222     int field_size = reflection().FieldSize(*message_, descriptor());
223     // API has only method to delete the last message, so we move method from
224     // the
225     // middle to the end.
226     for (int i = index() + 1; i < field_size; ++i)
227       reflection().SwapElements(message_, descriptor(), i, i - 1);
228     reflection().RemoveLast(message_, descriptor());
229   }
230 
231   template <class T>
Create(const T & value)232   void Create(const T& value) const {
233     if (!is_repeated()) return Store(value);
234     InsertRepeated(value);
235   }
236 
Store(int32_t value)237   void Store(int32_t value) const {
238     if (is_repeated())
239       reflection().SetRepeatedInt32(message_, descriptor(), index(), value);
240     else
241       reflection().SetInt32(message_, descriptor(), value);
242   }
243 
Store(int64_t value)244   void Store(int64_t value) const {
245     if (is_repeated())
246       reflection().SetRepeatedInt64(message_, descriptor(), index(), value);
247     else
248       reflection().SetInt64(message_, descriptor(), value);
249   }
250 
Store(uint32_t value)251   void Store(uint32_t value) const {
252     if (is_repeated())
253       reflection().SetRepeatedUInt32(message_, descriptor(), index(), value);
254     else
255       reflection().SetUInt32(message_, descriptor(), value);
256   }
257 
Store(uint64_t value)258   void Store(uint64_t value) const {
259     if (is_repeated())
260       reflection().SetRepeatedUInt64(message_, descriptor(), index(), value);
261     else
262       reflection().SetUInt64(message_, descriptor(), value);
263   }
264 
Store(double value)265   void Store(double value) const {
266     if (is_repeated())
267       reflection().SetRepeatedDouble(message_, descriptor(), index(), value);
268     else
269       reflection().SetDouble(message_, descriptor(), value);
270   }
271 
Store(float value)272   void Store(float value) const {
273     if (is_repeated())
274       reflection().SetRepeatedFloat(message_, descriptor(), index(), value);
275     else
276       reflection().SetFloat(message_, descriptor(), value);
277   }
278 
Store(bool value)279   void Store(bool value) const {
280     if (is_repeated())
281       reflection().SetRepeatedBool(message_, descriptor(), index(), value);
282     else
283       reflection().SetBool(message_, descriptor(), value);
284   }
285 
Store(const Enum & value)286   void Store(const Enum& value) const {
287     assert(value.index < value.count);
288     const protobuf::EnumValueDescriptor* enum_value =
289         descriptor()->enum_type()->value(value.index);
290     if (is_repeated())
291       reflection().SetRepeatedEnum(message_, descriptor(), index(), enum_value);
292     else
293       reflection().SetEnum(message_, descriptor(), enum_value);
294   }
295 
Store(const std::string & value)296   void Store(const std::string& value) const {
297     if (is_repeated())
298       reflection().SetRepeatedString(message_, descriptor(), index(), value);
299     else
300       reflection().SetString(message_, descriptor(), value);
301   }
302 
Store(const std::unique_ptr<protobuf::Message> & value)303   void Store(const std::unique_ptr<protobuf::Message>& value) const {
304     protobuf::Message* mutable_message =
305         is_repeated() ? reflection().MutableRepeatedMessage(
306                             message_, descriptor(), index())
307                       : reflection().MutableMessage(message_, descriptor());
308     mutable_message->Clear();
309     if (value) mutable_message->CopyFrom(*value);
310   }
311 
312  private:
313   template <class T>
InsertRepeated(const T & value)314   void InsertRepeated(const T& value) const {
315     PushBackRepeated(value);
316     size_t field_size = reflection().FieldSize(*message_, descriptor());
317     if (field_size == 1) return;
318     // API has only method to add field to the end of the list. So we add
319     // descriptor()
320     // and move it into the middle.
321     for (size_t i = field_size - 1; i > index(); --i)
322       reflection().SwapElements(message_, descriptor(), i, i - 1);
323   }
324 
PushBackRepeated(int32_t value)325   void PushBackRepeated(int32_t value) const {
326     assert(is_repeated());
327     reflection().AddInt32(message_, descriptor(), value);
328   }
329 
PushBackRepeated(int64_t value)330   void PushBackRepeated(int64_t value) const {
331     assert(is_repeated());
332     reflection().AddInt64(message_, descriptor(), value);
333   }
334 
PushBackRepeated(uint32_t value)335   void PushBackRepeated(uint32_t value) const {
336     assert(is_repeated());
337     reflection().AddUInt32(message_, descriptor(), value);
338   }
339 
PushBackRepeated(uint64_t value)340   void PushBackRepeated(uint64_t value) const {
341     assert(is_repeated());
342     reflection().AddUInt64(message_, descriptor(), value);
343   }
344 
PushBackRepeated(double value)345   void PushBackRepeated(double value) const {
346     assert(is_repeated());
347     reflection().AddDouble(message_, descriptor(), value);
348   }
349 
PushBackRepeated(float value)350   void PushBackRepeated(float value) const {
351     assert(is_repeated());
352     reflection().AddFloat(message_, descriptor(), value);
353   }
354 
PushBackRepeated(bool value)355   void PushBackRepeated(bool value) const {
356     assert(is_repeated());
357     reflection().AddBool(message_, descriptor(), value);
358   }
359 
PushBackRepeated(const Enum & value)360   void PushBackRepeated(const Enum& value) const {
361     assert(value.index < value.count);
362     const protobuf::EnumValueDescriptor* enum_value =
363         descriptor()->enum_type()->value(value.index);
364     assert(is_repeated());
365     reflection().AddEnum(message_, descriptor(), enum_value);
366   }
367 
PushBackRepeated(const std::string & value)368   void PushBackRepeated(const std::string& value) const {
369     assert(is_repeated());
370     reflection().AddString(message_, descriptor(), value);
371   }
372 
PushBackRepeated(const std::unique_ptr<protobuf::Message> & value)373   void PushBackRepeated(const std::unique_ptr<protobuf::Message>& value) const {
374     assert(is_repeated());
375     protobuf::Message* mutable_message =
376         reflection().AddMessage(message_, descriptor());
377     mutable_message->Clear();
378     if (value) mutable_message->CopyFrom(*value);
379   }
380 
381   protobuf::Message* message_;
382 };
383 
384 template <class Fn, class R = void>
385 struct FieldFunction {
386   template <class Field, class... Args>
operatorFieldFunction387   R operator()(const Field& field, const Args&... args) const {
388     assert(field.descriptor());
389     using protobuf::FieldDescriptor;
390     switch (field.cpp_type()) {
391       case FieldDescriptor::CPPTYPE_INT32:
392         return static_cast<const Fn*>(this)->template ForType<int32_t>(field,
393                                                                        args...);
394       case FieldDescriptor::CPPTYPE_INT64:
395         return static_cast<const Fn*>(this)->template ForType<int64_t>(field,
396                                                                        args...);
397       case FieldDescriptor::CPPTYPE_UINT32:
398         return static_cast<const Fn*>(this)->template ForType<uint32_t>(
399             field, args...);
400       case FieldDescriptor::CPPTYPE_UINT64:
401         return static_cast<const Fn*>(this)->template ForType<uint64_t>(
402             field, args...);
403       case FieldDescriptor::CPPTYPE_DOUBLE:
404         return static_cast<const Fn*>(this)->template ForType<double>(field,
405                                                                       args...);
406       case FieldDescriptor::CPPTYPE_FLOAT:
407         return static_cast<const Fn*>(this)->template ForType<float>(field,
408                                                                      args...);
409       case FieldDescriptor::CPPTYPE_BOOL:
410         return static_cast<const Fn*>(this)->template ForType<bool>(field,
411                                                                     args...);
412       case FieldDescriptor::CPPTYPE_ENUM:
413         return static_cast<const Fn*>(this)
414             ->template ForType<ConstFieldInstance::Enum>(field, args...);
415       case FieldDescriptor::CPPTYPE_STRING:
416         return static_cast<const Fn*>(this)->template ForType<std::string>(
417             field, args...);
418       case FieldDescriptor::CPPTYPE_MESSAGE:
419         return static_cast<const Fn*>(this)
420             ->template ForType<std::unique_ptr<protobuf::Message>>(field,
421                                                                    args...);
422     }
423     assert(false && "Unknown type");
424     abort();
425   }
426 };
427 
428 }  // namespace protobuf_mutator
429 
430 #endif  // SRC_FIELD_INSTANCE_H_
431