1 #region Copyright notice and license
2 // Protocol Buffers - Google's data interchange format
3 // Copyright 2008 Google Inc.  All rights reserved.
4 // https://developers.google.com/protocol-buffers/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 #endregion
32 
33 using Google.Protobuf.Collections;
34 using System;
35 using System.Collections.Generic;
36 using System.Linq;
37 
38 namespace Google.Protobuf
39 {
40     /// <summary>
41     /// Methods for managing <see cref="ExtensionSet{TTarget}"/>s with null checking.
42     ///
43     /// Most users will not use this class directly and its API is experimental and subject to change.
44     /// </summary>
45     public static class ExtensionSet
46     {
47         private static bool TryGetValue<TTarget>(ref ExtensionSet<TTarget> set, Extension extension, out IExtensionValue value) where TTarget : IExtendableMessage<TTarget>
48         {
49             if (set == null)
50             {
51                 value = null;
52                 return false;
53             }
54             return set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value);
55         }
56 
57         /// <summary>
58         /// Gets the value of the specified extension
59         /// </summary>
60         public static TValue Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
61         {
62             IExtensionValue value;
63             if (TryGetValue(ref set, extension, out value))
64             {
65                 return ((ExtensionValue<TValue>)value).GetValue();
66             }
67             else
68             {
69                 return extension.DefaultValue;
70             }
71         }
72 
73         /// <summary>
74         /// Gets the value of the specified repeated extension or null if it doesn't exist in this set
75         /// </summary>
76         public static RepeatedField<TValue> Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
77         {
78             IExtensionValue value;
79             if (TryGetValue(ref set, extension, out value))
80             {
81                 return ((RepeatedExtensionValue<TValue>)value).GetValue();
82             }
83             else
84             {
85                 return null;
86             }
87         }
88 
89         /// <summary>
90         /// Gets the value of the specified repeated extension, registering it if it doesn't exist
91         /// </summary>
92         public static RepeatedField<TValue> GetOrRegister<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
93         {
94             IExtensionValue value;
95             if (set == null)
96             {
97                 value = extension.CreateValue();
98                 set = new ExtensionSet<TTarget>();
99                 set.ValuesByNumber.Add(extension.FieldNumber, value);
100             }
101             else
102             {
103                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value))
104                 {
105                     value = extension.CreateValue();
106                     set.ValuesByNumber.Add(extension.FieldNumber, value);
107                 }
108             }
109 
110             return ((RepeatedExtensionValue<TValue>)value).GetValue();
111         }
112 
113         /// <summary>
114         /// Sets the value of the specified extension. This will make a new instance of ExtensionSet if the set is null.
115         /// </summary>
116         public static void Set<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension, TValue value) where TTarget : IExtendableMessage<TTarget>
117         {
118             IExtensionValue extensionValue;
119             if (set == null)
120             {
121                 extensionValue = extension.CreateValue();
122                 set = new ExtensionSet<TTarget>();
123                 set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
124             }
125             else
126             {
127                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out extensionValue))
128                 {
129                     extensionValue = extension.CreateValue();
130                     set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
131                 }
132             }
133 
134             ((ExtensionValue<TValue>)extensionValue).SetValue(value);
135         }
136 
137         /// <summary>
138         /// Gets whether the value of the specified extension is set
139         /// </summary>
140         public static bool Has<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
141         {
142             IExtensionValue value;
143             return TryGetValue(ref set, extension, out value);
144         }
145 
146         /// <summary>
147         /// Clears the value of the specified extension
148         /// </summary>
149         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
150         {
151             if (set == null)
152             {
153                 return;
154             }
155             set.ValuesByNumber.Remove(extension.FieldNumber);
156             if (set.ValuesByNumber.Count == 0)
157             {
158                 set = null;
159             }
160         }
161 
162         /// <summary>
163         /// Clears the value of the specified extension
164         /// </summary>
165         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
166         {
167             if (set == null)
168             {
169                 return;
170             }
171             set.ValuesByNumber.Remove(extension.FieldNumber);
172             if (set.ValuesByNumber.Count == 0)
173             {
174                 set = null;
175             }
176         }
177 
178         /// <summary>
179         /// Tries to merge a field from the coded input, returning true if the field was merged.
180         /// If the set is null or the field was not otherwise merged, this returns false.
181         /// </summary>
182         public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, CodedInputStream stream) where TTarget : IExtendableMessage<TTarget>
183         {
184             Extension extension;
185             int lastFieldNumber = WireFormat.GetTagFieldNumber(stream.LastTag);
186 
187             IExtensionValue extensionValue;
188             if (set != null && set.ValuesByNumber.TryGetValue(lastFieldNumber, out extensionValue))
189             {
190                 extensionValue.MergeFrom(stream);
191                 return true;
192             }
193             else if (stream.ExtensionRegistry != null && stream.ExtensionRegistry.ContainsInputField(stream, typeof(TTarget), out extension))
194             {
195                 IExtensionValue value = extension.CreateValue();
196                 value.MergeFrom(stream);
197                 set = (set ?? new ExtensionSet<TTarget>());
198                 set.ValuesByNumber.Add(extension.FieldNumber, value);
199                 return true;
200             }
201             else
202             {
203                 return false;
204             }
205         }
206 
207         /// <summary>
208         /// Merges the second set into the first set, creating a new instance if first is null
209         /// </summary>
210         public static void MergeFrom<TTarget>(ref ExtensionSet<TTarget> first, ExtensionSet<TTarget> second) where TTarget : IExtendableMessage<TTarget>
211         {
212             if (second == null)
213             {
214                 return;
215             }
216             if (first == null)
217             {
218                 first = new ExtensionSet<TTarget>();
219             }
220             foreach (var pair in second.ValuesByNumber)
221             {
222                 IExtensionValue value;
223                 if (first.ValuesByNumber.TryGetValue(pair.Key, out value))
224                 {
225                     value.MergeFrom(pair.Value);
226                 }
227                 else
228                 {
229                     var cloned = pair.Value.Clone();
230                     first.ValuesByNumber[pair.Key] = cloned;
231                 }
232             }
233         }
234 
235         /// <summary>
236         /// Clones the set into a new set. If the set is null, this returns null
237         /// </summary>
238         public static ExtensionSet<TTarget> Clone<TTarget>(ExtensionSet<TTarget> set) where TTarget : IExtendableMessage<TTarget>
239         {
240             if (set == null)
241             {
242                 return null;
243             }
244 
245             var newSet = new ExtensionSet<TTarget>();
246             foreach (var pair in set.ValuesByNumber)
247             {
248                 var cloned = pair.Value.Clone();
249                 newSet.ValuesByNumber[pair.Key] = cloned;
250             }
251             return newSet;
252         }
253     }
254 
255     /// <summary>
256     /// Used for keeping track of extensions in messages.
257     /// <see cref="IExtendableMessage{T}"/> methods route to this set.
258     ///
259     /// Most users will not need to use this class directly
260     /// </summary>
261     /// <typeparam name="TTarget">The message type that extensions in this set target</typeparam>
262     public sealed class ExtensionSet<TTarget> where TTarget : IExtendableMessage<TTarget>
263     {
264         internal Dictionary<int, IExtensionValue> ValuesByNumber { get; } = new Dictionary<int, IExtensionValue>();
265 
266         /// <summary>
267         /// Gets a hash code of the set
268         /// </summary>
GetHashCode()269         public override int GetHashCode()
270         {
271             int ret = typeof(TTarget).GetHashCode();
272             foreach (KeyValuePair<int, IExtensionValue> field in ValuesByNumber)
273             {
274                 // Use ^ here to make the field order irrelevant.
275                 int hash = field.Key.GetHashCode() ^ field.Value.GetHashCode();
276                 ret ^= hash;
277             }
278             return ret;
279         }
280 
281         /// <summary>
282         /// Returns whether this set is equal to the other object
283         /// </summary>
Equals(object other)284         public override bool Equals(object other)
285         {
286             if (ReferenceEquals(this, other))
287             {
288                 return true;
289             }
290             ExtensionSet<TTarget> otherSet = other as ExtensionSet<TTarget>;
291             if (ValuesByNumber.Count != otherSet.ValuesByNumber.Count)
292             {
293                 return false;
294             }
295             foreach (var pair in ValuesByNumber)
296             {
297                 IExtensionValue secondValue;
298                 if (!otherSet.ValuesByNumber.TryGetValue(pair.Key, out secondValue))
299                 {
300                     return false;
301                 }
302                 if (!pair.Value.Equals(secondValue))
303                 {
304                     return false;
305                 }
306             }
307             return true;
308         }
309 
310         /// <summary>
311         /// Calculates the size of this extension set
312         /// </summary>
CalculateSize()313         public int CalculateSize()
314         {
315             int size = 0;
316             foreach (var value in ValuesByNumber.Values)
317             {
318                 size += value.CalculateSize();
319             }
320             return size;
321         }
322 
323         /// <summary>
324         /// Writes the extension values in this set to the output stream
325         /// </summary>
WriteTo(CodedOutputStream stream)326         public void WriteTo(CodedOutputStream stream)
327         {
328             foreach (var value in ValuesByNumber.Values)
329             {
330                 value.WriteTo(stream);
331             }
332         }
333     }
334 }
335