1 // Copyright 2016 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_FIELD_INSTANCE_H_ 16 #define SRC_FIELD_INSTANCE_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "port/protobuf.h" 22 23 namespace protobuf_mutator { 24 25 // Helper class for common protobuf fields operations. 26 class ConstFieldInstance { 27 public: 28 static const size_t kInvalidIndex = -1; 29 30 struct Enum { 31 size_t index; 32 size_t count; 33 }; 34 ConstFieldInstance()35 ConstFieldInstance() 36 : message_(nullptr), descriptor_(nullptr), index_(kInvalidIndex) {} 37 ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)38 ConstFieldInstance(const protobuf::Message* message, 39 const protobuf::FieldDescriptor* field, size_t index) 40 : message_(message), descriptor_(field), index_(index) { 41 assert(message_); 42 assert(descriptor_); 43 assert(index_ != kInvalidIndex); 44 assert(descriptor_->is_repeated()); 45 } 46 ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field)47 ConstFieldInstance(const protobuf::Message* message, 48 const protobuf::FieldDescriptor* field) 49 : message_(message), descriptor_(field), index_(kInvalidIndex) { 50 assert(message_); 51 assert(descriptor_); 52 assert(!descriptor_->is_repeated()); 53 } 54 GetDefault(int32_t * out)55 void GetDefault(int32_t* out) const { 56 *out = descriptor_->default_value_int32(); 57 } 58 GetDefault(int64_t * out)59 void GetDefault(int64_t* out) const { 60 *out = descriptor_->default_value_int64(); 61 } 62 GetDefault(uint32_t * out)63 void GetDefault(uint32_t* out) const { 64 *out = descriptor_->default_value_uint32(); 65 } 66 GetDefault(uint64_t * out)67 void GetDefault(uint64_t* out) const { 68 *out = descriptor_->default_value_uint64(); 69 } 70 GetDefault(double * out)71 void GetDefault(double* out) const { 72 *out = descriptor_->default_value_double(); 73 } 74 GetDefault(float * out)75 void GetDefault(float* out) const { 76 *out = descriptor_->default_value_float(); 77 } 78 GetDefault(bool * out)79 void GetDefault(bool* out) const { *out = descriptor_->default_value_bool(); } 80 GetDefault(Enum * out)81 void GetDefault(Enum* out) const { 82 const protobuf::EnumValueDescriptor* value = 83 descriptor_->default_value_enum(); 84 const protobuf::EnumDescriptor* type = value->type(); 85 *out = {static_cast<size_t>(value->index()), 86 static_cast<size_t>(type->value_count())}; 87 } 88 GetDefault(std::string * out)89 void GetDefault(std::string* out) const { 90 *out = descriptor_->default_value_string(); 91 } 92 GetDefault(std::unique_ptr<protobuf::Message> * out)93 void GetDefault(std::unique_ptr<protobuf::Message>* out) const { 94 out->reset(reflection() 95 .GetMessageFactory() 96 ->GetPrototype(descriptor_->message_type()) 97 ->New()); 98 } 99 Load(int32_t * value)100 void Load(int32_t* value) const { 101 *value = is_repeated() 102 ? reflection().GetRepeatedInt32(*message_, descriptor_, index_) 103 : reflection().GetInt32(*message_, descriptor_); 104 } 105 Load(int64_t * value)106 void Load(int64_t* value) const { 107 *value = is_repeated() 108 ? reflection().GetRepeatedInt64(*message_, descriptor_, index_) 109 : reflection().GetInt64(*message_, descriptor_); 110 } 111 Load(uint32_t * value)112 void Load(uint32_t* value) const { 113 *value = is_repeated() ? reflection().GetRepeatedUInt32(*message_, 114 descriptor_, index_) 115 : reflection().GetUInt32(*message_, descriptor_); 116 } 117 Load(uint64_t * value)118 void Load(uint64_t* value) const { 119 *value = is_repeated() ? reflection().GetRepeatedUInt64(*message_, 120 descriptor_, index_) 121 : reflection().GetUInt64(*message_, descriptor_); 122 } 123 Load(double * value)124 void Load(double* value) const { 125 *value = is_repeated() ? reflection().GetRepeatedDouble(*message_, 126 descriptor_, index_) 127 : reflection().GetDouble(*message_, descriptor_); 128 } 129 Load(float * value)130 void Load(float* value) const { 131 *value = is_repeated() 132 ? reflection().GetRepeatedFloat(*message_, descriptor_, index_) 133 : reflection().GetFloat(*message_, descriptor_); 134 } 135 Load(bool * value)136 void Load(bool* value) const { 137 *value = is_repeated() 138 ? reflection().GetRepeatedBool(*message_, descriptor_, index_) 139 : reflection().GetBool(*message_, descriptor_); 140 } 141 Load(Enum * value)142 void Load(Enum* value) const { 143 const protobuf::EnumValueDescriptor* value_descriptor = 144 is_repeated() 145 ? reflection().GetRepeatedEnum(*message_, descriptor_, index_) 146 : reflection().GetEnum(*message_, descriptor_); 147 *value = {static_cast<size_t>(value_descriptor->index()), 148 static_cast<size_t>(value_descriptor->type()->value_count())}; 149 } 150 Load(std::string * value)151 void Load(std::string* value) const { 152 *value = is_repeated() ? reflection().GetRepeatedString(*message_, 153 descriptor_, index_) 154 : reflection().GetString(*message_, descriptor_); 155 } 156 Load(std::unique_ptr<protobuf::Message> * value)157 void Load(std::unique_ptr<protobuf::Message>* value) const { 158 const protobuf::Message& source = 159 is_repeated() 160 ? reflection().GetRepeatedMessage(*message_, descriptor_, index_) 161 : reflection().GetMessage(*message_, descriptor_); 162 value->reset(source.New()); 163 (*value)->CopyFrom(source); 164 } 165 name()166 std::string name() const { return descriptor_->name(); } 167 cpp_type()168 protobuf::FieldDescriptor::CppType cpp_type() const { 169 return descriptor_->cpp_type(); 170 } 171 enum_type()172 const protobuf::EnumDescriptor* enum_type() const { 173 return descriptor_->enum_type(); 174 } 175 message_type()176 const protobuf::Descriptor* message_type() const { 177 return descriptor_->message_type(); 178 } 179 EnforceUtf8()180 bool EnforceUtf8() const { 181 return descriptor_->type() == protobuf::FieldDescriptor::TYPE_STRING && 182 descriptor()->file()->syntax() == 183 protobuf::FileDescriptor::SYNTAX_PROTO3; 184 } 185 186 protected: is_repeated()187 bool is_repeated() const { return descriptor_->is_repeated(); } 188 reflection()189 const protobuf::Reflection& reflection() const { 190 return *message_->GetReflection(); 191 } 192 descriptor()193 const protobuf::FieldDescriptor* descriptor() const { return descriptor_; } 194 index()195 size_t index() const { return index_; } 196 197 private: 198 template <class Fn, class T> 199 friend struct FieldFunction; 200 201 const protobuf::Message* message_; 202 const protobuf::FieldDescriptor* descriptor_; 203 size_t index_; 204 }; 205 206 class FieldInstance : public ConstFieldInstance { 207 public: 208 static const size_t kInvalidIndex = -1; 209 FieldInstance()210 FieldInstance() : ConstFieldInstance(), message_(nullptr) {} 211 FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)212 FieldInstance(protobuf::Message* message, 213 const protobuf::FieldDescriptor* field, size_t index) 214 : ConstFieldInstance(message, field, index), message_(message) {} 215 FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field)216 FieldInstance(protobuf::Message* message, 217 const protobuf::FieldDescriptor* field) 218 : ConstFieldInstance(message, field), message_(message) {} 219 Delete()220 void Delete() const { 221 if (!is_repeated()) return reflection().ClearField(message_, descriptor()); 222 int field_size = reflection().FieldSize(*message_, descriptor()); 223 // API has only method to delete the last message, so we move method from 224 // the 225 // middle to the end. 226 for (int i = index() + 1; i < field_size; ++i) 227 reflection().SwapElements(message_, descriptor(), i, i - 1); 228 reflection().RemoveLast(message_, descriptor()); 229 } 230 231 template <class T> Create(const T & value)232 void Create(const T& value) const { 233 if (!is_repeated()) return Store(value); 234 InsertRepeated(value); 235 } 236 Store(int32_t value)237 void Store(int32_t value) const { 238 if (is_repeated()) 239 reflection().SetRepeatedInt32(message_, descriptor(), index(), value); 240 else 241 reflection().SetInt32(message_, descriptor(), value); 242 } 243 Store(int64_t value)244 void Store(int64_t value) const { 245 if (is_repeated()) 246 reflection().SetRepeatedInt64(message_, descriptor(), index(), value); 247 else 248 reflection().SetInt64(message_, descriptor(), value); 249 } 250 Store(uint32_t value)251 void Store(uint32_t value) const { 252 if (is_repeated()) 253 reflection().SetRepeatedUInt32(message_, descriptor(), index(), value); 254 else 255 reflection().SetUInt32(message_, descriptor(), value); 256 } 257 Store(uint64_t value)258 void Store(uint64_t value) const { 259 if (is_repeated()) 260 reflection().SetRepeatedUInt64(message_, descriptor(), index(), value); 261 else 262 reflection().SetUInt64(message_, descriptor(), value); 263 } 264 Store(double value)265 void Store(double value) const { 266 if (is_repeated()) 267 reflection().SetRepeatedDouble(message_, descriptor(), index(), value); 268 else 269 reflection().SetDouble(message_, descriptor(), value); 270 } 271 Store(float value)272 void Store(float value) const { 273 if (is_repeated()) 274 reflection().SetRepeatedFloat(message_, descriptor(), index(), value); 275 else 276 reflection().SetFloat(message_, descriptor(), value); 277 } 278 Store(bool value)279 void Store(bool value) const { 280 if (is_repeated()) 281 reflection().SetRepeatedBool(message_, descriptor(), index(), value); 282 else 283 reflection().SetBool(message_, descriptor(), value); 284 } 285 Store(const Enum & value)286 void Store(const Enum& value) const { 287 assert(value.index < value.count); 288 const protobuf::EnumValueDescriptor* enum_value = 289 descriptor()->enum_type()->value(value.index); 290 if (is_repeated()) 291 reflection().SetRepeatedEnum(message_, descriptor(), index(), enum_value); 292 else 293 reflection().SetEnum(message_, descriptor(), enum_value); 294 } 295 Store(const std::string & value)296 void Store(const std::string& value) const { 297 if (is_repeated()) 298 reflection().SetRepeatedString(message_, descriptor(), index(), value); 299 else 300 reflection().SetString(message_, descriptor(), value); 301 } 302 Store(const std::unique_ptr<protobuf::Message> & value)303 void Store(const std::unique_ptr<protobuf::Message>& value) const { 304 protobuf::Message* mutable_message = 305 is_repeated() ? reflection().MutableRepeatedMessage( 306 message_, descriptor(), index()) 307 : reflection().MutableMessage(message_, descriptor()); 308 mutable_message->Clear(); 309 if (value) mutable_message->CopyFrom(*value); 310 } 311 312 private: 313 template <class T> InsertRepeated(const T & value)314 void InsertRepeated(const T& value) const { 315 PushBackRepeated(value); 316 size_t field_size = reflection().FieldSize(*message_, descriptor()); 317 if (field_size == 1) return; 318 // API has only method to add field to the end of the list. So we add 319 // descriptor() 320 // and move it into the middle. 321 for (size_t i = field_size - 1; i > index(); --i) 322 reflection().SwapElements(message_, descriptor(), i, i - 1); 323 } 324 PushBackRepeated(int32_t value)325 void PushBackRepeated(int32_t value) const { 326 assert(is_repeated()); 327 reflection().AddInt32(message_, descriptor(), value); 328 } 329 PushBackRepeated(int64_t value)330 void PushBackRepeated(int64_t value) const { 331 assert(is_repeated()); 332 reflection().AddInt64(message_, descriptor(), value); 333 } 334 PushBackRepeated(uint32_t value)335 void PushBackRepeated(uint32_t value) const { 336 assert(is_repeated()); 337 reflection().AddUInt32(message_, descriptor(), value); 338 } 339 PushBackRepeated(uint64_t value)340 void PushBackRepeated(uint64_t value) const { 341 assert(is_repeated()); 342 reflection().AddUInt64(message_, descriptor(), value); 343 } 344 PushBackRepeated(double value)345 void PushBackRepeated(double value) const { 346 assert(is_repeated()); 347 reflection().AddDouble(message_, descriptor(), value); 348 } 349 PushBackRepeated(float value)350 void PushBackRepeated(float value) const { 351 assert(is_repeated()); 352 reflection().AddFloat(message_, descriptor(), value); 353 } 354 PushBackRepeated(bool value)355 void PushBackRepeated(bool value) const { 356 assert(is_repeated()); 357 reflection().AddBool(message_, descriptor(), value); 358 } 359 PushBackRepeated(const Enum & value)360 void PushBackRepeated(const Enum& value) const { 361 assert(value.index < value.count); 362 const protobuf::EnumValueDescriptor* enum_value = 363 descriptor()->enum_type()->value(value.index); 364 assert(is_repeated()); 365 reflection().AddEnum(message_, descriptor(), enum_value); 366 } 367 PushBackRepeated(const std::string & value)368 void PushBackRepeated(const std::string& value) const { 369 assert(is_repeated()); 370 reflection().AddString(message_, descriptor(), value); 371 } 372 PushBackRepeated(const std::unique_ptr<protobuf::Message> & value)373 void PushBackRepeated(const std::unique_ptr<protobuf::Message>& value) const { 374 assert(is_repeated()); 375 protobuf::Message* mutable_message = 376 reflection().AddMessage(message_, descriptor()); 377 mutable_message->Clear(); 378 if (value) mutable_message->CopyFrom(*value); 379 } 380 381 protobuf::Message* message_; 382 }; 383 384 template <class Fn, class R = void> 385 struct FieldFunction { 386 template <class Field, class... Args> operatorFieldFunction387 R operator()(const Field& field, const Args&... args) const { 388 assert(field.descriptor()); 389 using protobuf::FieldDescriptor; 390 switch (field.cpp_type()) { 391 case FieldDescriptor::CPPTYPE_INT32: 392 return static_cast<const Fn*>(this)->template ForType<int32_t>(field, 393 args...); 394 case FieldDescriptor::CPPTYPE_INT64: 395 return static_cast<const Fn*>(this)->template ForType<int64_t>(field, 396 args...); 397 case FieldDescriptor::CPPTYPE_UINT32: 398 return static_cast<const Fn*>(this)->template ForType<uint32_t>( 399 field, args...); 400 case FieldDescriptor::CPPTYPE_UINT64: 401 return static_cast<const Fn*>(this)->template ForType<uint64_t>( 402 field, args...); 403 case FieldDescriptor::CPPTYPE_DOUBLE: 404 return static_cast<const Fn*>(this)->template ForType<double>(field, 405 args...); 406 case FieldDescriptor::CPPTYPE_FLOAT: 407 return static_cast<const Fn*>(this)->template ForType<float>(field, 408 args...); 409 case FieldDescriptor::CPPTYPE_BOOL: 410 return static_cast<const Fn*>(this)->template ForType<bool>(field, 411 args...); 412 case FieldDescriptor::CPPTYPE_ENUM: 413 return static_cast<const Fn*>(this) 414 ->template ForType<ConstFieldInstance::Enum>(field, args...); 415 case FieldDescriptor::CPPTYPE_STRING: 416 return static_cast<const Fn*>(this)->template ForType<std::string>( 417 field, args...); 418 case FieldDescriptor::CPPTYPE_MESSAGE: 419 return static_cast<const Fn*>(this) 420 ->template ForType<std::unique_ptr<protobuf::Message>>(field, 421 args...); 422 } 423 assert(false && "Unknown type"); 424 abort(); 425 } 426 }; 427 428 } // namespace protobuf_mutator 429 430 #endif // SRC_FIELD_INSTANCE_H_ 431