1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "io" 42 "reflect" 43 "strconv" 44 "sync" 45) 46 47// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 48var ErrMissingExtension = errors.New("proto: missing extension") 49 50// ExtensionRange represents a range of message extensions for a protocol buffer. 51// Used in code generated by the protocol compiler. 52type ExtensionRange struct { 53 Start, End int32 // both inclusive 54} 55 56// extendableProto is an interface implemented by any protocol buffer generated by the current 57// proto compiler that may be extended. 58type extendableProto interface { 59 Message 60 ExtensionRangeArray() []ExtensionRange 61 extensionsWrite() map[int32]Extension 62 extensionsRead() (map[int32]Extension, sync.Locker) 63} 64 65// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 66// version of the proto compiler that may be extended. 67type extendableProtoV1 interface { 68 Message 69 ExtensionRangeArray() []ExtensionRange 70 ExtensionMap() map[int32]Extension 71} 72 73// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 74type extensionAdapter struct { 75 extendableProtoV1 76} 77 78func (e extensionAdapter) extensionsWrite() map[int32]Extension { 79 return e.ExtensionMap() 80} 81 82func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 83 return e.ExtensionMap(), notLocker{} 84} 85 86// notLocker is a sync.Locker whose Lock and Unlock methods are nops. 87type notLocker struct{} 88 89func (n notLocker) Lock() {} 90func (n notLocker) Unlock() {} 91 92// extendable returns the extendableProto interface for the given generated proto message. 93// If the proto message has the old extension format, it returns a wrapper that implements 94// the extendableProto interface. 95func extendable(p interface{}) (extendableProto, error) { 96 switch p := p.(type) { 97 case extendableProto: 98 if isNilPtr(p) { 99 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 100 } 101 return p, nil 102 case extendableProtoV1: 103 if isNilPtr(p) { 104 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 105 } 106 return extensionAdapter{p}, nil 107 } 108 // Don't allocate a specific error containing %T: 109 // this is the hot path for Clone and MarshalText. 110 return nil, errNotExtendable 111} 112 113var errNotExtendable = errors.New("proto: not an extendable proto.Message") 114 115func isNilPtr(x interface{}) bool { 116 v := reflect.ValueOf(x) 117 return v.Kind() == reflect.Ptr && v.IsNil() 118} 119 120// XXX_InternalExtensions is an internal representation of proto extensions. 121// 122// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 123// thus gaining the unexported 'extensions' method, which can be called only from the proto package. 124// 125// The methods of XXX_InternalExtensions are not concurrency safe in general, 126// but calls to logically read-only methods such as has and get may be executed concurrently. 127type XXX_InternalExtensions struct { 128 // The struct must be indirect so that if a user inadvertently copies a 129 // generated message and its embedded XXX_InternalExtensions, they 130 // avoid the mayhem of a copied mutex. 131 // 132 // The mutex serializes all logically read-only operations to p.extensionMap. 133 // It is up to the client to ensure that write operations to p.extensionMap are 134 // mutually exclusive with other accesses. 135 p *struct { 136 mu sync.Mutex 137 extensionMap map[int32]Extension 138 } 139} 140 141// extensionsWrite returns the extension map, creating it on first use. 142func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 143 if e.p == nil { 144 e.p = new(struct { 145 mu sync.Mutex 146 extensionMap map[int32]Extension 147 }) 148 e.p.extensionMap = make(map[int32]Extension) 149 } 150 return e.p.extensionMap 151} 152 153// extensionsRead returns the extensions map for read-only use. It may be nil. 154// The caller must hold the returned mutex's lock when accessing Elements within the map. 155func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 156 if e.p == nil { 157 return nil, nil 158 } 159 return e.p.extensionMap, &e.p.mu 160} 161 162// ExtensionDesc represents an extension specification. 163// Used in generated code from the protocol compiler. 164type ExtensionDesc struct { 165 ExtendedType Message // nil pointer to the type that is being extended 166 ExtensionType interface{} // nil pointer to the extension type 167 Field int32 // field number 168 Name string // fully-qualified name of extension, for text formatting 169 Tag string // protobuf tag style 170 Filename string // name of the file in which the extension is defined 171} 172 173func (ed *ExtensionDesc) repeated() bool { 174 t := reflect.TypeOf(ed.ExtensionType) 175 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 176} 177 178// Extension represents an extension in a message. 179type Extension struct { 180 // When an extension is stored in a message using SetExtension 181 // only desc and value are set. When the message is marshaled 182 // enc will be set to the encoded form of the message. 183 // 184 // When a message is unmarshaled and contains extensions, each 185 // extension will have only enc set. When such an extension is 186 // accessed using GetExtension (or GetExtensions) desc and value 187 // will be set. 188 desc *ExtensionDesc 189 190 // value is a concrete value for the extension field. Let the type of 191 // desc.ExtensionType be the "API type" and the type of Extension.value 192 // be the "storage type". The API type and storage type are the same except: 193 // * For scalars (except []byte), the API type uses *T, 194 // while the storage type uses T. 195 // * For repeated fields, the API type uses []T, while the storage type 196 // uses *[]T. 197 // 198 // The reason for the divergence is so that the storage type more naturally 199 // matches what is expected of when retrieving the values through the 200 // protobuf reflection APIs. 201 // 202 // The value may only be populated if desc is also populated. 203 value interface{} 204 205 // enc is the raw bytes for the extension field. 206 enc []byte 207} 208 209// SetRawExtension is for testing only. 210func SetRawExtension(base Message, id int32, b []byte) { 211 epb, err := extendable(base) 212 if err != nil { 213 return 214 } 215 extmap := epb.extensionsWrite() 216 extmap[id] = Extension{enc: b} 217} 218 219// isExtensionField returns true iff the given field number is in an extension range. 220func isExtensionField(pb extendableProto, field int32) bool { 221 for _, er := range pb.ExtensionRangeArray() { 222 if er.Start <= field && field <= er.End { 223 return true 224 } 225 } 226 return false 227} 228 229// checkExtensionTypes checks that the given extension is valid for pb. 230func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 231 var pbi interface{} = pb 232 // Check the extended type. 233 if ea, ok := pbi.(extensionAdapter); ok { 234 pbi = ea.extendableProtoV1 235 } 236 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 237 return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a) 238 } 239 // Check the range. 240 if !isExtensionField(pb, extension.Field) { 241 return errors.New("proto: bad extension number; not in declared ranges") 242 } 243 return nil 244} 245 246// extPropKey is sufficient to uniquely identify an extension. 247type extPropKey struct { 248 base reflect.Type 249 field int32 250} 251 252var extProp = struct { 253 sync.RWMutex 254 m map[extPropKey]*Properties 255}{ 256 m: make(map[extPropKey]*Properties), 257} 258 259func extensionProperties(ed *ExtensionDesc) *Properties { 260 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 261 262 extProp.RLock() 263 if prop, ok := extProp.m[key]; ok { 264 extProp.RUnlock() 265 return prop 266 } 267 extProp.RUnlock() 268 269 extProp.Lock() 270 defer extProp.Unlock() 271 // Check again. 272 if prop, ok := extProp.m[key]; ok { 273 return prop 274 } 275 276 prop := new(Properties) 277 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 278 extProp.m[key] = prop 279 return prop 280} 281 282// HasExtension returns whether the given extension is present in pb. 283func HasExtension(pb Message, extension *ExtensionDesc) bool { 284 // TODO: Check types, field numbers, etc.? 285 epb, err := extendable(pb) 286 if err != nil { 287 return false 288 } 289 extmap, mu := epb.extensionsRead() 290 if extmap == nil { 291 return false 292 } 293 mu.Lock() 294 _, ok := extmap[extension.Field] 295 mu.Unlock() 296 return ok 297} 298 299// ClearExtension removes the given extension from pb. 300func ClearExtension(pb Message, extension *ExtensionDesc) { 301 epb, err := extendable(pb) 302 if err != nil { 303 return 304 } 305 // TODO: Check types, field numbers, etc.? 306 extmap := epb.extensionsWrite() 307 delete(extmap, extension.Field) 308} 309 310// GetExtension retrieves a proto2 extended field from pb. 311// 312// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), 313// then GetExtension parses the encoded field and returns a Go value of the specified type. 314// If the field is not present, then the default value is returned (if one is specified), 315// otherwise ErrMissingExtension is reported. 316// 317// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil), 318// then GetExtension returns the raw encoded bytes of the field extension. 319func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 320 epb, err := extendable(pb) 321 if err != nil { 322 return nil, err 323 } 324 325 if extension.ExtendedType != nil { 326 // can only check type if this is a complete descriptor 327 if err := checkExtensionTypes(epb, extension); err != nil { 328 return nil, err 329 } 330 } 331 332 emap, mu := epb.extensionsRead() 333 if emap == nil { 334 return defaultExtensionValue(extension) 335 } 336 mu.Lock() 337 defer mu.Unlock() 338 e, ok := emap[extension.Field] 339 if !ok { 340 // defaultExtensionValue returns the default value or 341 // ErrMissingExtension if there is no default. 342 return defaultExtensionValue(extension) 343 } 344 345 if e.value != nil { 346 // Already decoded. Check the descriptor, though. 347 if e.desc != extension { 348 // This shouldn't happen. If it does, it means that 349 // GetExtension was called twice with two different 350 // descriptors with the same field number. 351 return nil, errors.New("proto: descriptor conflict") 352 } 353 return extensionAsLegacyType(e.value), nil 354 } 355 356 if extension.ExtensionType == nil { 357 // incomplete descriptor 358 return e.enc, nil 359 } 360 361 v, err := decodeExtension(e.enc, extension) 362 if err != nil { 363 return nil, err 364 } 365 366 // Remember the decoded version and drop the encoded version. 367 // That way it is safe to mutate what we return. 368 e.value = extensionAsStorageType(v) 369 e.desc = extension 370 e.enc = nil 371 emap[extension.Field] = e 372 return extensionAsLegacyType(e.value), nil 373} 374 375// defaultExtensionValue returns the default value for extension. 376// If no default for an extension is defined ErrMissingExtension is returned. 377func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 378 if extension.ExtensionType == nil { 379 // incomplete descriptor, so no default 380 return nil, ErrMissingExtension 381 } 382 383 t := reflect.TypeOf(extension.ExtensionType) 384 props := extensionProperties(extension) 385 386 sf, _, err := fieldDefault(t, props) 387 if err != nil { 388 return nil, err 389 } 390 391 if sf == nil || sf.value == nil { 392 // There is no default value. 393 return nil, ErrMissingExtension 394 } 395 396 if t.Kind() != reflect.Ptr { 397 // We do not need to return a Ptr, we can directly return sf.value. 398 return sf.value, nil 399 } 400 401 // We need to return an interface{} that is a pointer to sf.value. 402 value := reflect.New(t).Elem() 403 value.Set(reflect.New(value.Type().Elem())) 404 if sf.kind == reflect.Int32 { 405 // We may have an int32 or an enum, but the underlying data is int32. 406 // Since we can't set an int32 into a non int32 reflect.value directly 407 // set it as a int32. 408 value.Elem().SetInt(int64(sf.value.(int32))) 409 } else { 410 value.Elem().Set(reflect.ValueOf(sf.value)) 411 } 412 return value.Interface(), nil 413} 414 415// decodeExtension decodes an extension encoded in b. 416func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 417 t := reflect.TypeOf(extension.ExtensionType) 418 unmarshal := typeUnmarshaler(t, extension.Tag) 419 420 // t is a pointer to a struct, pointer to basic type or a slice. 421 // Allocate space to store the pointer/slice. 422 value := reflect.New(t).Elem() 423 424 var err error 425 for { 426 x, n := decodeVarint(b) 427 if n == 0 { 428 return nil, io.ErrUnexpectedEOF 429 } 430 b = b[n:] 431 wire := int(x) & 7 432 433 b, err = unmarshal(b, valToPointer(value.Addr()), wire) 434 if err != nil { 435 return nil, err 436 } 437 438 if len(b) == 0 { 439 break 440 } 441 } 442 return value.Interface(), nil 443} 444 445// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 446// The returned slice has the same length as es; missing extensions will appear as nil elements. 447func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 448 epb, err := extendable(pb) 449 if err != nil { 450 return nil, err 451 } 452 extensions = make([]interface{}, len(es)) 453 for i, e := range es { 454 extensions[i], err = GetExtension(epb, e) 455 if err == ErrMissingExtension { 456 err = nil 457 } 458 if err != nil { 459 return 460 } 461 } 462 return 463} 464 465// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 466// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 467// just the Field field, which defines the extension's field number. 468func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 469 epb, err := extendable(pb) 470 if err != nil { 471 return nil, err 472 } 473 registeredExtensions := RegisteredExtensions(pb) 474 475 emap, mu := epb.extensionsRead() 476 if emap == nil { 477 return nil, nil 478 } 479 mu.Lock() 480 defer mu.Unlock() 481 extensions := make([]*ExtensionDesc, 0, len(emap)) 482 for extid, e := range emap { 483 desc := e.desc 484 if desc == nil { 485 desc = registeredExtensions[extid] 486 if desc == nil { 487 desc = &ExtensionDesc{Field: extid} 488 } 489 } 490 491 extensions = append(extensions, desc) 492 } 493 return extensions, nil 494} 495 496// SetExtension sets the specified extension of pb to the specified value. 497func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 498 epb, err := extendable(pb) 499 if err != nil { 500 return err 501 } 502 if err := checkExtensionTypes(epb, extension); err != nil { 503 return err 504 } 505 typ := reflect.TypeOf(extension.ExtensionType) 506 if typ != reflect.TypeOf(value) { 507 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", value, extension.ExtensionType) 508 } 509 // nil extension values need to be caught early, because the 510 // encoder can't distinguish an ErrNil due to a nil extension 511 // from an ErrNil due to a missing field. Extensions are 512 // always optional, so the encoder would just swallow the error 513 // and drop all the extensions from the encoded message. 514 if reflect.ValueOf(value).IsNil() { 515 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 516 } 517 518 extmap := epb.extensionsWrite() 519 extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)} 520 return nil 521} 522 523// ClearAllExtensions clears all extensions from pb. 524func ClearAllExtensions(pb Message) { 525 epb, err := extendable(pb) 526 if err != nil { 527 return 528 } 529 m := epb.extensionsWrite() 530 for k := range m { 531 delete(m, k) 532 } 533} 534 535// A global registry of extensions. 536// The generated code will register the generated descriptors by calling RegisterExtension. 537 538var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 539 540// RegisterExtension is called from the generated code. 541func RegisterExtension(desc *ExtensionDesc) { 542 st := reflect.TypeOf(desc.ExtendedType).Elem() 543 m := extensionMaps[st] 544 if m == nil { 545 m = make(map[int32]*ExtensionDesc) 546 extensionMaps[st] = m 547 } 548 if _, ok := m[desc.Field]; ok { 549 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 550 } 551 m[desc.Field] = desc 552} 553 554// RegisteredExtensions returns a map of the registered extensions of a 555// protocol buffer struct, indexed by the extension number. 556// The argument pb should be a nil pointer to the struct type. 557func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 558 return extensionMaps[reflect.TypeOf(pb).Elem()] 559} 560 561// extensionAsLegacyType converts an value in the storage type as the API type. 562// See Extension.value. 563func extensionAsLegacyType(v interface{}) interface{} { 564 switch rv := reflect.ValueOf(v); rv.Kind() { 565 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: 566 // Represent primitive types as a pointer to the value. 567 rv2 := reflect.New(rv.Type()) 568 rv2.Elem().Set(rv) 569 v = rv2.Interface() 570 case reflect.Ptr: 571 // Represent slice types as the value itself. 572 switch rv.Type().Elem().Kind() { 573 case reflect.Slice: 574 if rv.IsNil() { 575 v = reflect.Zero(rv.Type().Elem()).Interface() 576 } else { 577 v = rv.Elem().Interface() 578 } 579 } 580 } 581 return v 582} 583 584// extensionAsStorageType converts an value in the API type as the storage type. 585// See Extension.value. 586func extensionAsStorageType(v interface{}) interface{} { 587 switch rv := reflect.ValueOf(v); rv.Kind() { 588 case reflect.Ptr: 589 // Represent slice types as the value itself. 590 switch rv.Type().Elem().Kind() { 591 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: 592 if rv.IsNil() { 593 v = reflect.Zero(rv.Type().Elem()).Interface() 594 } else { 595 v = rv.Elem().Interface() 596 } 597 } 598 case reflect.Slice: 599 // Represent slice types as a pointer to the value. 600 if rv.Type().Elem().Kind() != reflect.Uint8 { 601 rv2 := reflect.New(rv.Type()) 602 rv2.Elem().Set(rv) 603 v = rv2.Interface() 604 } 605 } 606 return v 607} 608