1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.util.proto;
18 
19 import android.util.LongArray;
20 
21 import java.io.IOException;
22 import java.io.InputStream;
23 import java.nio.charset.StandardCharsets;
24 import java.util.Arrays;
25 import java.util.Objects;
26 
27 /**
28  * Class to read to a protobuf stream.
29  *
30  * Each read method takes an ID code from the protoc generated classes
31  * and return a value of the field. To read a nested object, call #start
32  * and then #end when you are done.
33  *
34  * The ID codes have type information embedded into them, so if you call
35  * the incorrect function you will get an IllegalArgumentException.
36  *
37  * nextField will return the field number of the next field, which can be
38  * matched to the protoc generated ID code and used to determine how to
39  * read the next field.
40  *
41  * It is STRONGLY RECOMMENDED to read from the ProtoInputStream with a switch
42  * statement wrapped in a while loop. Additionally, it is worth logging or
43  * storing unexpected fields or ones that do not match the expected wire type
44  *
45  * ex:
46  * void parseFromProto(ProtoInputStream stream) {
47  *     while(stream.nextField() != ProtoInputStream.NO_MORE_FIELDS) {
48  *         try {
49  *             switch (stream.getFieldNumber()) {
50  *                 case (int) DummyProto.NAME:
51  *                     mName = stream.readString(DummyProto.NAME);
52  *                     break;
53  *                 case (int) DummyProto.VALUE:
54  *                     mValue = stream.readInt(DummyProto.VALUE);
55  *                     break;
56  *                 default:
57  *                     LOG(TAG, "Unhandled field in proto!\n"
58  *                              + ProtoUtils.currentFieldToString(stream));
59  *             }
60  *         } catch (WireTypeMismatchException wtme) {
61  *             LOG(TAG, "Wire Type mismatch in proto!\n" + ProtoUtils.currentFieldToString(stream));
62  *         }
63  *     }
64  * }
65  *
66  * @hide
67  */
68 @android.ravenwood.annotation.RavenwoodKeepWholeClass
69 public final class ProtoInputStream extends ProtoStream {
70 
71     public static final int NO_MORE_FIELDS = -1;
72 
73     /**
74      * Our stream.  If there is one.
75      */
76     private InputStream mStream;
77 
78     /**
79      * The field number of the current field. Will be equal to NO_MORE_FIELDS if end of message is
80      * reached
81      */
82     private int mFieldNumber;
83 
84     /**
85      * The wire type of the current field
86      */
87     private int mWireType;
88 
89     private static final byte STATE_STARTED_FIELD_READ = 1 << 0;
90     private static final byte STATE_READING_PACKED = 1 << 1;
91     private static final byte STATE_FIELD_MISS = 2 << 1;
92 
93     /**
94      * Tracks some boolean states for the proto input stream
95      * bit 0: Started Field Read, true - tag has been read, ready to read field data.
96      * false - field data has been read, reading to start next field.
97      * bit 1: Reading Packed Field, true - currently reading values from a packed field
98      * false - not reading from packed field.
99      */
100     private byte mState = 0;
101 
102     /**
103      * Keeps track of the currently read nested Objects, for end object checking and debug
104      */
105     private LongArray mExpectedObjectTokenStack = null;
106 
107     /**
108      * Current nesting depth of start calls.
109      */
110     private int mDepth = -1;
111 
112     /**
113      * Buffer for the to be read data. If mStream is not null, it will be constantly refilled from
114      * the stream.
115      */
116     private byte[] mBuffer;
117 
118     private static final int DEFAULT_BUFFER_SIZE = 8192;
119 
120     /**
121      * Size of the buffer if reading from a stream.
122      */
123     private final int mBufferSize;
124 
125     /**
126      * The number of bytes that have been skipped or dropped from the buffer.
127      */
128     private int mDiscardedBytes = 0;
129 
130     /**
131      * Current offset in the buffer
132      * mOffset + mDiscardedBytes = current offset in proto binary
133      */
134     private int mOffset = 0;
135 
136     /**
137      * Note the offset of the last byte in the buffer. Usually will equal the size of the buffer.
138      * mEnd + mDiscardedBytes = the last known byte offset + 1
139      */
140     private int mEnd = 0;
141 
142     /**
143      * Packed repeated fields are not read in one go. mPackedEnd keeps track of where the packed
144      * field ends in the proto binary if current field is packed.
145      */
146     private int mPackedEnd = 0;
147 
148     /**
149      * Construct a ProtoInputStream on top of an InputStream to read a proto. Also specify the
150      * number of bytes the ProtoInputStream will buffer from the input stream
151      *
152      * @param stream from which the proto is read
153      */
ProtoInputStream(InputStream stream, int bufferSize)154     public ProtoInputStream(InputStream stream, int bufferSize) {
155         mStream = stream;
156         if (bufferSize > 0) {
157             mBufferSize = bufferSize;
158         } else {
159             mBufferSize = DEFAULT_BUFFER_SIZE;
160         }
161         mBuffer = new byte[mBufferSize];
162     }
163 
164     /**
165      * Construct a ProtoInputStream on top of an InputStream to read a proto
166      *
167      * @param stream from which the proto is read
168      */
ProtoInputStream(InputStream stream)169     public ProtoInputStream(InputStream stream) {
170         this(stream, DEFAULT_BUFFER_SIZE);
171     }
172 
173     /**
174      * Construct a ProtoInputStream to read a proto directly from a byte array
175      *
176      * @param buffer - the byte array to be parsed
177      */
ProtoInputStream(byte[] buffer)178     public ProtoInputStream(byte[] buffer) {
179         mBufferSize = buffer.length;
180         mEnd = buffer.length;
181         mBuffer = buffer;
182         mStream = null;
183     }
184 
185     /**
186      * Get the field number of the current field.
187      */
getFieldNumber()188     public int getFieldNumber() {
189         return mFieldNumber;
190     }
191 
192     /**
193      * Get the wire type of the current field.
194      *
195      * @return an int that matches one of the ProtoStream WIRE_TYPE_ constants
196      */
getWireType()197     public int getWireType() {
198         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
199             // mWireType got overwritten when STATE_READING_PACKED was set. Send length delimited
200             // constant instead
201             return WIRE_TYPE_LENGTH_DELIMITED;
202         }
203         return mWireType;
204     }
205 
206     /**
207      * Get the current offset in the proto binary.
208      */
getOffset()209     public int getOffset() {
210         return mOffset + mDiscardedBytes;
211     }
212 
213     /**
214      * Reads the tag of the next field from the stream. If previous field value was not read, its
215      * data will be skipped over.
216      *
217      * @return the field number of the next field
218      * @throws IOException if an I/O error occurs
219      */
nextField()220     public int nextField() throws IOException {
221 
222         if ((mState & STATE_FIELD_MISS) == STATE_FIELD_MISS) {
223             // Data from the last nextField was not used, reuse the info
224             mState &= ~STATE_FIELD_MISS;
225             return mFieldNumber;
226         }
227         if ((mState & STATE_STARTED_FIELD_READ) == STATE_STARTED_FIELD_READ) {
228             // Field data was not read, skip to the next field
229             skip();
230             mState &= ~STATE_STARTED_FIELD_READ;
231         }
232         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
233             if (getOffset() < mPackedEnd) {
234                 // In the middle of a packed field, return the same tag until last packed value
235                 // has been read
236                 mState |= STATE_STARTED_FIELD_READ;
237                 return mFieldNumber;
238             } else if (getOffset() == mPackedEnd) {
239                 // Reached the end of the packed field
240                 mState &= ~STATE_READING_PACKED;
241             } else {
242                 throw new ProtoParseException(
243                         "Unexpectedly reached end of packed field at offset 0x"
244                                 + Integer.toHexString(mPackedEnd)
245                                 + dumpDebugData());
246             }
247         }
248 
249         if ((mDepth >= 0) && (getOffset() == getOffsetFromToken(
250                 mExpectedObjectTokenStack.get(mDepth)))) {
251             // reached end of a embedded message
252             mFieldNumber = NO_MORE_FIELDS;
253         } else {
254             readTag();
255         }
256         return mFieldNumber;
257     }
258 
259     /**
260      * Reads the tag of the next field from the stream. If previous field value was not read, its
261      * data will be skipped over. If {@code fieldId} matches the next field ID, the field data will
262      * be ready to read. If it does not match, {@link #nextField()} or {@link #nextField(long)} will
263      * need to be called again before the field data can be read.
264      *
265      * @return true if fieldId matches the next field, false if not
266      */
nextField(long fieldId)267     public boolean nextField(long fieldId) throws IOException {
268         if (nextField() == (int) fieldId) {
269             return true;
270         }
271         // Note to reuse the info from the nextField call in the next call.
272         mState |= STATE_FIELD_MISS;
273         return false;
274     }
275 
276     /**
277      * Read a single double.
278      * Will throw if the current wire type is not fixed64
279      *
280      * @param fieldId - must match the current field number and field type
281      */
readDouble(long fieldId)282     public double readDouble(long fieldId) throws IOException {
283         assertFreshData();
284         assertFieldNumber(fieldId);
285         checkPacked(fieldId);
286 
287         double value;
288         switch ((int) ((fieldId & FIELD_TYPE_MASK)
289                 >>> FIELD_TYPE_SHIFT)) {
290             case (int) (FIELD_TYPE_DOUBLE >>> FIELD_TYPE_SHIFT):
291                 assertWireType(WIRE_TYPE_FIXED64);
292                 value = Double.longBitsToDouble(readFixed64());
293                 break;
294             default:
295                 throw new IllegalArgumentException(
296                         "Requested field id (" + getFieldIdString(fieldId)
297                                 + ") cannot be read as a double"
298                                 + dumpDebugData());
299         }
300         // Successfully read the field
301         mState &= ~STATE_STARTED_FIELD_READ;
302         return value;
303     }
304 
305     /**
306      * Read a single float.
307      * Will throw if the current wire type is not fixed32
308      *
309      * @param fieldId - must match the current field number and field type
310      */
readFloat(long fieldId)311     public float readFloat(long fieldId) throws IOException {
312         assertFreshData();
313         assertFieldNumber(fieldId);
314         checkPacked(fieldId);
315 
316         float value;
317         switch ((int) ((fieldId & FIELD_TYPE_MASK)
318                 >>> FIELD_TYPE_SHIFT)) {
319             case (int) (FIELD_TYPE_FLOAT >>> FIELD_TYPE_SHIFT):
320                 assertWireType(WIRE_TYPE_FIXED32);
321                 value = Float.intBitsToFloat(readFixed32());
322                 break;
323             default:
324                 throw new IllegalArgumentException(
325                         "Requested field id (" + getFieldIdString(fieldId) + ") is not a float"
326                                 + dumpDebugData());
327         }
328         // Successfully read the field
329         mState &= ~STATE_STARTED_FIELD_READ;
330         return value;
331     }
332 
333     /**
334      * Read a single 32bit or varint proto type field as an int.
335      * Will throw if the current wire type is not varint or fixed32
336      *
337      * @param fieldId - must match the current field number and field type
338      */
readInt(long fieldId)339     public int readInt(long fieldId) throws IOException {
340         assertFreshData();
341         assertFieldNumber(fieldId);
342         checkPacked(fieldId);
343 
344         int value;
345         switch ((int) ((fieldId & FIELD_TYPE_MASK)
346                 >>> FIELD_TYPE_SHIFT)) {
347             case (int) (FIELD_TYPE_FIXED32 >>> FIELD_TYPE_SHIFT):
348             case (int) (FIELD_TYPE_SFIXED32 >>> FIELD_TYPE_SHIFT):
349                 assertWireType(WIRE_TYPE_FIXED32);
350                 value = readFixed32();
351                 break;
352             case (int) (FIELD_TYPE_SINT32 >>> FIELD_TYPE_SHIFT):
353                 assertWireType(WIRE_TYPE_VARINT);
354                 value = decodeZigZag32((int) readVarint());
355                 break;
356             case (int) (FIELD_TYPE_INT32 >>> FIELD_TYPE_SHIFT):
357             case (int) (FIELD_TYPE_UINT32 >>> FIELD_TYPE_SHIFT):
358             case (int) (FIELD_TYPE_ENUM >>> FIELD_TYPE_SHIFT):
359                 assertWireType(WIRE_TYPE_VARINT);
360                 value = (int) readVarint();
361                 break;
362             default:
363                 throw new IllegalArgumentException(
364                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an int"
365                                 + dumpDebugData());
366         }
367         // Successfully read the field
368         mState &= ~STATE_STARTED_FIELD_READ;
369         return value;
370     }
371 
372     /**
373      * Read a single 64bit or varint proto type field as an long.
374      *
375      * @param fieldId - must match the current field number
376      */
readLong(long fieldId)377     public long readLong(long fieldId) throws IOException {
378         assertFreshData();
379         assertFieldNumber(fieldId);
380         checkPacked(fieldId);
381 
382         long value;
383         switch ((int) ((fieldId & FIELD_TYPE_MASK)
384                 >>> FIELD_TYPE_SHIFT)) {
385             case (int) (FIELD_TYPE_FIXED64 >>> FIELD_TYPE_SHIFT):
386             case (int) (FIELD_TYPE_SFIXED64 >>> FIELD_TYPE_SHIFT):
387                 assertWireType(WIRE_TYPE_FIXED64);
388                 value = readFixed64();
389                 break;
390             case (int) (FIELD_TYPE_SINT64 >>> FIELD_TYPE_SHIFT):
391                 assertWireType(WIRE_TYPE_VARINT);
392                 value = decodeZigZag64(readVarint());
393                 break;
394             case (int) (FIELD_TYPE_INT64 >>> FIELD_TYPE_SHIFT):
395             case (int) (FIELD_TYPE_UINT64 >>> FIELD_TYPE_SHIFT):
396                 assertWireType(WIRE_TYPE_VARINT);
397                 value = readVarint();
398                 break;
399             default:
400                 throw new IllegalArgumentException(
401                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an long"
402                                 + dumpDebugData());
403         }
404         // Successfully read the field
405         mState &= ~STATE_STARTED_FIELD_READ;
406         return value;
407     }
408 
409     /**
410      * Read a single 32bit or varint proto type field as an boolean.
411      *
412      * @param fieldId - must match the current field number
413      */
readBoolean(long fieldId)414     public boolean readBoolean(long fieldId) throws IOException {
415         assertFreshData();
416         assertFieldNumber(fieldId);
417         checkPacked(fieldId);
418 
419         boolean value;
420         switch ((int) ((fieldId & FIELD_TYPE_MASK)
421                 >>> FIELD_TYPE_SHIFT)) {
422             case (int) (FIELD_TYPE_BOOL >>> FIELD_TYPE_SHIFT):
423                 assertWireType(WIRE_TYPE_VARINT);
424                 value = readVarint() != 0;
425                 break;
426             default:
427                 throw new IllegalArgumentException(
428                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an boolean"
429                                 + dumpDebugData());
430         }
431         // Successfully read the field
432         mState &= ~STATE_STARTED_FIELD_READ;
433         return value;
434     }
435 
436     /**
437      * Read a string field
438      *
439      * @param fieldId - must match the current field number
440      */
readString(long fieldId)441     public String readString(long fieldId) throws IOException {
442         assertFreshData();
443         assertFieldNumber(fieldId);
444 
445         String value;
446         switch ((int) ((fieldId & FIELD_TYPE_MASK) >>> FIELD_TYPE_SHIFT)) {
447             case (int) (FIELD_TYPE_STRING >>> FIELD_TYPE_SHIFT):
448                 assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
449                 int len = (int) readVarint();
450                 value = readRawString(len);
451                 break;
452             default:
453                 throw new IllegalArgumentException(
454                         "Requested field id(" + getFieldIdString(fieldId)
455                                 + ") is not an string"
456                                 + dumpDebugData());
457         }
458         // Successfully read the field
459         mState &= ~STATE_STARTED_FIELD_READ;
460         return value;
461     }
462 
463     /**
464      * Read a bytes field
465      *
466      * @param fieldId - must match the current field number
467      */
readBytes(long fieldId)468     public byte[] readBytes(long fieldId) throws IOException {
469         assertFreshData();
470         assertFieldNumber(fieldId);
471 
472         byte[] value;
473         switch ((int) ((fieldId & FIELD_TYPE_MASK) >>> FIELD_TYPE_SHIFT)) {
474             case (int) (FIELD_TYPE_MESSAGE >>> FIELD_TYPE_SHIFT):
475             case (int) (FIELD_TYPE_BYTES >>> FIELD_TYPE_SHIFT):
476                 assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
477                 int len = (int) readVarint();
478                 value = readRawBytes(len);
479                 break;
480             default:
481                 throw new IllegalArgumentException(
482                         "Requested field type (" + getFieldIdString(fieldId)
483                                 + ") cannot be read as raw bytes"
484                                 + dumpDebugData());
485         }
486         // Successfully read the field
487         mState &= ~STATE_STARTED_FIELD_READ;
488         return value;
489     }
490 
491     /**
492      * Start the read of an embedded Object
493      *
494      * @param fieldId - must match the current field number
495      * @return a token. The token must be handed back when finished reading embedded Object
496      */
start(long fieldId)497     public long start(long fieldId) throws IOException {
498         assertFreshData();
499         assertFieldNumber(fieldId);
500         assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
501 
502         int messageSize = (int) readVarint();
503 
504         if (mExpectedObjectTokenStack == null) {
505             mExpectedObjectTokenStack = new LongArray();
506         }
507         if (++mDepth == mExpectedObjectTokenStack.size()) {
508             // Create a token to keep track of nested Object and extend the object stack
509             mExpectedObjectTokenStack.add(makeToken(0,
510                     (fieldId & FIELD_COUNT_REPEATED) == FIELD_COUNT_REPEATED, mDepth,
511                     (int) fieldId, getOffset() + messageSize));
512 
513         } else {
514             // Create a token to keep track of nested Object
515             mExpectedObjectTokenStack.set(mDepth, makeToken(0,
516                     (fieldId & FIELD_COUNT_REPEATED) == FIELD_COUNT_REPEATED, mDepth,
517                     (int) fieldId, getOffset() + messageSize));
518         }
519 
520         // Validation check
521         if (mDepth > 0
522                 && getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth))
523                 > getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth - 1))) {
524             throw new ProtoParseException("Embedded Object ("
525                     + token2String(mExpectedObjectTokenStack.get(mDepth))
526                     + ") ends after of parent Objects's ("
527                     + token2String(mExpectedObjectTokenStack.get(mDepth - 1))
528                     + ") end"
529                     + dumpDebugData());
530         }
531         mState &= ~STATE_STARTED_FIELD_READ;
532         return mExpectedObjectTokenStack.get(mDepth);
533     }
534 
535     /**
536      * Note the end of a nested object. Must be called to continue streaming the rest of the proto.
537      * end can be called mid object parse. The offset will be moved to the next field outside the
538      * object.
539      *
540      * @param token - token
541      */
end(long token)542     public void end(long token) {
543         // Make sure user is keeping track of their embedded messages
544         if (mExpectedObjectTokenStack.get(mDepth) != token) {
545             throw new ProtoParseException(
546                     "end token " + token + " does not match current message token "
547                             + mExpectedObjectTokenStack.get(mDepth)
548                             + dumpDebugData());
549         }
550         if (getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth)) > getOffset()) {
551             // Did not read all of the message, skip to the end
552             incOffset(getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth)) - getOffset());
553         }
554         mDepth--;
555         mState &= ~STATE_STARTED_FIELD_READ;
556     }
557 
558     /**
559      * Read the tag at the start of the next field and collect field number and wire type.
560      * Will set mFieldNumber to NO_MORE_FIELDS if end of buffer/stream reached.
561      */
readTag()562     private void readTag() throws IOException {
563         fillBuffer();
564         if (mOffset >= mEnd) {
565             // reached end of the stream
566             mFieldNumber = NO_MORE_FIELDS;
567             return;
568         }
569         int tag = (int) readVarint();
570         mFieldNumber = tag >>> FIELD_ID_SHIFT;
571         mWireType = tag & WIRE_TYPE_MASK;
572         mState |= STATE_STARTED_FIELD_READ;
573     }
574 
575     /**
576      * Decode a 32 bit ZigZag encoded signed int.
577      *
578      * @param n - int to decode
579      * @return the decoded signed int
580      */
decodeZigZag32(final int n)581     public int decodeZigZag32(final int n) {
582         return (n >>> 1) ^ -(n & 1);
583     }
584 
585     /**
586      * Decode a 64 bit ZigZag encoded signed long.
587      *
588      * @param n - long to decode
589      * @return the decoded signed long
590      */
decodeZigZag64(final long n)591     public long decodeZigZag64(final long n) {
592         return (n >>> 1) ^ -(n & 1);
593     }
594 
595     /**
596      * Read a varint from the buffer
597      *
598      * @return the varint as a long
599      */
readVarint()600     private long readVarint() throws IOException {
601         long value = 0;
602         int shift = 0;
603         while (true) {
604             fillBuffer();
605             // Limit how much bookkeeping is done by checking how far away the end of the buffer is
606             // and directly accessing buffer up until the end.
607             final int fragment = mEnd - mOffset;
608             if (fragment < 0) {
609                 throw new ProtoParseException(
610                         "Incomplete varint at offset 0x"
611                                 + Integer.toHexString(getOffset())
612                                 + dumpDebugData());
613             }
614             for (int i = 0; i < fragment; i++) {
615                 byte b = mBuffer[(mOffset + i)];
616                 value |= (b & 0x7FL) << shift;
617                 if ((b & 0x80) == 0) {
618                     incOffset(i + 1);
619                     return value;
620                 }
621                 shift += 7;
622                 if (shift > 63) {
623                     throw new ProtoParseException(
624                             "Varint is too large at offset 0x"
625                                     + Integer.toHexString(getOffset() + i)
626                                     + dumpDebugData());
627                 }
628             }
629             // Hit the end of the buffer, do some incrementing and checking, then continue
630             incOffset(fragment);
631         }
632     }
633 
634     /**
635      * Read a fixed 32 bit int from the buffer
636      *
637      * @return the fixed32 as a int
638      */
readFixed32()639     private int readFixed32() throws IOException {
640         // check for fast path, which is likely with a reasonable buffer size
641         if (mOffset + 4 <= mEnd) {
642             // don't bother filling buffer since we know the end is plenty far away
643             incOffset(4);
644             return (mBuffer[mOffset - 4] & 0xFF)
645                     | ((mBuffer[mOffset - 3] & 0xFF) << 8)
646                     | ((mBuffer[mOffset - 2] & 0xFF) << 16)
647                     | ((mBuffer[mOffset - 1] & 0xFF) << 24);
648         }
649 
650         // the Fixed32 crosses the edge of a chunk, read the Fixed32 in multiple fragments.
651         // There will be two fragment reads except when the chunk size is 2 or less.
652         int value = 0;
653         int shift = 0;
654         int bytesLeft = 4;
655         while (bytesLeft > 0) {
656             fillBuffer();
657             // Find the number of bytes available until the end of the chunk or Fixed32
658             int fragment = (mEnd - mOffset) < bytesLeft ? (mEnd - mOffset) : bytesLeft;
659             if (fragment < 0) {
660                 throw new ProtoParseException(
661                         "Incomplete fixed32 at offset 0x"
662                                 + Integer.toHexString(getOffset())
663                                 + dumpDebugData());
664             }
665             incOffset(fragment);
666             bytesLeft -= fragment;
667             while (fragment > 0) {
668                 value |= ((mBuffer[mOffset - fragment] & 0xFF) << shift);
669                 fragment--;
670                 shift += 8;
671             }
672         }
673         return value;
674     }
675 
676     /**
677      * Read a fixed 64 bit long from the buffer
678      *
679      * @return the fixed64 as a long
680      */
681     private long readFixed64() throws IOException {
682         // check for fast path, which is likely with a reasonable buffer size
683         if (mOffset + 8 <= mEnd) {
684             // don't bother filling buffer since we know the end is plenty far away
685             incOffset(8);
686             return (mBuffer[mOffset - 8] & 0xFFL)
687                     | ((mBuffer[mOffset - 7] & 0xFFL) << 8)
688                     | ((mBuffer[mOffset - 6] & 0xFFL) << 16)
689                     | ((mBuffer[mOffset - 5] & 0xFFL) << 24)
690                     | ((mBuffer[mOffset - 4] & 0xFFL) << 32)
691                     | ((mBuffer[mOffset - 3] & 0xFFL) << 40)
692                     | ((mBuffer[mOffset - 2] & 0xFFL) << 48)
693                     | ((mBuffer[mOffset - 1] & 0xFFL) << 56);
694         }
695 
696         // the Fixed64 crosses the edge of a chunk, read the Fixed64 in multiple fragments.
697         // There will be two fragment reads except when the chunk size is 6 or less.
698         long value = 0;
699         int shift = 0;
700         int bytesLeft = 8;
701         while (bytesLeft > 0) {
702             fillBuffer();
703             // Find the number of bytes available until the end of the chunk or Fixed64
704             int fragment = (mEnd - mOffset) < bytesLeft ? (mEnd - mOffset) : bytesLeft;
705             if (fragment < 0) {
706                 throw new ProtoParseException(
707                         "Incomplete fixed64 at offset 0x"
708                                 + Integer.toHexString(getOffset())
709                                 + dumpDebugData());
710             }
711             incOffset(fragment);
712             bytesLeft -= fragment;
713             while (fragment > 0) {
714                 value |= ((mBuffer[(mOffset - fragment)] & 0xFFL) << shift);
715                 fragment--;
716                 shift += 8;
717             }
718         }
719         return value;
720     }
721 
722     /**
723      * Read raw bytes from the buffer
724      *
725      * @param n - number of bytes to read
726      * @return a byte array with raw bytes
727      */
728     private byte[] readRawBytes(int n) throws IOException {
729         byte[] buffer = new byte[n];
730         int pos = 0;
731         while (mOffset + n - pos > mEnd) {
732             int fragment = mEnd - mOffset;
733             if (fragment > 0) {
734                 System.arraycopy(mBuffer, mOffset, buffer, pos, fragment);
735                 incOffset(fragment);
736                 pos += fragment;
737             }
738             fillBuffer();
739             if (mOffset >= mEnd) {
740                 throw new ProtoParseException(
741                         "Unexpectedly reached end of the InputStream at offset 0x"
742                                 + Integer.toHexString(mEnd)
743                                 + dumpDebugData());
744             }
745         }
746         System.arraycopy(mBuffer, mOffset, buffer, pos, n - pos);
747         incOffset(n - pos);
748         return buffer;
749     }
750 
751     /**
752      * Read raw string from the buffer
753      *
754      * @param n - number of bytes to read
755      * @return a string
756      */
757     private String readRawString(int n) throws IOException {
758         fillBuffer();
759         if (mOffset + n <= mEnd) {
760             // fast path read. String is well within the current buffer
761             String value = new String(mBuffer, mOffset, n, StandardCharsets.UTF_8);
762             incOffset(n);
763             return value;
764         } else if (n <= mBufferSize) {
765             // String extends past buffer, but can be encapsulated in a buffer. Copy the first chunk
766             // of the string to the start of the buffer and then fill the rest of the buffer from
767             // the stream.
768             final int stringHead = mEnd - mOffset;
769             System.arraycopy(mBuffer, mOffset, mBuffer, 0, stringHead);
770             mEnd = stringHead + mStream.read(mBuffer, stringHead, n - stringHead);
771 
772             mDiscardedBytes += mOffset;
773             mOffset = 0;
774 
775             String value = new String(mBuffer, mOffset, n, StandardCharsets.UTF_8);
776             incOffset(n);
777             return value;
778         }
779         // Otherwise, the string is too large to use the buffer. Create the string from a
780         // separate byte array.
781         return new String(readRawBytes(n), 0, n, StandardCharsets.UTF_8);
782     }
783 
784     /**
785      * Fill the buffer with a chunk from the stream if need be.
786      * Will skip chunks until mOffset is reached
787      */
788     private void fillBuffer() throws IOException {
789         if (mOffset >= mEnd && mStream != null) {
790             mOffset -= mEnd;
791             mDiscardedBytes += mEnd;
792             if (mOffset >= mBufferSize) {
793                 int skipped = (int) mStream.skip((mOffset / mBufferSize) * mBufferSize);
794                 mDiscardedBytes += skipped;
795                 mOffset -= skipped;
796             }
797             mEnd = mStream.read(mBuffer);
798         }
799     }
800 
801     /**
802      * Skips the rest of current field and moves to the start of the next field. This should only be
803      * called while state is STATE_STARTED_FIELD_READ
804      */
805     public void skip() throws IOException {
806         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
807             incOffset(mPackedEnd - getOffset());
808         } else {
809             switch (mWireType) {
810                 case WIRE_TYPE_VARINT:
811                     byte b;
812                     do {
813                         fillBuffer();
814                         b = mBuffer[mOffset];
815                         incOffset(1);
816                     } while ((b & 0x80) != 0);
817                     break;
818                 case WIRE_TYPE_FIXED64:
819                     incOffset(8);
820                     break;
821                 case WIRE_TYPE_LENGTH_DELIMITED:
822                     fillBuffer();
823                     int length = (int) readVarint();
824                     incOffset(length);
825                     break;
826                 /*
827             case WIRE_TYPE_START_GROUP:
828                 // Not implemented
829                 break;
830             case WIRE_TYPE_END_GROUP:
831                 // Not implemented
832                 break;
833                 */
834                 case WIRE_TYPE_FIXED32:
835                     incOffset(4);
836                     break;
837                 default:
838                     throw new ProtoParseException(
839                             "Unexpected wire type: " + mWireType + " at offset 0x"
840                                     + Integer.toHexString(mOffset)
841                                     + dumpDebugData());
842             }
843         }
844         mState &= ~STATE_STARTED_FIELD_READ;
845     }
846 
847     /**
848      * Increment the offset and handle all the relevant bookkeeping
849      * Refilling the buffer when its end is reached will be handled elsewhere (ideally just before
850      * a read, to avoid unnecessary reads from stream)
851      *
852      * @param n - number of bytes to increment
853      */
854     private void incOffset(int n) {
855         mOffset += n;
856 
857         if (mDepth >= 0 && getOffset() > getOffsetFromToken(
858                 mExpectedObjectTokenStack.get(mDepth))) {
859             throw new ProtoParseException("Unexpectedly reached end of embedded object.  "
860                     + token2String(mExpectedObjectTokenStack.get(mDepth))
861                     + dumpDebugData());
862         }
863     }
864 
865     /**
866      * Check the current wire type to determine if current numeric field is packed. If it is packed,
867      * set up to deal with the field
868      * This should only be called for primitive numeric field types.
869      *
870      * @param fieldId - used to determine what the packed wire type is.
871      */
872     private void checkPacked(long fieldId) throws IOException {
873         if (mWireType == WIRE_TYPE_LENGTH_DELIMITED) {
874             // Primitive Field is length delimited, must be a packed field.
875             final int length = (int) readVarint();
876             mPackedEnd = getOffset() + length;
877             mState |= STATE_READING_PACKED;
878 
879             // Fake the wire type, based on the field type
880             switch ((int) ((fieldId & FIELD_TYPE_MASK)
881                     >>> FIELD_TYPE_SHIFT)) {
882                 case (int) (FIELD_TYPE_FLOAT >>> FIELD_TYPE_SHIFT):
883                 case (int) (FIELD_TYPE_FIXED32 >>> FIELD_TYPE_SHIFT):
884                 case (int) (FIELD_TYPE_SFIXED32 >>> FIELD_TYPE_SHIFT):
885                     if (length % 4 != 0) {
886                         throw new IllegalArgumentException(
887                                 "Requested field id (" + getFieldIdString(fieldId)
888                                         + ") packed length " + length
889                                         + " is not aligned for fixed32"
890                                         + dumpDebugData());
891                     }
892                     mWireType = WIRE_TYPE_FIXED32;
893                     break;
894                 case (int) (FIELD_TYPE_DOUBLE >>> FIELD_TYPE_SHIFT):
895                 case (int) (FIELD_TYPE_FIXED64 >>> FIELD_TYPE_SHIFT):
896                 case (int) (FIELD_TYPE_SFIXED64 >>> FIELD_TYPE_SHIFT):
897                     if (length % 8 != 0) {
898                         throw new IllegalArgumentException(
899                                 "Requested field id (" + getFieldIdString(fieldId)
900                                         + ") packed length " + length
901                                         + " is not aligned for fixed64"
902                                         + dumpDebugData());
903                     }
904                     mWireType = WIRE_TYPE_FIXED64;
905                     break;
906                 case (int) (FIELD_TYPE_SINT32 >>> FIELD_TYPE_SHIFT):
907                 case (int) (FIELD_TYPE_INT32 >>> FIELD_TYPE_SHIFT):
908                 case (int) (FIELD_TYPE_UINT32 >>> FIELD_TYPE_SHIFT):
909                 case (int) (FIELD_TYPE_SINT64 >>> FIELD_TYPE_SHIFT):
910                 case (int) (FIELD_TYPE_INT64 >>> FIELD_TYPE_SHIFT):
911                 case (int) (FIELD_TYPE_UINT64 >>> FIELD_TYPE_SHIFT):
912                 case (int) (FIELD_TYPE_ENUM >>> FIELD_TYPE_SHIFT):
913                 case (int) (FIELD_TYPE_BOOL >>> FIELD_TYPE_SHIFT):
914                     mWireType = WIRE_TYPE_VARINT;
915                     break;
916                 default:
917                     throw new IllegalArgumentException(
918                             "Requested field id (" + getFieldIdString(fieldId)
919                                     + ") is not a packable field"
920                                     + dumpDebugData());
921             }
922         }
923     }
924 
925 
926     /**
927      * Check a field id constant against current field number
928      *
929      * @param fieldId - throws if fieldId does not match mFieldNumber
930      */
assertFieldNumber(long fieldId)931     private void assertFieldNumber(long fieldId) {
932         if ((int) fieldId != mFieldNumber) {
933             throw new IllegalArgumentException("Requested field id (" + getFieldIdString(fieldId)
934                     + ") does not match current field number (0x" + Integer.toHexString(
935                     mFieldNumber)
936                     + ") at offset 0x" + Integer.toHexString(getOffset())
937                     + dumpDebugData());
938         }
939     }
940 
941 
942     /**
943      * Check a wire type against current wire type.
944      *
945      * @param wireType - throws if wireType does not match mWireType.
946      */
assertWireType(int wireType)947     private void assertWireType(int wireType) {
948         if (wireType != mWireType) {
949             throw new WireTypeMismatchException(
950                     "Current wire type " + getWireTypeString(mWireType)
951                             + " does not match expected wire type " + getWireTypeString(wireType)
952                             + " at offset 0x" + Integer.toHexString(getOffset())
953                             + dumpDebugData());
954         }
955     }
956 
957     /**
958      * Check if there is data ready to be read.
959      */
assertFreshData()960     private void assertFreshData() {
961         if ((mState & STATE_STARTED_FIELD_READ) != STATE_STARTED_FIELD_READ) {
962             throw new ProtoParseException(
963                     "Attempting to read already read field at offset 0x" + Integer.toHexString(
964                             getOffset()) + dumpDebugData());
965         }
966     }
967 
968     /**
969      * Dump debugging data about the buffer.
970      */
dumpDebugData()971     public String dumpDebugData() {
972         StringBuilder sb = new StringBuilder();
973 
974         sb.append("\nmFieldNumber : 0x").append(Integer.toHexString(mFieldNumber));
975         sb.append("\nmWireType : 0x").append(Integer.toHexString(mWireType));
976         sb.append("\nmState : 0x").append(Integer.toHexString(mState));
977         sb.append("\nmDiscardedBytes : 0x").append(Integer.toHexString(mDiscardedBytes));
978         sb.append("\nmOffset : 0x").append(Integer.toHexString(mOffset));
979         sb.append("\nmExpectedObjectTokenStack : ")
980                 .append(Objects.toString(mExpectedObjectTokenStack));
981         sb.append("\nmDepth : 0x").append(Integer.toHexString(mDepth));
982         sb.append("\nmBuffer : ").append(Arrays.toString(mBuffer));
983         sb.append("\nmBufferSize : 0x").append(Integer.toHexString(mBufferSize));
984         sb.append("\nmEnd : 0x").append(Integer.toHexString(mEnd));
985 
986         return sb.toString();
987     }
988 }
989