1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Inline functions for parsing the protocol buffers wire format.
17 //
18 // These functions have been optimized at the expense of safety.
19 // They are broken out into a separate file for readability but are
20 // not intended for use by clients other than the decode_proto op.
21 //
22 // The calling code in the decode_proto op does some fairly
23 // complicated things to ensure that this code is called
24 // safely. Changes to this code should be thoroughly fuzz tested.
25 
26 #ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
27 #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
28 
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace internal {
36 
37 using tensorflow::protobuf::internal::WireFormatLite;
38 using tensorflow::protobuf::io::CodedInputStream;
39 using tensorflow::protobuf::io::CodedOutputStream;
40 using tensorflow::protobuf::io::StringOutputStream;
41 
42 // Converts an uint64 to an int64 without loss of information.
43 // Unsigned values greater than INT64_MAX are represented as
44 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned64(uint64 unsigned_value)45 inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
46   // For a detailed explanation of why this works to wrap unsigned ints, see
47   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
48   // Both if tests should be optimized out.
49   if (unsigned_value <= INT64_MAX) {
50     return static_cast<int64>(unsigned_value);
51   }
52   // The C++ spec allows an architecture where this test is required.
53   if (unsigned_value >= INT64_MIN) {
54     return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
55   }
56   return 0;  // This should never occur.
57 }
58 
59 // Converts an uint32 to an int32 without loss of information.
60 // Unsigned values greater than INT_MAX are represented as
61 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned32(uint32 unsigned_value)62 inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
63   // For a detailed explanation of why this works to wrap unsigned ints, see
64   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
65   // Both if tests should be optimized out.
66   if (unsigned_value <= INT_MAX) {
67     return static_cast<int32>(unsigned_value);
68   }
69   // The C++ spec allows an architecture where this test is required.
70   if (unsigned_value >= INT_MIN) {
71     return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
72   }
73   return 0;  // This should never occur.
74 }
75 
76 // Reads a single varint32 from a byte array.
77 // It is the caller's responsibility to ensure that there is enough
78 // space in the buffer.
79 // The ok value will be set to false if the buffer does not contain
80 // a valid varint.
81 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
82                                           uint64* value);
83 
84 // Reads a single varint32 from a byte array.
85 // It is the caller's responsibility to ensure that there is enough
86 // space in the buffer.
87 // The ok value will be set to false if the buffer does not contain
88 // a valid varint.
89 // This is slightly less efficient than the private version in
90 // coded_stream.cc but we duplicate less code by calling
91 // the 64 bit version instead of copying the code.
ReadVarint32FromArray(const uint8 * buffer,bool * ok,uint32 * value)92 inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
93                                           uint32* value) {
94   uint64 tmp = 0;
95   const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
96   *value = tmp & 0xffffffff;
97   return buf;
98 }
99 
100 // Reads a single proto field value from a byte array into an array.
101 // The array is part of a Tensor that was allocated by the caller
102 // with type TensorType, while DeclaredType is the proto field type.
103 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
104 const uint8* ReadFromArray(const uint8* buf, TensorType* value);
105 
106 template <>
107 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
108     const uint8* buf, int64* value) {
109   uint32 temp = 0;
110   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
111   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
112   *value = static_cast<int64>(temp);
113   return buf;
114 }
115 
116 template <>
117 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
118     const uint8* buf, int32* value) {
119   uint32 temp = 0;
120   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
121   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
122   *value = static_cast<int32>(temp);
123   return buf;
124 }
125 
126 template <>
127 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
128     const uint8* buf, int64* value) {
129   uint64 temp = 0;
130   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
131   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
132   *value = WrapUnsignedAsSigned64(temp);
133   return buf;
134 }
135 
136 template <>
137 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
138     const uint8* buf, uint64* value) {
139   uint32 temp = 0;
140   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
141   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
142   *value = temp;
143   return buf;
144 }
145 
146 template <>
147 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
148     const uint8* buf, uint32* value) {
149   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
150   return ReadVarint32FromArray(buf, &unused_ok, value);
151 }
152 
153 template <>
154 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
155     const uint8* buf, uint64* value) {
156   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
157   return ReadVarint64FromArray(buf, &unused_ok, value);
158 }
159 
160 template <>
161 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
162     const uint8* buf, int64* value) {
163   uint64 temp = 0;
164   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
165   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
166   *value = WireFormatLite::ZigZagDecode32(temp);
167   return buf;
168 }
169 
170 template <>
171 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
172     const uint8* buf, int32* value) {
173   uint32 temp = 0;
174   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
175   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
176   *value = WireFormatLite::ZigZagDecode32(temp);
177   return buf;
178 }
179 
180 template <>
181 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
182     const uint8* buf, int64* value) {
183   uint64 temp = 0;
184   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
185   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
186   *value = WireFormatLite::ZigZagDecode64(temp);
187   return buf;
188 }
189 
190 template <>
191 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
192     const uint8* buf, uint64* value) {
193   uint32 temp;
194   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
195                                                WireFormatLite::TYPE_FIXED32>(
196       buf, &temp);
197   *value = temp;
198   return buf;
199 }
200 
201 template <>
202 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
203     const uint8* buf, uint32* value) {
204   uint32 temp;
205   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
206                                                WireFormatLite::TYPE_FIXED32>(
207       buf, &temp);
208   *value = WrapUnsignedAsSigned32(temp);
209   return buf;
210 }
211 
212 template <>
213 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
214     const uint8* buf, uint64* value) {
215   protobuf_uint64 temp;
216   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
217                                                WireFormatLite::TYPE_FIXED64>(
218       buf, &temp);
219   *value = WrapUnsignedAsSigned64(temp);
220   return buf;
221 }
222 
223 template <>
224 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
225     const uint8* buf, int64* value) {
226   int32 temp;
227   buf = WireFormatLite::ReadPrimitiveFromArray<int32,
228                                                WireFormatLite::TYPE_SFIXED32>(
229       buf, &temp);
230   *value = temp;
231   return buf;
232 }
233 
234 template <>
235 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
236     const uint8* buf, int32* value) {
237   return WireFormatLite::ReadPrimitiveFromArray<int32,
238                                                 WireFormatLite::TYPE_SFIXED32>(
239       buf, value);
240 }
241 
242 template <>
243 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
244     const uint8* buf, int64* value) {
245   protobuf_int64 temp;
246   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
247                                                WireFormatLite::TYPE_SFIXED64>(
248       buf, &temp);
249   *value = temp;
250   return buf;
251 }
252 
253 template <>
254 inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
255     const uint8* buf, float* value) {
256   return WireFormatLite::ReadPrimitiveFromArray<float,
257                                                 WireFormatLite::TYPE_FLOAT>(
258       buf, value);
259 }
260 
261 template <>
262 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
263     const uint8* buf, double* value) {
264   float temp;
265   buf =
266       WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
267           buf, &temp);
268   *value = temp;
269   return buf;
270 }
271 
272 template <>
273 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
274     const uint8* buf, double* value) {
275   return WireFormatLite::ReadPrimitiveFromArray<double,
276                                                 WireFormatLite::TYPE_DOUBLE>(
277       buf, value);
278 }
279 
280 template <>
281 inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
282     const uint8* buf, bool* value) {
283   uint64 temp = 0;
284   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
285   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
286   *value = temp != 0;
287   return buf;
288 }
289 
290 template <>
291 inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
292     const uint8* buf, int* value) {
293   uint32 temp = 0;
294   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
295   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
296   *value = static_cast<int>(temp);
297   return buf;
298 }
299 
300 // Reads packed values from an array.
301 // Stride is set to 1 for repeated fields, and 0 for non-repeated fields
302 // (where any value overwrites previous values).
303 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
ReadPackedPrimitives(const void * bufp,const size_t len,const int index,const int stride,void * datap)304 inline int ReadPackedPrimitives(const void* bufp, const size_t len,
305                                 const int index, const int stride,
306                                 void* datap) {
307   const uint8* buf = reinterpret_cast<const uint8*>(bufp);
308   const uint8* bound = buf + len;
309   TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
310   int count;
311 
312   // This could overrun the bound by stride-1. This is defended
313   // against in the caller, where it ensures that the input buffer
314   // contains complete values.
315   for (count = 0; buf < bound; count += stride) {
316     buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
317   }
318   return count;
319 }
320 
321 // Reads a value of a primitive type field from a serialized proto.
322 // The value is parsed from the serialized format, then static_cast
323 // to the desired type for TensorFlow and stored.
324 template <class ValueType, class TensorType,
325           enum WireFormatLite::FieldType DeclaredType>
ReadPrimitive(CodedInputStream * input,int index,void * data)326 inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
327   ValueType v;
328   if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
329     return errors::DataLoss("Failed reading primitive");
330   }
331 
332   reinterpret_cast<TensorType*>(data)[index] = v;
333   return Status::OK();
334 }
335 
336 // Reads a string, submessage, or other variable-length field from a
337 // serialized proto.
338 // May read all or part of a repeated field.
ReadBytes(CodedInputStream * input,int index,void * datap)339 inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
340   string* data = reinterpret_cast<string*>(datap) + index;
341   if (!WireFormatLite::ReadBytes(input, data)) {
342     return errors::DataLoss("Failed reading bytes");
343   }
344   return Status::OK();
345 }
346 
347 // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
348 // as a bytestring.
ReadGroupBytes(CodedInputStream * input,int field_number,int index,void * datap)349 inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
350                              int index, void* datap) {
351   // WireFormatLite::SkipField has an option to emit the
352   // skipped bytes to an output stream. We could do better by implementing our
353   // own scanner but this is simpler for now.
354   // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
355   // on input->IsFlat() == true and using input->GetDirectBufferPointer()
356   // with input->CurrentPosition().
357   string* data = reinterpret_cast<string*>(datap) + index;
358   StringOutputStream string_stream(data);
359   CodedOutputStream out(&string_stream);
360   if (!WireFormatLite::SkipField(
361           input,
362           WireFormatLite::MakeTag(field_number,
363                                   WireFormatLite::WIRETYPE_START_GROUP),
364           &out)) {
365     return errors::DataLoss("Failed reading group");
366   }
367   return Status::OK();
368 }
369 
370 // Reads a single field value from a CodedInputStream into a tensor.
ReadValue(CodedInputStream * input,WireFormatLite::FieldType field_type,int field_number,DataType dtype,int index,void * datap)371 inline Status ReadValue(CodedInputStream* input,
372                         WireFormatLite::FieldType field_type, int field_number,
373                         DataType dtype, int index, void* datap) {
374   // Dispatch to the appropriately typed field reader based on the schema type.
375   switch (field_type) {
376     case WireFormatLite::TYPE_DOUBLE:
377       return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
378           input, index, datap);
379     case WireFormatLite::TYPE_FLOAT:
380       switch (dtype) {
381         case DataType::DT_DOUBLE:
382           return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
383               input, index, datap);
384         case DataType::DT_FLOAT:
385           return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
386               input, index, datap);
387         default:
388           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
389                                   DataTypeString(dtype));
390       }
391     case WireFormatLite::TYPE_INT64:
392       return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
393           input, index, datap);
394     case WireFormatLite::TYPE_UINT64:
395       return ReadPrimitive<protobuf_uint64, uint64,
396                            WireFormatLite::TYPE_UINT64>(input, index, datap);
397     case WireFormatLite::TYPE_INT32:
398       switch (dtype) {
399         case DataType::DT_INT64:
400           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
401               input, index, datap);
402         case DataType::DT_INT32:
403           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
404               input, index, datap);
405         default:
406           return errors::DataLoss("Failed reading TYPE_INT32 for ",
407                                   DataTypeString(dtype));
408       }
409     case WireFormatLite::TYPE_FIXED64:
410       return ReadPrimitive<protobuf_uint64, uint64,
411                            WireFormatLite::TYPE_FIXED64>(input, index, datap);
412     case WireFormatLite::TYPE_FIXED32:
413       switch (dtype) {
414         case DataType::DT_UINT64:
415           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
416               input, index, datap);
417         case DataType::DT_UINT32:
418           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
419               input, index, datap);
420         default:
421           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
422                                   DataTypeString(dtype));
423       }
424     case WireFormatLite::TYPE_BOOL:
425       return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
426                                                                   datap);
427     case WireFormatLite::TYPE_STRING:
428       return ReadBytes(input, index, datap);
429     case WireFormatLite::TYPE_GROUP:
430       return ReadGroupBytes(input, field_number, index, datap);
431     case WireFormatLite::TYPE_MESSAGE:
432       return ReadBytes(input, index, datap);
433     case WireFormatLite::TYPE_BYTES:
434       return ReadBytes(input, index, datap);
435     case WireFormatLite::TYPE_UINT32:
436       switch (dtype) {
437         case DataType::DT_UINT64:
438           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
439               input, index, datap);
440         case DataType::DT_UINT32:
441           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
442               input, index, datap);
443         default:
444           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
445                                   DataTypeString(dtype));
446       }
447     case WireFormatLite::TYPE_ENUM:
448       return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
449           input, index, datap);
450     case WireFormatLite::TYPE_SFIXED32:
451       switch (dtype) {
452         case DataType::DT_INT64:
453           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
454               input, index, datap);
455         case DataType::DT_INT32:
456           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
457               input, index, datap);
458         default:
459           return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
460                                   DataTypeString(dtype));
461       }
462     case WireFormatLite::TYPE_SFIXED64:
463       return ReadPrimitive<protobuf_int64, int64,
464                            WireFormatLite::TYPE_SFIXED64>(input, index, datap);
465     case WireFormatLite::TYPE_SINT32:
466       switch (dtype) {
467         case DataType::DT_INT64:
468           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
469               input, index, datap);
470         case DataType::DT_INT32:
471           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
472               input, index, datap);
473         default:
474           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
475                                   DataTypeString(dtype));
476       }
477     case WireFormatLite::TYPE_SINT64:
478       return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
479           input, index, datap);
480       // default: intentionally omitted in order to enable static checking.
481   }
482   // Unreachable.
483   return errors::DataLoss("Failed reading unknown wire type");
484 }
485 
486 // Reads and stores a length-delimited list of values.
ReadPackedFromArray(const void * buf,size_t buf_size,const WireFormatLite::FieldType field_type,const int field_number,const DataType dtype,const int stride,int * index,void * data)487 inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
488                                   const WireFormatLite::FieldType field_type,
489                                   const int field_number, const DataType dtype,
490                                   const int stride, int* index, void* data) {
491   // Dispatch to the appropriately typed field reader based on the schema type.
492   switch (field_type) {
493     case WireFormatLite::TYPE_DOUBLE:
494       *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
495           buf, buf_size, *index, stride, data);
496       return Status::OK();
497     case WireFormatLite::TYPE_FLOAT:
498       switch (dtype) {
499         case DataType::DT_DOUBLE:
500           *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
501               buf, buf_size, *index, stride, data);
502           return Status::OK();
503         case DataType::DT_FLOAT:
504           *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
505               buf, buf_size, *index, stride, data);
506           return Status::OK();
507         default:
508           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
509                                   DataTypeString(dtype));
510       }
511     case WireFormatLite::TYPE_INT64:
512       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
513           buf, buf_size, *index, stride, data);
514       return Status::OK();
515     case WireFormatLite::TYPE_UINT64:
516       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
517           buf, buf_size, *index, stride, data);
518       return Status::OK();
519     case WireFormatLite::TYPE_INT32:
520       switch (dtype) {
521         case DataType::DT_INT64:
522           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
523               buf, buf_size, *index, stride, data);
524           return Status::OK();
525         case DataType::DT_INT32:
526           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
527               buf, buf_size, *index, stride, data);
528           return Status::OK();
529         default:
530           return errors::DataLoss("Failed reading TYPE_INT32 for ",
531                                   DataTypeString(dtype));
532       }
533     case WireFormatLite::TYPE_FIXED64:
534       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
535           buf, buf_size, *index, stride, data);
536       return Status::OK();
537     case WireFormatLite::TYPE_FIXED32:
538       switch (dtype) {
539         case DataType::DT_UINT64:
540           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
541               buf, buf_size, *index, stride, data);
542           return Status::OK();
543         case DataType::DT_UINT32:
544           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
545               buf, buf_size, *index, stride, data);
546           return Status::OK();
547         default:
548           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
549                                   DataTypeString(dtype));
550       }
551     case WireFormatLite::TYPE_BOOL:
552       *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
553           buf, buf_size, *index, stride, data);
554       return Status::OK();
555     case WireFormatLite::TYPE_STRING:
556     case WireFormatLite::TYPE_GROUP:
557     case WireFormatLite::TYPE_MESSAGE:
558     case WireFormatLite::TYPE_BYTES:
559       return errors::DataLoss("Non-primitive type encountered as packed");
560     case WireFormatLite::TYPE_UINT32:
561       switch (dtype) {
562         case DataType::DT_UINT64:
563           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
564               buf, buf_size, *index, stride, data);
565           return Status::OK();
566         case DataType::DT_UINT32:
567           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
568               buf, buf_size, *index, stride, data);
569           return Status::OK();
570         default:
571           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
572                                   DataTypeString(dtype));
573       }
574     case WireFormatLite::TYPE_ENUM:
575       *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
576           buf, buf_size, *index, stride, data);
577       return Status::OK();
578     case WireFormatLite::TYPE_SFIXED32:
579       switch (dtype) {
580         case DataType::DT_INT64:
581           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
582               buf, buf_size, *index, stride, data);
583           return Status::OK();
584         case DataType::DT_INT32:
585           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
586               buf, buf_size, *index, stride, data);
587           return Status::OK();
588         default:
589           return errors::DataLoss("Failed reading TYPE_INT32 for ",
590                                   DataTypeString(dtype));
591       }
592     case WireFormatLite::TYPE_SFIXED64:
593       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
594           buf, buf_size, *index, stride, data);
595       return Status::OK();
596 
597     case WireFormatLite::TYPE_SINT32:
598       switch (dtype) {
599         case DataType::DT_INT64:
600           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
601               buf, buf_size, *index, stride, data);
602           return Status::OK();
603         case DataType::DT_INT32:
604           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
605               buf, buf_size, *index, stride, data);
606           return Status::OK();
607         default:
608           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
609                                   DataTypeString(dtype));
610       }
611     case WireFormatLite::TYPE_SINT64:
612       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
613           buf, buf_size, *index, stride, data);
614       return Status::OK();
615       // default: intentionally omitted in order to enable static checking.
616   }
617   // Unreachable.
618   return errors::DataLoss("Failed reading unknown wire type");
619 }
620 
621 // Reads a varint from the given buffer, write it to *value, and return the
622 // new buffer pointer.
623 // This was copied from coded_stream.cc where it is private.
624 // Important: This routine may read as much as kMaxVarintBytes from
625 // the buffer. It is the caller's responsibility to make sure that there is
626 // enough space in the buffer.
ReadVarint64FromArray(const uint8 * buffer,bool * ok,uint64 * value)627 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
628                                           uint64* value) {
629   const uint8* ptr = buffer;
630   uint32 b;
631 
632   // Splitting into 32-bit pieces gives better performance on 32-bit
633   // processors.
634   uint32 part0 = 0, part1 = 0, part2 = 0;
635 
636   b = *(ptr++);
637   part0 = b;
638   if (!(b & 0x80)) goto done;
639   part0 -= 0x80;
640   b = *(ptr++);
641   part0 += b << 7;
642   if (!(b & 0x80)) goto done;
643   part0 -= 0x80 << 7;
644   b = *(ptr++);
645   part0 += b << 14;
646   if (!(b & 0x80)) goto done;
647   part0 -= 0x80 << 14;
648   b = *(ptr++);
649   part0 += b << 21;
650   if (!(b & 0x80)) goto done;
651   part0 -= 0x80 << 21;
652   b = *(ptr++);
653   part1 = b;
654   if (!(b & 0x80)) goto done;
655   part1 -= 0x80;
656   b = *(ptr++);
657   part1 += b << 7;
658   if (!(b & 0x80)) goto done;
659   part1 -= 0x80 << 7;
660   b = *(ptr++);
661   part1 += b << 14;
662   if (!(b & 0x80)) goto done;
663   part1 -= 0x80 << 14;
664   b = *(ptr++);
665   part1 += b << 21;
666   if (!(b & 0x80)) goto done;
667   part1 -= 0x80 << 21;
668   b = *(ptr++);
669   part2 = b;
670   if (!(b & 0x80)) goto done;
671   part2 -= 0x80;
672   b = *(ptr++);
673   part2 += b << 7;
674   if (!(b & 0x80)) goto done;
675   // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
676 
677   // We have overrun the maximum size of a varint (10 bytes).  Assume
678   // the data is corrupt.
679   *ok = false;
680   return ptr;
681 
682 done:
683   *ok = true;
684   *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
685            (static_cast<uint64>(part2) << 56);
686   return ptr;
687 }
688 
689 }  // namespace internal
690 }  // namespace tensorflow
691 
692 #endif  // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
693