1 #region Copyright notice and license
2 // Protocol Buffers - Google's data interchange format
3 // Copyright 2008 Google Inc.  All rights reserved.
4 // https://developers.google.com/protocol-buffers/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 #endregion
32 
33 using Google.Protobuf.Collections;
34 using System;
35 using System.Collections.Generic;
36 using System.IO;
37 
38 namespace Google.Protobuf
39 {
40     /// <summary>
41     /// Reads and decodes protocol message fields.
42     /// </summary>
43     /// <remarks>
44     /// <para>
45     /// This class is generally used by generated code to read appropriate
46     /// primitives from the stream. It effectively encapsulates the lowest
47     /// levels of protocol buffer format.
48     /// </para>
49     /// <para>
50     /// Repeated fields and map fields are not handled by this class; use <see cref="RepeatedField{T}"/>
51     /// and <see cref="MapField{TKey, TValue}"/> to serialize such fields.
52     /// </para>
53     /// </remarks>
54     public sealed class CodedInputStream : IDisposable
55     {
56         /// <summary>
57         /// Whether to leave the underlying stream open when disposing of this stream.
58         /// This is always true when there's no stream.
59         /// </summary>
60         private readonly bool leaveOpen;
61 
62         /// <summary>
63         /// Buffer of data read from the stream or provided at construction time.
64         /// </summary>
65         private readonly byte[] buffer;
66 
67         /// <summary>
68         /// The index of the buffer at which we need to refill from the stream (if there is one).
69         /// </summary>
70         private int bufferSize;
71 
72         private int bufferSizeAfterLimit = 0;
73         /// <summary>
74         /// The position within the current buffer (i.e. the next byte to read)
75         /// </summary>
76         private int bufferPos = 0;
77 
78         /// <summary>
79         /// The stream to read further input from, or null if the byte array buffer was provided
80         /// directly on construction, with no further data available.
81         /// </summary>
82         private readonly Stream input;
83 
84         /// <summary>
85         /// The last tag we read. 0 indicates we've read to the end of the stream
86         /// (or haven't read anything yet).
87         /// </summary>
88         private uint lastTag = 0;
89 
90         /// <summary>
91         /// The next tag, used to store the value read by PeekTag.
92         /// </summary>
93         private uint nextTag = 0;
94         private bool hasNextTag = false;
95 
96         internal const int DefaultRecursionLimit = 100;
97         internal const int DefaultSizeLimit = Int32.MaxValue;
98         internal const int BufferSize = 4096;
99 
100         /// <summary>
101         /// The total number of bytes read before the current buffer. The
102         /// total bytes read up to the current position can be computed as
103         /// totalBytesRetired + bufferPos.
104         /// </summary>
105         private int totalBytesRetired = 0;
106 
107         /// <summary>
108         /// The absolute position of the end of the current message.
109         /// </summary>
110         private int currentLimit = int.MaxValue;
111 
112         private int recursionDepth = 0;
113 
114         private readonly int recursionLimit;
115         private readonly int sizeLimit;
116 
117         #region Construction
118         // Note that the checks are performed such that we don't end up checking obviously-valid things
119         // like non-null references for arrays we've just created.
120 
121         /// <summary>
122         /// Creates a new CodedInputStream reading data from the given byte array.
123         /// </summary>
CodedInputStream(byte[] buffer)124         public CodedInputStream(byte[] buffer) : this(null, ProtoPreconditions.CheckNotNull(buffer, "buffer"), 0, buffer.Length, true)
125         {
126         }
127 
128         /// <summary>
129         /// Creates a new <see cref="CodedInputStream"/> that reads from the given byte array slice.
130         /// </summary>
CodedInputStream(byte[] buffer, int offset, int length)131         public CodedInputStream(byte[] buffer, int offset, int length)
132             : this(null, ProtoPreconditions.CheckNotNull(buffer, "buffer"), offset, offset + length, true)
133         {
134             if (offset < 0 || offset > buffer.Length)
135             {
136                 throw new ArgumentOutOfRangeException("offset", "Offset must be within the buffer");
137             }
138             if (length < 0 || offset + length > buffer.Length)
139             {
140                 throw new ArgumentOutOfRangeException("length", "Length must be non-negative and within the buffer");
141             }
142         }
143 
144         /// <summary>
145         /// Creates a new <see cref="CodedInputStream"/> reading data from the given stream, which will be disposed
146         /// when the returned object is disposed.
147         /// </summary>
148         /// <param name="input">The stream to read from.</param>
CodedInputStream(Stream input)149         public CodedInputStream(Stream input) : this(input, false)
150         {
151         }
152 
153         /// <summary>
154         /// Creates a new <see cref="CodedInputStream"/> reading data from the given stream.
155         /// </summary>
156         /// <param name="input">The stream to read from.</param>
157         /// <param name="leaveOpen"><c>true</c> to leave <paramref name="input"/> open when the returned
158         /// <c cref="CodedInputStream"/> is disposed; <c>false</c> to dispose of the given stream when the
159         /// returned object is disposed.</param>
CodedInputStream(Stream input, bool leaveOpen)160         public CodedInputStream(Stream input, bool leaveOpen)
161             : this(ProtoPreconditions.CheckNotNull(input, "input"), new byte[BufferSize], 0, 0, leaveOpen)
162         {
163         }
164 
165         /// <summary>
166         /// Creates a new CodedInputStream reading data from the given
167         /// stream and buffer, using the default limits.
168         /// </summary>
CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, bool leaveOpen)169         internal CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, bool leaveOpen)
170         {
171             this.input = input;
172             this.buffer = buffer;
173             this.bufferPos = bufferPos;
174             this.bufferSize = bufferSize;
175             this.sizeLimit = DefaultSizeLimit;
176             this.recursionLimit = DefaultRecursionLimit;
177             this.leaveOpen = leaveOpen;
178         }
179 
180         /// <summary>
181         /// Creates a new CodedInputStream reading data from the given
182         /// stream and buffer, using the specified limits.
183         /// </summary>
184         /// <remarks>
185         /// This chains to the version with the default limits instead of vice versa to avoid
186         /// having to check that the default values are valid every time.
187         /// </remarks>
CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, int sizeLimit, int recursionLimit, bool leaveOpen)188         internal CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, int sizeLimit, int recursionLimit, bool leaveOpen)
189             : this(input, buffer, bufferPos, bufferSize, leaveOpen)
190         {
191             if (sizeLimit <= 0)
192             {
193                 throw new ArgumentOutOfRangeException("sizeLimit", "Size limit must be positive");
194             }
195             if (recursionLimit <= 0)
196             {
197                 throw new ArgumentOutOfRangeException("recursionLimit!", "Recursion limit must be positive");
198             }
199             this.sizeLimit = sizeLimit;
200             this.recursionLimit = recursionLimit;
201         }
202         #endregion
203 
204         /// <summary>
205         /// Creates a <see cref="CodedInputStream"/> with the specified size and recursion limits, reading
206         /// from an input stream.
207         /// </summary>
208         /// <remarks>
209         /// This method exists separately from the constructor to reduce the number of constructor overloads.
210         /// It is likely to be used considerably less frequently than the constructors, as the default limits
211         /// are suitable for most use cases.
212         /// </remarks>
213         /// <param name="input">The input stream to read from</param>
214         /// <param name="sizeLimit">The total limit of data to read from the stream.</param>
215         /// <param name="recursionLimit">The maximum recursion depth to allow while reading.</param>
216         /// <returns>A <c>CodedInputStream</c> reading from <paramref name="input"/> with the specified size
217         /// and recursion limits.</returns>
CreateWithLimits(Stream input, int sizeLimit, int recursionLimit)218         public static CodedInputStream CreateWithLimits(Stream input, int sizeLimit, int recursionLimit)
219         {
220             // Note: we may want an overload accepting leaveOpen
221             return new CodedInputStream(input, new byte[BufferSize], 0, 0, sizeLimit, recursionLimit, false);
222         }
223 
224         /// <summary>
225         /// Returns the current position in the input stream, or the position in the input buffer
226         /// </summary>
227         public long Position
228         {
229             get
230             {
231                 if (input != null)
232                 {
233                     return input.Position - ((bufferSize + bufferSizeAfterLimit) - bufferPos);
234                 }
235                 return bufferPos;
236             }
237         }
238 
239         /// <summary>
240         /// Returns the last tag read, or 0 if no tags have been read or we've read beyond
241         /// the end of the stream.
242         /// </summary>
243         internal uint LastTag { get { return lastTag; } }
244 
245         /// <summary>
246         /// Returns the size limit for this stream.
247         /// </summary>
248         /// <remarks>
249         /// This limit is applied when reading from the underlying stream, as a sanity check. It is
250         /// not applied when reading from a byte array data source without an underlying stream.
251         /// The default value is Int32.MaxValue.
252         /// </remarks>
253         /// <value>
254         /// The size limit.
255         /// </value>
256         public int SizeLimit { get { return sizeLimit; } }
257 
258         /// <summary>
259         /// Returns the recursion limit for this stream. This limit is applied whilst reading messages,
260         /// to avoid maliciously-recursive data.
261         /// </summary>
262         /// <remarks>
263         /// The default limit is 100.
264         /// </remarks>
265         /// <value>
266         /// The recursion limit for this stream.
267         /// </value>
268         public int RecursionLimit { get { return recursionLimit; } }
269 
270         /// <summary>
271         /// Internal-only property; when set to true, unknown fields will be discarded while parsing.
272         /// </summary>
273         internal bool DiscardUnknownFields { get; set; }
274 
275         /// <summary>
276         /// Internal-only property; provides extension identifiers to compatible messages while parsing.
277         /// </summary>
278         internal ExtensionRegistry ExtensionRegistry { get; set; }
279 
280         /// <summary>
281         /// Disposes of this instance, potentially closing any underlying stream.
282         /// </summary>
283         /// <remarks>
284         /// As there is no flushing to perform here, disposing of a <see cref="CodedInputStream"/> which
285         /// was constructed with the <c>leaveOpen</c> option parameter set to <c>true</c> (or one which
286         /// was constructed to read from a byte array) has no effect.
287         /// </remarks>
Dispose()288         public void Dispose()
289         {
290             if (!leaveOpen)
291             {
292                 input.Dispose();
293             }
294         }
295 
296         #region Validation
297         /// <summary>
298         /// Verifies that the last call to ReadTag() returned tag 0 - in other words,
299         /// we've reached the end of the stream when we expected to.
300         /// </summary>
301         /// <exception cref="InvalidProtocolBufferException">The
302         /// tag read was not the one specified</exception>
CheckReadEndOfStreamTag()303         internal void CheckReadEndOfStreamTag()
304         {
305             if (lastTag != 0)
306             {
307                 throw InvalidProtocolBufferException.MoreDataAvailable();
308             }
309         }
310         #endregion
311 
312         #region Reading of tags etc
313 
314         /// <summary>
315         /// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
316         /// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
317         /// same value.)
318         /// </summary>
PeekTag()319         public uint PeekTag()
320         {
321             if (hasNextTag)
322             {
323                 return nextTag;
324             }
325 
326             uint savedLast = lastTag;
327             nextTag = ReadTag();
328             hasNextTag = true;
329             lastTag = savedLast; // Undo the side effect of ReadTag
330             return nextTag;
331         }
332 
333         /// <summary>
334         /// Reads a field tag, returning the tag of 0 for "end of stream".
335         /// </summary>
336         /// <remarks>
337         /// If this method returns 0, it doesn't necessarily mean the end of all
338         /// the data in this CodedInputStream; it may be the end of the logical stream
339         /// for an embedded message, for example.
340         /// </remarks>
341         /// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns>
ReadTag()342         public uint ReadTag()
343         {
344             if (hasNextTag)
345             {
346                 lastTag = nextTag;
347                 hasNextTag = false;
348                 return lastTag;
349             }
350 
351             // Optimize for the incredibly common case of having at least two bytes left in the buffer,
352             // and those two bytes being enough to get the tag. This will be true for fields up to 4095.
353             if (bufferPos + 2 <= bufferSize)
354             {
355                 int tmp = buffer[bufferPos++];
356                 if (tmp < 128)
357                 {
358                     lastTag = (uint)tmp;
359                 }
360                 else
361                 {
362                     int result = tmp & 0x7f;
363                     if ((tmp = buffer[bufferPos++]) < 128)
364                     {
365                         result |= tmp << 7;
366                         lastTag = (uint) result;
367                     }
368                     else
369                     {
370                         // Nope, rewind and go the potentially slow route.
371                         bufferPos -= 2;
372                         lastTag = ReadRawVarint32();
373                     }
374                 }
375             }
376             else
377             {
378                 if (IsAtEnd)
379                 {
380                     lastTag = 0;
381                     return 0;
382                 }
383 
384                 lastTag = ReadRawVarint32();
385             }
386             if (WireFormat.GetTagFieldNumber(lastTag) == 0)
387             {
388                 // If we actually read a tag with a field of 0, that's not a valid tag.
389                 throw InvalidProtocolBufferException.InvalidTag();
390             }
391             if (ReachedLimit)
392             {
393                 return 0;
394             }
395             return lastTag;
396         }
397 
398         /// <summary>
399         /// Skips the data for the field with the tag we've just read.
400         /// This should be called directly after <see cref="ReadTag"/>, when
401         /// the caller wishes to skip an unknown field.
402         /// </summary>
403         /// <remarks>
404         /// This method throws <see cref="InvalidProtocolBufferException"/> if the last-read tag was an end-group tag.
405         /// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the
406         /// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly
407         /// resulting in an error if an end-group tag has not been paired with an earlier start-group tag.
408         /// </remarks>
409         /// <exception cref="InvalidProtocolBufferException">The last tag was an end-group tag</exception>
410         /// <exception cref="InvalidOperationException">The last read operation read to the end of the logical stream</exception>
SkipLastField()411         public void SkipLastField()
412         {
413             if (lastTag == 0)
414             {
415                 throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
416             }
417             switch (WireFormat.GetTagWireType(lastTag))
418             {
419                 case WireFormat.WireType.StartGroup:
420                     SkipGroup(lastTag);
421                     break;
422                 case WireFormat.WireType.EndGroup:
423                     throw new InvalidProtocolBufferException(
424                         "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");
425                 case WireFormat.WireType.Fixed32:
426                     ReadFixed32();
427                     break;
428                 case WireFormat.WireType.Fixed64:
429                     ReadFixed64();
430                     break;
431                 case WireFormat.WireType.LengthDelimited:
432                     var length = ReadLength();
433                     SkipRawBytes(length);
434                     break;
435                 case WireFormat.WireType.Varint:
436                     ReadRawVarint32();
437                     break;
438             }
439         }
440 
441         /// <summary>
442         /// Skip a group.
443         /// </summary>
SkipGroup(uint startGroupTag)444         internal void SkipGroup(uint startGroupTag)
445         {
446             // Note: Currently we expect this to be the way that groups are read. We could put the recursion
447             // depth changes into the ReadTag method instead, potentially...
448             recursionDepth++;
449             if (recursionDepth >= recursionLimit)
450             {
451                 throw InvalidProtocolBufferException.RecursionLimitExceeded();
452             }
453             uint tag;
454             while (true)
455             {
456                 tag = ReadTag();
457                 if (tag == 0)
458                 {
459                     throw InvalidProtocolBufferException.TruncatedMessage();
460                 }
461                 // Can't call SkipLastField for this case- that would throw.
462                 if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)
463                 {
464                     break;
465                 }
466                 // This recursion will allow us to handle nested groups.
467                 SkipLastField();
468             }
469             int startField = WireFormat.GetTagFieldNumber(startGroupTag);
470             int endField = WireFormat.GetTagFieldNumber(tag);
471             if (startField != endField)
472             {
473                 throw new InvalidProtocolBufferException(
474                     $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");
475             }
476             recursionDepth--;
477         }
478 
479         /// <summary>
480         /// Reads a double field from the stream.
481         /// </summary>
ReadDouble()482         public double ReadDouble()
483         {
484             return BitConverter.Int64BitsToDouble((long) ReadRawLittleEndian64());
485         }
486 
487         /// <summary>
488         /// Reads a float field from the stream.
489         /// </summary>
ReadFloat()490         public float ReadFloat()
491         {
492             if (BitConverter.IsLittleEndian && 4 <= bufferSize - bufferPos)
493             {
494                 float ret = BitConverter.ToSingle(buffer, bufferPos);
495                 bufferPos += 4;
496                 return ret;
497             }
498             else
499             {
500                 byte[] rawBytes = ReadRawBytes(4);
501                 if (!BitConverter.IsLittleEndian)
502                 {
503                     ByteArray.Reverse(rawBytes);
504                 }
505                 return BitConverter.ToSingle(rawBytes, 0);
506             }
507         }
508 
509         /// <summary>
510         /// Reads a uint64 field from the stream.
511         /// </summary>
ReadUInt64()512         public ulong ReadUInt64()
513         {
514             return ReadRawVarint64();
515         }
516 
517         /// <summary>
518         /// Reads an int64 field from the stream.
519         /// </summary>
ReadInt64()520         public long ReadInt64()
521         {
522             return (long) ReadRawVarint64();
523         }
524 
525         /// <summary>
526         /// Reads an int32 field from the stream.
527         /// </summary>
ReadInt32()528         public int ReadInt32()
529         {
530             return (int) ReadRawVarint32();
531         }
532 
533         /// <summary>
534         /// Reads a fixed64 field from the stream.
535         /// </summary>
ReadFixed64()536         public ulong ReadFixed64()
537         {
538             return ReadRawLittleEndian64();
539         }
540 
541         /// <summary>
542         /// Reads a fixed32 field from the stream.
543         /// </summary>
ReadFixed32()544         public uint ReadFixed32()
545         {
546             return ReadRawLittleEndian32();
547         }
548 
549         /// <summary>
550         /// Reads a bool field from the stream.
551         /// </summary>
ReadBool()552         public bool ReadBool()
553         {
554             return ReadRawVarint32() != 0;
555         }
556 
557         /// <summary>
558         /// Reads a string field from the stream.
559         /// </summary>
ReadString()560         public string ReadString()
561         {
562             int length = ReadLength();
563             // No need to read any data for an empty string.
564             if (length == 0)
565             {
566                 return "";
567             }
568             if (length <= bufferSize - bufferPos && length > 0)
569             {
570                 // Fast path:  We already have the bytes in a contiguous buffer, so
571                 //   just copy directly from it.
572                 String result = CodedOutputStream.Utf8Encoding.GetString(buffer, bufferPos, length);
573                 bufferPos += length;
574                 return result;
575             }
576             // Slow path: Build a byte array first then copy it.
577             return CodedOutputStream.Utf8Encoding.GetString(ReadRawBytes(length), 0, length);
578         }
579 
580         /// <summary>
581         /// Reads an embedded message field value from the stream.
582         /// </summary>
ReadMessage(IMessage builder)583         public void ReadMessage(IMessage builder)
584         {
585             int length = ReadLength();
586             if (recursionDepth >= recursionLimit)
587             {
588                 throw InvalidProtocolBufferException.RecursionLimitExceeded();
589             }
590             int oldLimit = PushLimit(length);
591             ++recursionDepth;
592             builder.MergeFrom(this);
593             CheckReadEndOfStreamTag();
594             // Check that we've read exactly as much data as expected.
595             if (!ReachedLimit)
596             {
597                 throw InvalidProtocolBufferException.TruncatedMessage();
598             }
599             --recursionDepth;
600             PopLimit(oldLimit);
601         }
602 
603         /// <summary>
604         /// Reads an embedded group field from the stream.
605         /// </summary>
ReadGroup(IMessage builder)606         public void ReadGroup(IMessage builder)
607         {
608             if (recursionDepth >= recursionLimit)
609             {
610                 throw InvalidProtocolBufferException.RecursionLimitExceeded();
611             }
612             ++recursionDepth;
613             builder.MergeFrom(this);
614             --recursionDepth;
615         }
616 
617         /// <summary>
618         /// Reads a bytes field value from the stream.
619         /// </summary>
ReadBytes()620         public ByteString ReadBytes()
621         {
622             int length = ReadLength();
623             if (length <= bufferSize - bufferPos && length > 0)
624             {
625                 // Fast path:  We already have the bytes in a contiguous buffer, so
626                 //   just copy directly from it.
627                 ByteString result = ByteString.CopyFrom(buffer, bufferPos, length);
628                 bufferPos += length;
629                 return result;
630             }
631             else
632             {
633                 // Slow path:  Build a byte array and attach it to a new ByteString.
634                 return ByteString.AttachBytes(ReadRawBytes(length));
635             }
636         }
637 
638         /// <summary>
639         /// Reads a uint32 field value from the stream.
640         /// </summary>
ReadUInt32()641         public uint ReadUInt32()
642         {
643             return ReadRawVarint32();
644         }
645 
646         /// <summary>
647         /// Reads an enum field value from the stream.
648         /// </summary>
ReadEnum()649         public int ReadEnum()
650         {
651             // Currently just a pass-through, but it's nice to separate it logically from WriteInt32.
652             return (int) ReadRawVarint32();
653         }
654 
655         /// <summary>
656         /// Reads an sfixed32 field value from the stream.
657         /// </summary>
ReadSFixed32()658         public int ReadSFixed32()
659         {
660             return (int) ReadRawLittleEndian32();
661         }
662 
663         /// <summary>
664         /// Reads an sfixed64 field value from the stream.
665         /// </summary>
ReadSFixed64()666         public long ReadSFixed64()
667         {
668             return (long) ReadRawLittleEndian64();
669         }
670 
671         /// <summary>
672         /// Reads an sint32 field value from the stream.
673         /// </summary>
ReadSInt32()674         public int ReadSInt32()
675         {
676             return DecodeZigZag32(ReadRawVarint32());
677         }
678 
679         /// <summary>
680         /// Reads an sint64 field value from the stream.
681         /// </summary>
ReadSInt64()682         public long ReadSInt64()
683         {
684             return DecodeZigZag64(ReadRawVarint64());
685         }
686 
687         /// <summary>
688         /// Reads a length for length-delimited data.
689         /// </summary>
690         /// <remarks>
691         /// This is internally just reading a varint, but this method exists
692         /// to make the calling code clearer.
693         /// </remarks>
ReadLength()694         public int ReadLength()
695         {
696             return (int) ReadRawVarint32();
697         }
698 
699         /// <summary>
700         /// Peeks at the next tag in the stream. If it matches <paramref name="tag"/>,
701         /// the tag is consumed and the method returns <c>true</c>; otherwise, the
702         /// stream is left in the original position and the method returns <c>false</c>.
703         /// </summary>
MaybeConsumeTag(uint tag)704         public bool MaybeConsumeTag(uint tag)
705         {
706             if (PeekTag() == tag)
707             {
708                 hasNextTag = false;
709                 return true;
710             }
711             return false;
712         }
713 
714         #endregion
715 
716         #region Underlying reading primitives
717 
718         /// <summary>
719         /// Same code as ReadRawVarint32, but read each byte individually, checking for
720         /// buffer overflow.
721         /// </summary>
SlowReadRawVarint32()722         private uint SlowReadRawVarint32()
723         {
724             int tmp = ReadRawByte();
725             if (tmp < 128)
726             {
727                 return (uint) tmp;
728             }
729             int result = tmp & 0x7f;
730             if ((tmp = ReadRawByte()) < 128)
731             {
732                 result |= tmp << 7;
733             }
734             else
735             {
736                 result |= (tmp & 0x7f) << 7;
737                 if ((tmp = ReadRawByte()) < 128)
738                 {
739                     result |= tmp << 14;
740                 }
741                 else
742                 {
743                     result |= (tmp & 0x7f) << 14;
744                     if ((tmp = ReadRawByte()) < 128)
745                     {
746                         result |= tmp << 21;
747                     }
748                     else
749                     {
750                         result |= (tmp & 0x7f) << 21;
751                         result |= (tmp = ReadRawByte()) << 28;
752                         if (tmp >= 128)
753                         {
754                             // Discard upper 32 bits.
755                             for (int i = 0; i < 5; i++)
756                             {
757                                 if (ReadRawByte() < 128)
758                                 {
759                                     return (uint) result;
760                                 }
761                             }
762                             throw InvalidProtocolBufferException.MalformedVarint();
763                         }
764                     }
765                 }
766             }
767             return (uint) result;
768         }
769 
770         /// <summary>
771         /// Reads a raw Varint from the stream.  If larger than 32 bits, discard the upper bits.
772         /// This method is optimised for the case where we've got lots of data in the buffer.
773         /// That means we can check the size just once, then just read directly from the buffer
774         /// without constant rechecking of the buffer length.
775         /// </summary>
ReadRawVarint32()776         internal uint ReadRawVarint32()
777         {
778             if (bufferPos + 5 > bufferSize)
779             {
780                 return SlowReadRawVarint32();
781             }
782 
783             int tmp = buffer[bufferPos++];
784             if (tmp < 128)
785             {
786                 return (uint) tmp;
787             }
788             int result = tmp & 0x7f;
789             if ((tmp = buffer[bufferPos++]) < 128)
790             {
791                 result |= tmp << 7;
792             }
793             else
794             {
795                 result |= (tmp & 0x7f) << 7;
796                 if ((tmp = buffer[bufferPos++]) < 128)
797                 {
798                     result |= tmp << 14;
799                 }
800                 else
801                 {
802                     result |= (tmp & 0x7f) << 14;
803                     if ((tmp = buffer[bufferPos++]) < 128)
804                     {
805                         result |= tmp << 21;
806                     }
807                     else
808                     {
809                         result |= (tmp & 0x7f) << 21;
810                         result |= (tmp = buffer[bufferPos++]) << 28;
811                         if (tmp >= 128)
812                         {
813                             // Discard upper 32 bits.
814                             // Note that this has to use ReadRawByte() as we only ensure we've
815                             // got at least 5 bytes at the start of the method. This lets us
816                             // use the fast path in more cases, and we rarely hit this section of code.
817                             for (int i = 0; i < 5; i++)
818                             {
819                                 if (ReadRawByte() < 128)
820                                 {
821                                     return (uint) result;
822                                 }
823                             }
824                             throw InvalidProtocolBufferException.MalformedVarint();
825                         }
826                     }
827                 }
828             }
829             return (uint) result;
830         }
831 
832         /// <summary>
833         /// Reads a varint from the input one byte at a time, so that it does not
834         /// read any bytes after the end of the varint. If you simply wrapped the
835         /// stream in a CodedInputStream and used ReadRawVarint32(Stream)
836         /// then you would probably end up reading past the end of the varint since
837         /// CodedInputStream buffers its input.
838         /// </summary>
839         /// <param name="input"></param>
840         /// <returns></returns>
ReadRawVarint32(Stream input)841         internal static uint ReadRawVarint32(Stream input)
842         {
843             int result = 0;
844             int offset = 0;
845             for (; offset < 32; offset += 7)
846             {
847                 int b = input.ReadByte();
848                 if (b == -1)
849                 {
850                     throw InvalidProtocolBufferException.TruncatedMessage();
851                 }
852                 result |= (b & 0x7f) << offset;
853                 if ((b & 0x80) == 0)
854                 {
855                     return (uint) result;
856                 }
857             }
858             // Keep reading up to 64 bits.
859             for (; offset < 64; offset += 7)
860             {
861                 int b = input.ReadByte();
862                 if (b == -1)
863                 {
864                     throw InvalidProtocolBufferException.TruncatedMessage();
865                 }
866                 if ((b & 0x80) == 0)
867                 {
868                     return (uint) result;
869                 }
870             }
871             throw InvalidProtocolBufferException.MalformedVarint();
872         }
873 
874         /// <summary>
875         /// Reads a raw varint from the stream.
876         /// </summary>
ReadRawVarint64()877         internal ulong ReadRawVarint64()
878         {
879             int shift = 0;
880             ulong result = 0;
881             while (shift < 64)
882             {
883                 byte b = ReadRawByte();
884                 result |= (ulong) (b & 0x7F) << shift;
885                 if ((b & 0x80) == 0)
886                 {
887                     return result;
888                 }
889                 shift += 7;
890             }
891             throw InvalidProtocolBufferException.MalformedVarint();
892         }
893 
894         /// <summary>
895         /// Reads a 32-bit little-endian integer from the stream.
896         /// </summary>
ReadRawLittleEndian32()897         internal uint ReadRawLittleEndian32()
898         {
899             uint b1 = ReadRawByte();
900             uint b2 = ReadRawByte();
901             uint b3 = ReadRawByte();
902             uint b4 = ReadRawByte();
903             return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24);
904         }
905 
906         /// <summary>
907         /// Reads a 64-bit little-endian integer from the stream.
908         /// </summary>
ReadRawLittleEndian64()909         internal ulong ReadRawLittleEndian64()
910         {
911             ulong b1 = ReadRawByte();
912             ulong b2 = ReadRawByte();
913             ulong b3 = ReadRawByte();
914             ulong b4 = ReadRawByte();
915             ulong b5 = ReadRawByte();
916             ulong b6 = ReadRawByte();
917             ulong b7 = ReadRawByte();
918             ulong b8 = ReadRawByte();
919             return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24)
920                    | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56);
921         }
922 
923         /// <summary>
924         /// Decode a 32-bit value with ZigZag encoding.
925         /// </summary>
926         /// <remarks>
927         /// ZigZag encodes signed integers into values that can be efficiently
928         /// encoded with varint.  (Otherwise, negative values must be
929         /// sign-extended to 64 bits to be varint encoded, thus always taking
930         /// 10 bytes on the wire.)
931         /// </remarks>
DecodeZigZag32(uint n)932         internal static int DecodeZigZag32(uint n)
933         {
934             return (int)(n >> 1) ^ -(int)(n & 1);
935         }
936 
937         /// <summary>
938         /// Decode a 32-bit value with ZigZag encoding.
939         /// </summary>
940         /// <remarks>
941         /// ZigZag encodes signed integers into values that can be efficiently
942         /// encoded with varint.  (Otherwise, negative values must be
943         /// sign-extended to 64 bits to be varint encoded, thus always taking
944         /// 10 bytes on the wire.)
945         /// </remarks>
DecodeZigZag64(ulong n)946         internal static long DecodeZigZag64(ulong n)
947         {
948             return (long)(n >> 1) ^ -(long)(n & 1);
949         }
950         #endregion
951 
952         #region Internal reading and buffer management
953 
954         /// <summary>
955         /// Sets currentLimit to (current position) + byteLimit. This is called
956         /// when descending into a length-delimited embedded message. The previous
957         /// limit is returned.
958         /// </summary>
959         /// <returns>The old limit.</returns>
PushLimit(int byteLimit)960         internal int PushLimit(int byteLimit)
961         {
962             if (byteLimit < 0)
963             {
964                 throw InvalidProtocolBufferException.NegativeSize();
965             }
966             byteLimit += totalBytesRetired + bufferPos;
967             int oldLimit = currentLimit;
968             if (byteLimit > oldLimit)
969             {
970                 throw InvalidProtocolBufferException.TruncatedMessage();
971             }
972             currentLimit = byteLimit;
973 
974             RecomputeBufferSizeAfterLimit();
975 
976             return oldLimit;
977         }
978 
RecomputeBufferSizeAfterLimit()979         private void RecomputeBufferSizeAfterLimit()
980         {
981             bufferSize += bufferSizeAfterLimit;
982             int bufferEnd = totalBytesRetired + bufferSize;
983             if (bufferEnd > currentLimit)
984             {
985                 // Limit is in current buffer.
986                 bufferSizeAfterLimit = bufferEnd - currentLimit;
987                 bufferSize -= bufferSizeAfterLimit;
988             }
989             else
990             {
991                 bufferSizeAfterLimit = 0;
992             }
993         }
994 
995         /// <summary>
996         /// Discards the current limit, returning the previous limit.
997         /// </summary>
PopLimit(int oldLimit)998         internal void PopLimit(int oldLimit)
999         {
1000             currentLimit = oldLimit;
1001             RecomputeBufferSizeAfterLimit();
1002         }
1003 
1004         /// <summary>
1005         /// Returns whether or not all the data before the limit has been read.
1006         /// </summary>
1007         /// <returns></returns>
1008         internal bool ReachedLimit
1009         {
1010             get
1011             {
1012                 if (currentLimit == int.MaxValue)
1013                 {
1014                     return false;
1015                 }
1016                 int currentAbsolutePosition = totalBytesRetired + bufferPos;
1017                 return currentAbsolutePosition >= currentLimit;
1018             }
1019         }
1020 
1021         /// <summary>
1022         /// Returns true if the stream has reached the end of the input. This is the
1023         /// case if either the end of the underlying input source has been reached or
1024         /// the stream has reached a limit created using PushLimit.
1025         /// </summary>
1026         public bool IsAtEnd
1027         {
1028             get { return bufferPos == bufferSize && !RefillBuffer(false); }
1029         }
1030 
1031         /// <summary>
1032         /// Called when buffer is empty to read more bytes from the
1033         /// input.  If <paramref name="mustSucceed"/> is true, RefillBuffer() gurantees that
1034         /// either there will be at least one byte in the buffer when it returns
1035         /// or it will throw an exception.  If <paramref name="mustSucceed"/> is false,
1036         /// RefillBuffer() returns false if no more bytes were available.
1037         /// </summary>
1038         /// <param name="mustSucceed"></param>
1039         /// <returns></returns>
RefillBuffer(bool mustSucceed)1040         private bool RefillBuffer(bool mustSucceed)
1041         {
1042             if (bufferPos < bufferSize)
1043             {
1044                 throw new InvalidOperationException("RefillBuffer() called when buffer wasn't empty.");
1045             }
1046 
1047             if (totalBytesRetired + bufferSize == currentLimit)
1048             {
1049                 // Oops, we hit a limit.
1050                 if (mustSucceed)
1051                 {
1052                     throw InvalidProtocolBufferException.TruncatedMessage();
1053                 }
1054                 else
1055                 {
1056                     return false;
1057                 }
1058             }
1059 
1060             totalBytesRetired += bufferSize;
1061 
1062             bufferPos = 0;
1063             bufferSize = (input == null) ? 0 : input.Read(buffer, 0, buffer.Length);
1064             if (bufferSize < 0)
1065             {
1066                 throw new InvalidOperationException("Stream.Read returned a negative count");
1067             }
1068             if (bufferSize == 0)
1069             {
1070                 if (mustSucceed)
1071                 {
1072                     throw InvalidProtocolBufferException.TruncatedMessage();
1073                 }
1074                 else
1075                 {
1076                     return false;
1077                 }
1078             }
1079             else
1080             {
1081                 RecomputeBufferSizeAfterLimit();
1082                 int totalBytesRead =
1083                     totalBytesRetired + bufferSize + bufferSizeAfterLimit;
1084                 if (totalBytesRead < 0 || totalBytesRead > sizeLimit)
1085                 {
1086                     throw InvalidProtocolBufferException.SizeLimitExceeded();
1087                 }
1088                 return true;
1089             }
1090         }
1091 
1092         /// <summary>
1093         /// Read one byte from the input.
1094         /// </summary>
1095         /// <exception cref="InvalidProtocolBufferException">
1096         /// the end of the stream or the current limit was reached
1097         /// </exception>
ReadRawByte()1098         internal byte ReadRawByte()
1099         {
1100             if (bufferPos == bufferSize)
1101             {
1102                 RefillBuffer(true);
1103             }
1104             return buffer[bufferPos++];
1105         }
1106 
1107         /// <summary>
1108         /// Reads a fixed size of bytes from the input.
1109         /// </summary>
1110         /// <exception cref="InvalidProtocolBufferException">
1111         /// the end of the stream or the current limit was reached
1112         /// </exception>
ReadRawBytes(int size)1113         internal byte[] ReadRawBytes(int size)
1114         {
1115             if (size < 0)
1116             {
1117                 throw InvalidProtocolBufferException.NegativeSize();
1118             }
1119 
1120             if (totalBytesRetired + bufferPos + size > currentLimit)
1121             {
1122                 // Read to the end of the stream (up to the current limit) anyway.
1123                 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
1124                 // Then fail.
1125                 throw InvalidProtocolBufferException.TruncatedMessage();
1126             }
1127 
1128             if (size <= bufferSize - bufferPos)
1129             {
1130                 // We have all the bytes we need already.
1131                 byte[] bytes = new byte[size];
1132                 ByteArray.Copy(buffer, bufferPos, bytes, 0, size);
1133                 bufferPos += size;
1134                 return bytes;
1135             }
1136             else if (size < buffer.Length)
1137             {
1138                 // Reading more bytes than are in the buffer, but not an excessive number
1139                 // of bytes.  We can safely allocate the resulting array ahead of time.
1140 
1141                 // First copy what we have.
1142                 byte[] bytes = new byte[size];
1143                 int pos = bufferSize - bufferPos;
1144                 ByteArray.Copy(buffer, bufferPos, bytes, 0, pos);
1145                 bufferPos = bufferSize;
1146 
1147                 // We want to use RefillBuffer() and then copy from the buffer into our
1148                 // byte array rather than reading directly into our byte array because
1149                 // the input may be unbuffered.
1150                 RefillBuffer(true);
1151 
1152                 while (size - pos > bufferSize)
1153                 {
1154                     Buffer.BlockCopy(buffer, 0, bytes, pos, bufferSize);
1155                     pos += bufferSize;
1156                     bufferPos = bufferSize;
1157                     RefillBuffer(true);
1158                 }
1159 
1160                 ByteArray.Copy(buffer, 0, bytes, pos, size - pos);
1161                 bufferPos = size - pos;
1162 
1163                 return bytes;
1164             }
1165             else
1166             {
1167                 // The size is very large.  For security reasons, we can't allocate the
1168                 // entire byte array yet.  The size comes directly from the input, so a
1169                 // maliciously-crafted message could provide a bogus very large size in
1170                 // order to trick the app into allocating a lot of memory.  We avoid this
1171                 // by allocating and reading only a small chunk at a time, so that the
1172                 // malicious message must actually *be* extremely large to cause
1173                 // problems.  Meanwhile, we limit the allowed size of a message elsewhere.
1174 
1175                 // Remember the buffer markers since we'll have to copy the bytes out of
1176                 // it later.
1177                 int originalBufferPos = bufferPos;
1178                 int originalBufferSize = bufferSize;
1179 
1180                 // Mark the current buffer consumed.
1181                 totalBytesRetired += bufferSize;
1182                 bufferPos = 0;
1183                 bufferSize = 0;
1184 
1185                 // Read all the rest of the bytes we need.
1186                 int sizeLeft = size - (originalBufferSize - originalBufferPos);
1187                 List<byte[]> chunks = new List<byte[]>();
1188 
1189                 while (sizeLeft > 0)
1190                 {
1191                     byte[] chunk = new byte[Math.Min(sizeLeft, buffer.Length)];
1192                     int pos = 0;
1193                     while (pos < chunk.Length)
1194                     {
1195                         int n = (input == null) ? -1 : input.Read(chunk, pos, chunk.Length - pos);
1196                         if (n <= 0)
1197                         {
1198                             throw InvalidProtocolBufferException.TruncatedMessage();
1199                         }
1200                         totalBytesRetired += n;
1201                         pos += n;
1202                     }
1203                     sizeLeft -= chunk.Length;
1204                     chunks.Add(chunk);
1205                 }
1206 
1207                 // OK, got everything.  Now concatenate it all into one buffer.
1208                 byte[] bytes = new byte[size];
1209 
1210                 // Start by copying the leftover bytes from this.buffer.
1211                 int newPos = originalBufferSize - originalBufferPos;
1212                 ByteArray.Copy(buffer, originalBufferPos, bytes, 0, newPos);
1213 
1214                 // And now all the chunks.
1215                 foreach (byte[] chunk in chunks)
1216                 {
1217                     Buffer.BlockCopy(chunk, 0, bytes, newPos, chunk.Length);
1218                     newPos += chunk.Length;
1219                 }
1220 
1221                 // Done.
1222                 return bytes;
1223             }
1224         }
1225 
1226         /// <summary>
1227         /// Reads and discards <paramref name="size"/> bytes.
1228         /// </summary>
1229         /// <exception cref="InvalidProtocolBufferException">the end of the stream
1230         /// or the current limit was reached</exception>
SkipRawBytes(int size)1231         private void SkipRawBytes(int size)
1232         {
1233             if (size < 0)
1234             {
1235                 throw InvalidProtocolBufferException.NegativeSize();
1236             }
1237 
1238             if (totalBytesRetired + bufferPos + size > currentLimit)
1239             {
1240                 // Read to the end of the stream anyway.
1241                 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
1242                 // Then fail.
1243                 throw InvalidProtocolBufferException.TruncatedMessage();
1244             }
1245 
1246             if (size <= bufferSize - bufferPos)
1247             {
1248                 // We have all the bytes we need already.
1249                 bufferPos += size;
1250             }
1251             else
1252             {
1253                 // Skipping more bytes than are in the buffer.  First skip what we have.
1254                 int pos = bufferSize - bufferPos;
1255 
1256                 // ROK 5/7/2013 Issue #54: should retire all bytes in buffer (bufferSize)
1257                 // totalBytesRetired += pos;
1258                 totalBytesRetired += bufferSize;
1259 
1260                 bufferPos = 0;
1261                 bufferSize = 0;
1262 
1263                 // Then skip directly from the InputStream for the rest.
1264                 if (pos < size)
1265                 {
1266                     if (input == null)
1267                     {
1268                         throw InvalidProtocolBufferException.TruncatedMessage();
1269                     }
1270                     SkipImpl(size - pos);
1271                     totalBytesRetired += size - pos;
1272                 }
1273             }
1274         }
1275 
1276         /// <summary>
1277         /// Abstraction of skipping to cope with streams which can't really skip.
1278         /// </summary>
SkipImpl(int amountToSkip)1279         private void SkipImpl(int amountToSkip)
1280         {
1281             if (input.CanSeek)
1282             {
1283                 long previousPosition = input.Position;
1284                 input.Position += amountToSkip;
1285                 if (input.Position != previousPosition + amountToSkip)
1286                 {
1287                     throw InvalidProtocolBufferException.TruncatedMessage();
1288                 }
1289             }
1290             else
1291             {
1292                 byte[] skipBuffer = new byte[Math.Min(1024, amountToSkip)];
1293                 while (amountToSkip > 0)
1294                 {
1295                     int bytesRead = input.Read(skipBuffer, 0, Math.Min(skipBuffer.Length, amountToSkip));
1296                     if (bytesRead <= 0)
1297                     {
1298                         throw InvalidProtocolBufferException.TruncatedMessage();
1299                     }
1300                     amountToSkip -= bytesRead;
1301                 }
1302             }
1303         }
1304         #endregion
1305     }
1306 }